Source code for pytranskit.optrans.continuous.base
import numpy as np
from ..utils import check_array, assert_equal_shape, interp2d, griddata2d
[docs]class BaseTransform(object):
"""
Base class for optimal transport transform methods.
.. warning::
This class should **not** be used directly. Use derived classes instead.
"""
def __init__(self):
self.is_fitted = False
self.sig0_ = None
self.displacements_ = None
self.transport_map_ = None
def _check_is_fitted(self):
if not self.is_fitted:
raise AssertionError("The forward transform of {0!s} has not been "
"called yet. Call 'forward' before using "
"this method".format(type(self).__name__))
[docs] def forward(self):
"""
Placeholder for forward transform.
Subclasses should implement this method!
"""
raise NotImplementedError
[docs] def inverse(self):
"""
Inverse transform.
Returns
-------
sig1_recon : array, shape (height, width)
Reconstructed signal sig1.
"""
self._check_is_fitted()
return self.apply_inverse_map(self.transport_map_, self.sig0_)
[docs] def apply_forward_map(self):
"""
Placeholder for application of forward transport map.
Subclasses should implement this method!
"""
raise NotImplementedError
[docs] def apply_inverse_map(self):
"""
Placeholder for application of inverse transport map.
Subclasses should implement this method!
"""
raise NotImplementedError
[docs]class BaseMapper2D(BaseTransform):
"""
Base class for 2D optimal transport transform methods (e.g. CLOT, VOT2D).
.. warning::
This class should **not** be used directly. Use derived classes instead.
"""
def __init__(self):
super(BaseMapper2D, self).__init__()
return
[docs] def apply_forward_map(self, transport_map, sig1):
"""
Appy forward transport map.
Parameters
----------
transport_map : array, shape (2, height, width)
Forward transport map.
sig1 : array, shape (height, width)
Signal to transform.
Returns
-------
sig0_recon : array, shape (height, width)
Reconstructed reference signal sig0.
"""
# Check inputs
transport_map = check_array(transport_map, ndim=3,
dtype=[np.float64, np.float32])
sig1 = check_array(sig1, ndim=2, dtype=[np.float64, np.float32],
force_strictly_positive=True)
assert_equal_shape(transport_map[0], sig1, ['transport_map', 'sig1'])
# Jacobian and its determinant
f0y, f0x = np.gradient(transport_map[0])
f1y, f1x = np.gradient(transport_map[1])
detJ = (f1x * f0y) - (f1y * f0x)
# Reconstruct sig0
sig0_recon = detJ * interp2d(sig1, transport_map, fill_value=sig1.min())
return sig0_recon
[docs] def apply_inverse_map(self, transport_map, sig0):
"""
Appy inverse transport map.
Parameters
----------
transport_map : array, shape (2, height, width)
Forward transport map. Inverse is computed in this function.
sig0 : array, shape (height, width)
Reference signal.
Returns
-------
sig1_recon : array, shape (height, width)
Reconstructed signal sig1.
"""
# Check inputs
transport_map = check_array(transport_map, ndim=3,
dtype=[np.float64, np.float32])
sig0 = check_array(sig0, ndim=2, dtype=[np.float64, np.float32],
force_strictly_positive=True)
assert_equal_shape(transport_map[0], sig0, ['transport_map', 'sig0'])
# Jacobian and its determinant
f0y, f0x = np.gradient(transport_map[0])
f1y, f1x = np.gradient(transport_map[1])
detJ = (f1x * f0y) - (f1y * f0x)
# Let's hope there are no NaNs/Infs in sig0/detJ
sig1_recon = griddata2d(sig0/detJ, transport_map, fill_value=sig0.min())
return sig1_recon