import numpy as np
from numpy import interp
import matplotlib.pyplot as plt
from pytranskit.optrans.continuous.base import BaseTransform
from pytranskit.optrans.utils import check_array, assert_equal_shape, signal_to_pdf
[docs]class CDT(BaseTransform):
"""
Cumulative Distribution Transform.
Attributes
-----------
displacements_ : 1d array
Displacements u.
transport_map_ : 1d array
Transport map f.
References
----------
[The cumulative distribution transform and linear pattern classification]
(https://arxiv.org/abs/1507.05936)
"""
def __init__(self):
super(CDT, self).__init__()
[docs] def forward(self, x0, sig0, x1, sig1, rm_edge=False):
"""
Forward transform.
Parameters
----------
x0 : 1d array
Independent axis variable of reference signal (sig0).
sig0 : 1d array
Reference signal.
x1 : 1d array
Independent axis variable of the signal to transform (sig1).
sig1 : 1d array
Signal to transform.
Returns
-------
sig1_cdt : 1d array
CDT of input signal sig1 (new definition).
sig1_hat : 1d array
old definition.
xilde : 1d array
Independent axis variable in CDT space.
"""
# Check input arrays
sig0 = check_array(sig0, ndim=1, dtype=[np.float64, np.float32],
force_strictly_positive=True)
sig1 = check_array(sig1, ndim=1, dtype=[np.float64, np.float32],
force_strictly_positive=True)
# Input signals must be the same size
#assert_equal_shape(sig0, sig1, ['sig0', 'sig1'])
self.sig0_ = sig0
# Cumulative sums
cum0 = np.cumsum(sig0)
cum1 = np.cumsum(sig1)
# x co-ordinates and interpolated y co-ordinates
x = x1
y = np.linspace(0, 1, sig0.size)
y0 = interp(y, cum0, x0) # inverse of CDF of sig0
if len(np.unique(sig0)) == 1:
y1 = interp(y, cum1, x) # inverse of CDF of sig1 for uniform reference
else:
y1 = interp(cum0, cum1, x) # inverse of CDF of sig1
# Compute displacements: u = f(x0)-x0
self.displacements_ = interp(x0, y0, y1-y0)
#self.displacements_ = y1 - x0
# Compute transport map: f = u - x0
#self.transport_map_ = self.displacements_ - x0
self.transport_map_ = y1
# CDT (new definition)
sig1_cdt = self.transport_map_
self.xtilde = x0
# OLD CDT = (f - x) * sqrt(I0)
sig1_hat = self.displacements_ * np.sqrt(sig0)
if rm_edge:
sig1_cdt = np.delete(sig1_cdt, 0)
sig1_cdt = np.delete(sig1_cdt, len(sig1_cdt)-1)
sig1_hat = np.delete(sig1_hat, 0)
sig1_hat = np.delete(sig1_hat, len(sig1_hat)-1)
y1 = np.delete(y1, 0)
y1 = np.delete(y1, len(y1)-1)
self.xtilde = np.delete(self.xtilde, 0)
self.xtilde = np.delete(self.xtilde, len(self.xtilde)-1)
self.transport_map_ = np.delete(self.transport_map_, 0)
self.transport_map_ = np.delete(self.transport_map_, len(self.transport_map_)-1)
self.is_fitted = True
return sig1_cdt, sig1_hat, self.xtilde
[docs] def inverse(self, transport_map, sig0, x1):
"""
Inverse transform.
Parameters
----------
transport_map : 1d array
Forward transport map.
sig0 : 1d array
Reference signal.
x1 : 1d array
Independent axis variable of the signal to reconstruct.
Returns
-------
sig1_recon : 1d array
Reconstructed signal.
"""
self._check_is_fitted()
return self.apply_inverse_map(transport_map, sig0, x1)
[docs] def apply_forward_map(self, transport_map, sig1):
"""
Appy forward transport map.
Parameters
----------
transport_map : 1d array
Forward transport map.
sig1 : 1d array
Signal to transform.
Returns
-------
sig0_recon : 1d array
Reconstructed reference signal sig0.
"""
# Check inputs
transport_map = check_array(transport_map, ndim=1,
dtype=[np.float64, np.float32])
sig1 = check_array(sig1, ndim=1, dtype=[np.float64, np.float32],
force_strictly_positive=True)
assert_equal_shape(transport_map, sig1, ['transport_map', 'sig1'])
# Reconstruct sig0
x = np.arange(sig1.size)
fprime = np.gradient(transport_map)
sig0_recon = fprime * interp(transport_map, x, sig1)
return sig0_recon
[docs] def apply_inverse_map(self, transport_map, sig0, x):
"""
Apply inverse transport map.
Parameters
----------
transport_map : 1d array
Forward transport map. Inverse is computed in this function.
sig0 : 1d array
Reference signal.
Returns
-------
sig1_recon : 1d array
Reconstructed signal sig1.
"""
# Check inputs
transport_map = check_array(transport_map, ndim=1,
dtype=[np.float64, np.float32])
sig0 = check_array(sig0, ndim=1, dtype=[np.float64, np.float32],
force_strictly_positive=True)
assert_equal_shape(transport_map, sig0, ['transport_map', 'sig0'])
# Reconstruct sig1
fprime = np.gradient(transport_map)
if len(np.unique(sig0)) == 1:
sig1_recon = interp(x, transport_map, 1/fprime)
else:
sig1_recon = interp(x, transport_map, sig0/fprime)
sig1_recon = signal_to_pdf(sig1_recon, epsilon=1e-7,
total=1.)
return sig1_recon