Commit d235c44b authored by Virginie Uhlmann's avatar Virginie Uhlmann
Browse files

Merge branch 'refactor' into 'master'

Refactoring proposals

See merge request uhlmann-group/python-spline-fitting-toolbox!3
parents 97801aec 76030ca5
Pipeline #141097 failed with stage
in 5 minutes and 36 seconds
......@@ -9,7 +9,7 @@ before_script:
tests :
stage: test
script:
- pip install -e .
- pip install -e . -v
- pip install pytest pytest-cov
- pytest --cov=sft/ tests/
- coverage xml
......
......@@ -10,7 +10,7 @@
git clone <this repo url>
cd python-spline-fitting-toolbox
conda env create -f environment.yml
conda activate meshcorr
conda activate sft
pip install .
```
......@@ -23,7 +23,7 @@ pip install .
git clone <this repo url>
cd python-spline-fitting-toolbox
conda env create -f environment.yml
conda activate meshcorr
conda activate sft
pip install --editable .
```
......
......@@ -79,7 +79,9 @@ class KSD:
Rdn = Rdn - vM[l] * Undn[:, l].conj()
norm2 = norm2 - np.abs(Undn[:, l]) ** 2
norm2[norm2 < 0] = 0 # sometimes happens to have small negative numbers
norm2[
norm2 < 0
] = 0 # sometimes happens to have small negative numbers
non_null = norm2 > 1e-8
scores = np.zeros(J, dtype=complex)
scores[non_null] = Rdn[non_null] / np.sqrt(norm2)[non_null]
......@@ -94,7 +96,9 @@ class KSD:
def normalize(self, D_c):
"""Normalizing the columns of D_c, which are centered configurations"""
J = D_c.shape[1]
uu = ssu.normalizedUnitConfiguration(self.Phi, self.numAnchors, self.configType)
uu = ssu.normalizedUnitConfiguration(
self.Phi, self.numAnchors, self.configType
)
for j in range(J):
d = D_c[:, j]
......@@ -120,7 +124,10 @@ class KSD:
diff_re = shapes_re - DA_re
diff_im = shapes_im - DA_im
E1 = np.sum(
np.diag(diff_re.T @ self.Phi @ diff_re + diff_im.T @ self.Phi @ diff_im)
np.diag(
diff_re.T @ self.Phi @ diff_re
+ diff_im.T @ self.Phi @ diff_im
)
)
if incoherence:
......@@ -182,12 +189,22 @@ class KSD:
Dv = res.x
D = np.zeros(D0.shape)
D[: self.N] = Dv[: self.N * self.numAtoms].reshape(self.N, self.numAtoms)
D[self.N :] = Dv[self.N * self.numAtoms :].reshape(self.N, self.numAtoms)
D[: self.N] = Dv[: self.N * self.numAtoms].reshape(
self.N, self.numAtoms
)
D[self.N :] = Dv[self.N * self.numAtoms :].reshape(
self.N, self.numAtoms
)
return D
def torchLoss(
self, Atorch, Dtorch, Phitorch, shapestorch, incoherence=False, alpha=1.0
self,
Atorch,
Dtorch,
Phitorch,
shapestorch,
incoherence=False,
alpha=1.0,
):
Dx, Dy = split_xy(Dtorch, self.N)
Ax, Ay = split_xy(Atorch, self.numAtoms)
......@@ -231,11 +248,12 @@ class KSD:
):
"""2D Kendall Shape Dictionary classically alternates between:
- a sparse coding step : the weights A are updated using a Cholesky-based
Order Recursive Matching Pursuit (ORMP), as a direct adaptation to the
complex setting of Mairal's implementation for the real setting in the SPAMS toolbox.
- a dictionary update : following the Method of Optimal Directions (MOD),
we update D as
- a sparse coding step : the weights A are updated using a
Cholesky-based Order Recursive Matching Pursuit (ORMP), as a direct
adaptation to the complex setting of Mairal's implementation for the
real setting in the SPAMS toolbox.
- a dictionary update : following the Method of Optimal Directions
(MOD), we update D as
D <- [z_1,...,z_K] @ A^H @ (A @ A^H)^{-1}
D <- Pi_S(D) (center and normalize all the non-null atoms d_j)
......@@ -245,13 +263,17 @@ class KSD:
(nb of data using d_j) / (K*N0) < 1 / (50*numAtoms)
Parameters:
- dataset in C^{(K,n)} is a complex array containing the horizontally stacked dataset [z_1,...,z_K]^T
- dataset in C^{(K,n)} is a complex array containing the
horizontally stacked dataset [z_1,...,z_K]^T
- sparsity determines the L0 sparsity, N0, of the weights a_k
- numAtoms fixes the number of atoms that we want to learn
- init = None initializes the dictionary with randomly picked data shapes.
if init is a given (n,numAtoms) complex array, then the initialization starts with init.
- init = None initializes the dictionary with randomly picked
data shapes.
if init is a given (n,numAtoms) complex array, then the
initialization starts with init.
- numIter is the number of iterations
- if verbose == True, the algorithm keeps track of the loss function E to be minimized at each iteration.
- if verbose == True, the algorithm keeps track of the loss
function E to be minimized at each iteration.
"""
K = len(dataset)
......@@ -324,7 +346,12 @@ class KSD:
torch.from_numpy(A).type(torch.FloatTensor)
)
loss = self.torchLoss(
Atorch, Dtorch, Phitorch, shapestorch, incoherence, alpha
Atorch,
Dtorch,
Phitorch,
shapestorch,
incoherence,
alpha,
)
loss.backward()
return loss
......@@ -349,7 +376,10 @@ class KSD:
< self.sparsity / (5 * self.numAtoms)
)[0]
for j in range(self.numAtoms):
if ssu.configurationNorm(D_c[:, j], self.Phi) < 1e-8 or j in purge_j:
if (
ssu.configurationNorm(D_c[:, j], self.Phi) < 1e-8
or j in purge_j
):
if verbose:
print("purged ", j, "at iteration", t)
D_c[:, j] = dataset[np.random.randint(K)]
......@@ -366,7 +396,9 @@ class KSD:
if verbose:
elapsed = time.time() - start
print("duration of the algorithm: ", np.round(elapsed, 2), "seconds")
print(
"duration of the algorithm: ", np.round(elapsed, 2), "seconds"
)
diffs = dataset.T - D_c @ A_c
E1 = np.diag(diffs.T.conj() @ self.Phi @ diffs).sum().real
......@@ -390,8 +422,9 @@ class KSD:
def greedyOrder(inputA):
"""Gives a permutation of the dictionary atoms, so that the first atom of the dictionary is the one
for which the modulus of A[:,0] is maximal; then the second atom for which..."""
"""Gives a permutation of the dictionary atoms, so that the first atom of
the dictionary is the one for which the modulus of A[:,0] is maximal;
then the second atom for which..."""
import copy
A = copy.deepcopy(inputA)
......
......@@ -31,7 +31,7 @@ class KSD:
return D_c
def ORMPCholesky(self, D_c, dataset): # columns of D_c must be preshapes
def ORMPCholesky(self, D_c, dataset):
"""Order Recursive Matching Pursuit with Cholesky-based optimisation,
as in the SPAMS toolbox of Mairal et al. (2009).
This is a direct adaptation of their code in C++ to the complex setting
......@@ -77,7 +77,9 @@ class KSD:
Rdn = Rdn - vM[l] * Undn[:, l].conj()
norm2 = norm2 - np.abs(Undn[:, l]) ** 2
norm2[norm2 < 0] = 0 # sometimes happens to have small negative numbers
norm2[
norm2 < 0
] = 0 # sometimes happens to have small negative numbers
non_null = norm2 > 1e-8
scores = np.zeros(J, dtype=complex)
scores[non_null] = Rdn[non_null] / np.sqrt(norm2)[non_null]
......@@ -91,8 +93,10 @@ class KSD:
def reciprocal(self, sigmas):
"""Auxiliary function of KSD_optimal_directions().
Given a 1D array of non-negative elements called sigmas, with possibly zero elements,
returns the array of multiplicative inverses whenever possible, and leaves the zeroes."""
Given a 1D array of non-negative elements called sigmas, with possibly
zero elements,
returns the array of multiplicative inverses whenever possible, and
leaves the zeroes."""
sigmas_rec = np.zeros_like(sigmas)
for i, x in enumerate(sigmas):
if x != 0:
......@@ -120,11 +124,12 @@ class KSD:
def learn(self, dataset, init=None, numIter=100, verbose=False):
"""2D Kendall Shape Dictionary classically alternates between:
- a sparse coding step : the weights A are updated using a Cholesky-based
Order Recursive Matching Pursuit (ORMP), as a direct adaptation to the
complex setting of Mairal's implementation for the real setting in the SPAMS toolbox.
- a dictionary update : following the Method of Optimal Directions (MOD),
we update D as
- a sparse coding step : the weights A are updated using a
Cholesky-based Order Recursive Matching Pursuit (ORMP), as a direct
adaptation to the complex setting of Mairal's implementation for the
real setting in the SPAMS toolbox.
- a dictionary update : following the Method of Optimal Directions
(MOD), we update D as
D <- [z_1,...,z_K] @ A^H @ (A @ A^H)^{-1}
D <- Pi_S(D) (center and normalize all the non-null atoms d_j)
......@@ -134,11 +139,17 @@ class KSD:
(nb of data using d_j) / (K*N0) < 1 / (50*numAtoms)
Parameters:
- dataset in C^{(K,n)} is a complex array containing the horizontally stacked dataset [z_1,...,z_K]^T
- init = None initializes the dictionary with randomly picked data shapes.
if init is a given (n,numAtoms) complex array, then the initialization starts with init.
- dataset in C^{(K,n)} is a complex array containing the
horizontally stacked dataset [z_1,...,z_K]^T
- sparsity determines the L0 sparsity, N0, of the weights a_k
- numAtoms fixes the number of atoms that we want to learn
- init = None initializes the dictionary with randomly picked
data shapes.
if init is a given (n,numAtoms) complex array, then the
initialization starts with init.
- numIter is the number of iterations
- if verbose == True, the algorithm keeps track of the loss function E to be minimized at each iteration.
- if verbose == True, the algorithm keeps track of the loss
function E to be minimized at each iteration.
"""
K = len(dataset)
......@@ -190,7 +201,10 @@ class KSD:
< self.sparsity / (5 * self.numAtoms)
)[0]
for j in range(self.numAtoms):
if ssu.configurationNorm(D_c[:, j], self.Phi) < 1e-8 or j in purge_j:
if (
ssu.configurationNorm(D_c[:, j], self.Phi) < 1e-8
or j in purge_j
):
if verbose:
print("purged ", j, "at iteration", t)
D_c[:, j] = dataset[np.random.randint(K)]
......@@ -201,7 +215,9 @@ class KSD:
if verbose:
elapsed = time.time() - start
print("duration of the algorithm: ", np.round(elapsed, 2), "seconds")
print(
"duration of the algorithm: ", np.round(elapsed, 2), "seconds"
)
diffs = dataset.T - D_c @ A_c
E = np.diag(diffs.T.conj() @ self.Phi @ diffs).sum().real
......@@ -220,8 +236,9 @@ class KSD:
def greedyOrder(inputA):
"""Gives a permutation of the dictionary atoms, so that the first atom of the dictionary is the one
for which the modulus of A[:,0] is maximal; then the second atom for which..."""
"""Gives a permutation of the dictionary atoms, so that the first atom of
the dictionary is the one for which the modulus of A[:,0] is maximal;
then the second atom for which..."""
import copy
A = copy.deepcopy(inputA)
......
......@@ -23,7 +23,12 @@ def shapePCA(data, numAnchors, N, configType, c=3.0):
modes = []
count = len(np.where(pca.diag > 1e-6)[0])
for i in range(count):
mode = c * np.std(pcWeights[:, i]) * np.sqrt(pca.diag[i]) * pca.complexPC[:, i]
mode = (
c
* np.std(pcWeights[:, i])
* np.sqrt(pca.diag[i])
* pca.complexPC[:, i]
)
modes.append(
[
ssu.exponentialMap(frechetMean, mode, Phi),
......@@ -50,7 +55,9 @@ def shapePCReconstruction(data, pca, count=0):
originalShape = ssu.exponentialMap(pca.frechetMean, aligned, Phi)
error.append(
np.round(
ssu.configurationNorm(originalShape - reconstructedShapes[k], Phi).real,
ssu.configurationNorm(
originalShape - reconstructedShapes[k], Phi
).real,
4,
)
)
......@@ -141,7 +148,9 @@ class PCA:
self.diag = Diag[sortindr]
Vmodes = Vmodes[:, sortindr]
Wmodes = np.linalg.inv(sqrtPsi) @ Vmodes # eigenmodes in the tangent plane at m
Wmodes = (
np.linalg.inv(sqrtPsi) @ Vmodes
) # eigenmodes in the tangent plane at m
W = np.zeros((self.N, 2 * self.N), dtype=complex)
W.real = Wmodes[: self.N, :]
W.imag = Wmodes[self.N :, :]
......
......@@ -8,9 +8,11 @@ extended to spline curves thanks to the Hermitian product Phi.
It defines:
- the three distances d_F, d_P, and rho, also called full, partial, or geodesic distances.
- some operations on the Riemannian manifold, such as exp(z,v), log(z,w), geo(z,w),
and logarithmic or orthogonal projections on a tangent space
- the three distances d_F, d_P, and rho, also called full, partial,
or geodesic distances.
- some operations on the Riemannian manifold, such as
exp(z,v), log(z,w), geo(z,w),
and logarithmic or orthogonal projections on a tangent space
All functions are defined supposing z and w to be preshapes.
......@@ -19,8 +21,9 @@ Warning:
Mathematical shapes can only be numerically handled as preshapes.
Hence the `shape' variable designates in fact a preshape
(centered and normalized configuration), that is one of the many representatives
of the equivalence class defining the corresponding (mathematical) shape.
(centered and normalized configuration), that is one of the many
representatives of the equivalence class defining the corresponding
(mathematical) shape.
"""
......@@ -47,8 +50,8 @@ def getPhi(configType, numAnchors):
def aux(k):
return (
lambda t: basis.B3().value(numAnchors * (t - 1) - k)
+ basis.B3().value(numAnchors * t - k)
+ basis.B3().value(numAnchors * (t + 1) - k)
+ basis.B3().value(numAnchors * t - k)
+ basis.B3().value(numAnchors * (t + 1) - k)
)
for k in range(numAnchors):
......@@ -58,7 +61,10 @@ def getPhi(configType, numAnchors):
Phi[k, l] = innerProductL2(u, v)
return Phi
elif configType is configurationType.CLOSEDHSPLINE or configurationType.OPENHSPLINE:
elif (
configType is configurationType.CLOSEDHSPLINE
or configurationType.OPENHSPLINE
):
II = np.repeat(np.arange(numAnchors)[:, None], numAnchors, axis=1)
JJ = np.repeat(np.arange(numAnchors)[None, :], numAnchors, axis=0)
......@@ -126,7 +132,9 @@ def getPhi(configType, numAnchors):
]
)
Phi = np.concatenate((np.hstack((Phi11, Phi12)), np.hstack((Phi12.T, Phi22))))
Phi = np.concatenate(
(np.hstack((Phi11, Phi12)), np.hstack((Phi12.T, Phi22)))
)
return Phi
else:
......@@ -160,7 +168,9 @@ def configurationNorm(z, Phi):
return np.sqrt(squaredNorm).real
def exponentialMap(z, v, Phi): # z preshape, v in C^n referring to a tangent vector
def exponentialMap(
z, v, Phi
): # z preshape, v in C^n referring to a tangent vector
"""Computes the exponential of v at z, that corresponds to a preshape"""
t = configurationNorm(v, Phi)
if t < 1e-16: # catch numerical errors
......@@ -190,7 +200,8 @@ def theta(z, w, Phi): # preshapes
def logPreshape(z, w, Phi):
"""Computes v = log_z(w) where log is relative to the preshape sphere \S.
(Requires that z* Phi w > 0, because otherwise
v does not satisfy Re(z* Phi v) = 0 in order to be on the tangent space T_z \S.
v does not satisfy Re(z* Phi v) = 0 in order to be on the tangent space
T_z \S.
As a consequence, expo(logPre(z,w)) would not be a preshape in \S.)
"""
ro = geodesicDistance(z, w, Phi)
......@@ -198,7 +209,8 @@ def logPreshape(z, w, Phi):
def log(z, w, Phi): # preshapes
"""Computes a preshape pertaining to the shape (equivalence class) log_[z] ([w])
"""Computes a preshape pertaining to the shape (equivalence class) log_[z]
([w])
where log is relative the shape space Sigma."""
ta = theta(z, w, Phi)
return logPreshape(z, np.exp(-1j * ta) * w, Phi)
......@@ -220,15 +232,13 @@ def configurationMean(z, dataType, numAnchors):
+ (z[numAnchors] - z[-1]) / (12 * numAnchors - 12)
)
else:
raise NotImplementedError(
"Unknown configuration type (" + str(configType) + ")."
)
return
raise NotImplementedError(f"Unknown data type: {dataType}.")
return m
def preshape(z, Phi, numAnchors, dataType):
"""Centers and normalizes (i.e. preshapes it) the configuration z in C^N."""
"""Centers and normalizes (i.e. preshapes it) the configuration
z in C^N."""
m = configurationMean(z, dataType, numAnchors)
if (
dataType is configurationType.LANDMARKS
......@@ -290,7 +300,8 @@ def alignShapes(dataset, Phi, numAnchors, dataType, inputMeanShape=None):
def geodesicPath(z, w, Phi, numSteps=5): # preshapes
"""Returns elements regularly spaced along the geodesic curve joining z to w (preshapes)."""
"""Returns elements regularly spaced along the geodesic curve joining
z to w (preshapes)."""
ro = geodesicDistance(z, w, Phi)
steps = np.arange(numSteps + 1) / numSteps
......@@ -307,7 +318,8 @@ def geodesicPath(z, w, Phi, numSteps=5): # preshapes
def projectTangentSpace(m, z, Phi): # preshapes
"""Orthogonally project z onto tangent space at m (considering the ambient space)"""
"""Orthogonally project z onto tangent space at m
(considering the ambient space)"""
if np.abs(hermitianProduct(m, z, Phi)) < 1e-2:
return z
ta = theta(z, m, Phi)
......@@ -316,22 +328,29 @@ def projectTangentSpace(m, z, Phi): # preshapes
def realToComplex(realConfigurations):
"""Converts a collection of real configuration in R^{2N} in a collection of complex configurations in C^N."""
"""Converts a collection of real configuration in R^{2N} in a
collection of complex configurations in C^N."""
return realConfigurations[:, :, 0] + 1j * realConfigurations[:, :, 1]
def complexToReal(complexConfigurations):
"""Converts a collection of complex configurations in C^N to a collection of real configuration in R^{2N}."""
return np.stack((complexConfigurations.real, complexConfigurations.imag), axis=2)
"""Converts a collection of complex configurations in C^N to a
collection of real configuration in R^{2N}."""
return np.stack(
(complexConfigurations.real, complexConfigurations.imag), axis=2
)
def getPreshapes(complexConfigurations, numAnchors, Phi, dataType):
"""Same as preshape(), but for several horizontally stacked configurations."""
"""Same as preshape(), but for several horizontally stacked
configurations."""
if (
dataType is configurationType.LANDMARKS
or dataType is configurationType.CLOSEDBSPLINE
):
shapes = np.zeros((complexConfigurations.shape[0], numAnchors), dtype=complex)
shapes = np.zeros(
(complexConfigurations.shape[0], numAnchors), dtype=complex
)
elif (
dataType is configurationType.CLOSEDHSPLINE
or dataType is configurationType.OPENHSPLINE
......@@ -341,5 +360,7 @@ def getPreshapes(complexConfigurations, numAnchors, Phi, dataType):
)
for k in range(complexConfigurations.shape[0]):
shapes[k] = preshape(complexConfigurations[k], Phi, numAnchors, dataType)
shapes[k] = preshape(
complexConfigurations[k], Phi, numAnchors, dataType
)
return shapes
......@@ -6,8 +6,8 @@ import abc
# TODO: quadratic prefilters
class SplineGenerator(abc.ABC):
unimplementedMessage = "This method is not implemented."
class Basis(abc.ABC):
_unimplemented_message = "This method is not implemented."
def __init__(self, multigenerator, support):
self.multigenerator = multigenerator
......@@ -15,27 +15,27 @@ class SplineGenerator(abc.ABC):
@abc.abstractmethod
def value(self, x):
raise NotImplementedError(SplineGenerator.unimplementedMessage)
raise NotImplementedError(Basis._unimplemented_message)
@abc.abstractmethod
def firstDerivativeValue(self, x):
raise NotImplementedError(SplineGenerator.unimplementedMessage)
raise NotImplementedError(Basis._unimplemented_message)
@abc.abstractmethod
def secondDerivativeValue(self, x):
raise NotImplementedError(SplineGenerator.unimplementedMessage)
raise NotImplementedError(Basis._unimplemented_message)
@abc.abstractmethod
def filterSymmetric(self, s):
raise NotImplementedError(SplineGenerator.unimplementedMessage)
raise NotImplementedError(Basis._unimplemented_message)
@abc.abstractmethod
def filterPeriodic(self, s):
raise NotImplementedError(SplineGenerator.unimplementedMessage)
raise NotImplementedError(Basis._unimplemented_message)
@abc.abstractmethod
def refinementMask(self):
raise NotImplementedError(SplineGenerator.unimplementedMessage)
raise NotImplementedError(Basis._unimplemented_message)
@property
def support(self):
......@@ -46,9 +46,9 @@ class SplineGenerator(abc.ABC):
raise RuntimeError(f"Can't override {self.__class__.__name__}.support")
class B1(SplineGenerator):
class B1(Basis):
def __init__(self):
SplineGenerator.__init__(self, False, 2.0)
Basis.__init__(self, multigenerator=False, support=2.0)
def value(self, x):
val = 0.0
......@@ -80,9 +80,9 @@ class B1(SplineGenerator):
return mask
class B2(SplineGenerator):
class B2(Basis):
def __init__(self):
SplineGenerator.__init__(self, False, 3.0)
Basis.__init__(self, multigenerator=False, support=3.0)
def value(self, x):
val = 0.0
......@@ -127,9 +127,9 @@ class B2(SplineGenerator):
raise NotImplementedError()
class B3(SplineGenerator):
class B3(Basis):
def __init__(self):
SplineGenerator.__init__(self, False, 4.0)
Basis.__init__(self, multigenerator=False, support=4.0)
def value(self, x):
val = 0.0
......@@ -169,7 +169,9 @@ class B3(SplineGenerator):
cp = np.zeros(M)
eps = 1e-8
k0 = np.min(((2 * M) - 2, int(np.ceil(np.log(eps) / np.log(np.abs(pole))))))
k0 = np.min(
((2 * M) - 2, int(np.ceil(np.log(eps) / np.log(np.abs(pole)))))
)
for k in range(0, k0):
k = k % (2 * M - 2)
if k >= M:
......@@ -228,9 +230,9 @@ class B3(SplineGenerator):
return mask
class EM(SplineGenerator):
class EM(Basis):
def __init__(self, M, alpha):
SplineGenerator.__init__(self, False, 3.0)
Basis.__init__(self, multigenerator=False, support=3.0)
self.M = M
self.alpha = alpha
......@@ -240,7 +242,11 @@ class EM(SplineGenerator):
val = 0.0
if 0 <= x < 1:
val = 2.0 * np.sin(self.alpha * 0.5 * x) * np.sin(self.alpha * 0.5 * x)
val = (
2.0
* np.sin(self.alpha * 0.5 * x)
* np.sin(self.alpha * 0.5 * x)
)
elif 1 <= x < 2:
val = (
np.cos(self.alpha * (x - 2))
......@@ -283,7 +289,10 @@ class EM(SplineGenerator):
val = (
self.alpha
* self.alpha
* (-np.cos(self.alpha * (1 - x)) - np.cos(self.alpha * (2 - x)))
* (
-np.cos(self.alpha * (1 - x))
- np.cos(self.alpha * (2 - x))
)
)
elif 2 < x <= 3:
val = self.alpha * self.alpha * np.cos(self.alpha * (x - 3))
......@@ -299,7 +308,10 @@ class EM(SplineGenerator):
cp = np.zeros(self.M)
eps = 1e-8
k0 = np.min(
((2 * self.M) - 2, int(np.ceil(np.log(eps) / np.log(np.abs(pole)))))
(
(2 * self.M) - 2,
int(np.ceil(np.log(eps) / np.log(np.abs(pole)))),
)
)
for k in range(0, k0):