Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,17 @@ authors:
- given-names: David
family-names: Coeurjolly
affiliation: CNRS, LIRIS
- given-names: Thibaut
family-names: Germain
affiliation: Ecole Polytechnique
- given-names: Sienna
family-names: O'Shea
affiliation: Ecole Polytechnique
- given-names: Marco
family-names: Corneli
affiliation: Université Côte d'Azur
- given-names: Ferdinand Genans
- given-names: Ferdinand
family-names: Genans
affiliation: Sorbonne Université, LPSM, CNRS
identifiers:
- type: url
Expand Down
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ The contributors to this library are:
* [Julie Delon](https://judelo.github.io/) (GMM OT)
* [Samuel Boïté](https://samuelbx.github.io/) (GMM OT)
* [Nathan Neike](https://github.com/nathanneike) (Sparse EMD solver)
* [Thibaut Germain](https://thibaut-germain.github.io) (SGOT)
* Sienna O'Shea (SGOT)


## Acknowledgments
Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ This new release adds support for sparse cost matrices and a new lazy exact OT s
- Update the geomloss wrapper to the new version and API (PR #826)
- Fix docstrings for `lowrank_gromov_wasserstein_samples` and `lowrank_sinkhorn` (PR #823)
- Reorganize all tests per backend (PR #828)
- Update sgot cost function and example (PR #830)


#### Closed issues
Expand Down
249 changes: 130 additions & 119 deletions examples/plot_sgot.py → examples/others/plot_sgot.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,35 @@
theta_0 = np.pi / 4


def rotation_matrix(theta):
return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])


def generate_data(time, tau, freq, theta):
t_ = np.sin(2 * np.pi * freq[None, :] * time[:, None]) * np.exp(
-tau[None, :] * time[:, None]
)
t_ = t_.sum(axis=1)
traj_0 = np.zeros((t_.shape[0], 2))
traj_0[:, 0] = t_
rotation_matrix = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
traj_0 = traj_0 @ rotation_matrix.T
R_ = rotation_matrix(theta)
traj_0 = traj_0 @ R_.T
return traj_0


traj_0 = generate_data(time, tau_0, freq_0, theta_0)
traj_0_proj = traj_0 @ rotation_matrix(theta_0)[:, 0]


# plot the observed signal components and their sum
plt.figure(figsize=(10, 4))
plt.plot(time, traj_0, label="base trajectory", linewidth=2)
plt.plot(time, traj_0_proj, label="projected trajectory", linewidth=2)
plt.xlabel("time")
plt.ylabel("amplitude")
plt.legend()
plt.title(r"Observed scalar signal along $\vec{e}(\theta)$")
plt.show()


# %%
# 2. Interpret the signal as coming from a continuous linear dynamical system
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -274,12 +276,8 @@ def augment(traj, window_length=2):
# Processing Systems, 35, pp.4017-4031.


def estimator(X, Y, rank=4):
# X: (n_samples, n_features)
# Y: (n_samples, n_features)

# estimate operator
cxx = X.T @ X
def estimator(X, Y, rank=4, eps=1e-8):
cxx = X.T @ X + eps * np.eye(X.shape[1])
U, S, Vt = np.linalg.svd(cxx)
S_inv = np.divide(1, S, out=np.zeros_like(S), where=S != 0)
cxx_inv_half = Vt.T @ np.diag(np.sqrt(S_inv)) @ U.T
Expand Down Expand Up @@ -416,6 +414,24 @@ def estimator(X, Y, rank=4):
# spectral atoms, taking into account both the location of eigenvalues and the
# relative geometry of their eigenspaces.

# %%
# A wider delay window for the SGOT experiments below
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The window of length 4 used above is enough to identify a single reference
# operator, but the experiments below also probe signals whose two modes
# nearly coincide in frequency (e.g. :math:`\omega_2'\to\omega_1`). Telling
# such near-degenerate modes apart requires the delay embedding to span
# enough time to "see" their differing decay, so we re-embed the reference
# signal with a longer window before running the sweeps.

sgot_window = 10
Z = augment(traj_0, sgot_window)
_, B_0_spec_sgot = estimator(Z[:-1], Z[1:])
D_0_sgot = np.log(B_0_spec_sgot["eig_val"]) * fs
L_0_sgot = B_0_spec_sgot["eig_vec_left"]
R_0_sgot = B_0_spec_sgot["eig_vec_right"]

# %%
# SGOT distance versus rotation angle
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -434,52 +450,66 @@ def estimator(X, Y, rank=4):
# this experiment isolates the effect of rotating the underlying one-dimensional
# subspace in the observation plane.

thetas = np.linspace(0, np.pi / 2, 50)
lst = []
for i, theta in enumerate(thetas):
traj = generate_data(time, tau_0, freq_0, theta)
Z = augment(traj, 4)
X = Z[:-1]
Y = Z[1:]
B, B_spec = estimator(X, Y, rank=4)
D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"]
D = np.log(D) * fs
lst.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=0.01))

plt.figure(figsize=(8, 5))
plt.plot(thetas, lst)
plt.xlabel("theta")
plt.ylabel("SGOT distance")
plt.title("SGOT distance vs rotation angle")
thetas = np.linspace(0, np.pi / 2, 51)
rotation_scores = []

for theta in thetas:
Z = augment(generate_data(time, tau_0, freq_0, theta), sgot_window)
B, B_spec = estimator(Z[:-1], Z[1:])
D = np.log(B_spec["eig_val"]) * fs
L = B_spec["eig_vec_left"]
R = B_spec["eig_vec_right"]
rotation_scores.append(
sgot_metric(
D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric="chordal"
)
)

fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(thetas, rotation_scores, linewidth=1.8)
ax.axvline(theta_0, color="gray", linestyle="--", linewidth=0.8)
ax.set_xlabel(r"Rotation angle $\theta$ (rad)")
ax.set_ylabel(r"$d_S$")
ax.set_title("SGOT distance vs. rotation angle")
fig.tight_layout()
plt.show()

# %%
# Comparison across Grassmannian metrics for SGOT distance versus rotation angle
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

thetas = np.linspace(0, np.pi / 2, 50)
lst = []
for i, theta in enumerate(thetas):
traj = generate_data(time, tau_0, freq_0, theta)
Z = augment(traj, 4)
X = Z[:-1]
Y = Z[1:]
B, B_spec = estimator(X, Y, rank=4)
D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"]
D = np.log(D) * fs
lst1 = []
for name in ["chordal", "martin", "geodesic", "procrustes"]:
lst1.append(sgot_metric(D_0, R_0, L_0, D, R, L, eta=0.9, grassmann_metric=name))
lst.append(lst1)
lst2 = np.array(lst)
plt.figure(figsize=(8, 5))
for i, name in enumerate(["chordal", "martin", "geodesic", "procrustes"]):
plt.plot(thetas, lst2[:, i], label=name)

plt.xlabel("theta")
plt.ylabel("SGOT distance")
plt.title("SGOT distance vs rotation angle")
plt.legend()
metrics = ["chordal", "geodesic", "procrustes", "martin"]
styles = {"chordal": "-", "geodesic": "--", "procrustes": "-.", "martin": ":"}
rotation_results = {m: [] for m in metrics}

for theta in thetas:
Z = augment(generate_data(time, tau_0, freq_0, theta), sgot_window)
B, B_spec = estimator(Z[:-1], Z[1:])
D = np.log(B_spec["eig_val"]) * fs
L = B_spec["eig_vec_left"]
R = B_spec["eig_vec_right"]
for m in metrics:
rotation_results[m].append(
sgot_metric(
D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric=m
)
)

fig, ax = plt.subplots(figsize=(7, 4))
for m in metrics:
ax.plot(thetas, rotation_results[m], styles[m], label=m, linewidth=1.8)
ax.axvline(
theta_0,
color="gray",
linestyle="--",
linewidth=0.8,
label=r"$\theta_0 = \pi/4$ (reference)",
)
ax.set_xlabel(r"Rotation angle $\theta$ (rad)")
ax.set_ylabel(r"$d_S$")
ax.set_title("SGOT distance vs. rotation angle across Grassmannian metrics")
ax.legend()
fig.tight_layout()
plt.show()

# %%
Expand All @@ -501,38 +531,38 @@ def estimator(X, Y, rank=4):
# distance changes as a function of the perturbed frequency :math:`\omega_2'`.

omegas = np.linspace(0.5, 3.0, 21)
methods = ["chordal", "martin", "geodesic", "procrustes"]
scores_omega = []
theta = theta_0
frequency_scores = {m: [] for m in metrics}

eta_fixed = 0.9
for omega in omegas:
freq_1 = np.array([freq_0[0], omega])
traj = generate_data(time, tau_0, freq_1, theta)
Z = augment(traj, 4)
X = Z[:-1]
Y = Z[1:]

B, B_spec = estimator(X, Y, rank=4)
D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"]
D = np.log(D) * fs

row = []
for name in methods:
row.append(
sgot_metric(D_0, R_0, L_0, D, R, L, eta=eta_fixed, grassmann_metric=name)
Z = augment(
generate_data(time, tau_0, np.array([freq_0[0], omega]), theta_0), sgot_window
)
B, B_spec = estimator(Z[:-1], Z[1:])
D = np.log(B_spec["eig_val"]) * fs
L = B_spec["eig_vec_left"]
R = B_spec["eig_vec_right"]
for m in metrics:
frequency_scores[m].append(
sgot_metric(
D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric=m
)
)
scores_omega.append(row)

scores_omega = np.array(scores_omega)
plt.figure(figsize=(8, 5))
for i, name in enumerate(methods):
plt.plot(omegas, scores_omega[:, i], label=name)

plt.xlabel("omega")
plt.ylabel("SGOT distance")
plt.title("SGOT distance vs omega")
plt.legend()
fig, ax = plt.subplots(figsize=(7, 4))
for m in metrics:
ax.plot(omegas, frequency_scores[m], styles[m], label=m, linewidth=1.8)
ax.axvline(
freq_0[1],
color="gray",
linestyle="--",
linewidth=0.8,
label=r"$\omega_2 = 2.0$ Hz (reference)",
)
ax.set_xlabel(r"Frequency $\omega_2'$ (Hz)")
ax.set_ylabel(r"$d_S$")
ax.set_title("SGOT distance vs. frequency across Grassmannian metrics")
ax.legend()
fig.tight_layout()
plt.show()

# %%
Expand All @@ -553,47 +583,28 @@ def estimator(X, Y, rank=4):
# In this way, both modes share the same modified decay parameter
# :math:`\tau`, allowing us to isolate the influence of dissipation on the SGOT
# distance.
decays = np.linspace(0.1, 3.0, 20) # adjust range as needed
methods = ["chordal", "martin", "geodesic", "procrustes"]
scores_decay = []
theta = theta_0

for tau in decays:
freq_1 = np.array([freq_0[0], recovered_freqs[1]])
tau_1 = np.array([tau, tau]) # or whatever structure your generator expects

traj = generate_data(time, tau_1, freq_1, theta)
Z = augment(traj, 4)
X = Z[:-1]
Y = Z[1:]

B, B_spec = estimator(X, Y, rank=4)
D, R, L = B_spec["eig_val"], B_spec["eig_vec_right"], B_spec["eig_vec_left"]
D = np.log(D) * fs

row = []
for name in methods:
row.append(
taus = np.linspace(0.1, 3.0, 21)
decay_scores = {m: [] for m in metrics}

for tau in taus:
Z = augment(generate_data(time, np.array([tau, tau]), freq_0, theta_0), sgot_window)
B, B_spec = estimator(Z[:-1], Z[1:])
D = np.log(B_spec["eig_val"]) * fs
L = B_spec["eig_vec_left"]
R = B_spec["eig_vec_right"]
for m in metrics:
decay_scores[m].append(
sgot_metric(
D_0,
R_0,
L_0,
D,
R,
L,
eta=0.9, # keep eta fixed here
grassmann_metric=name,
D_0_sgot, R_0_sgot, L_0_sgot, D, R, L, eta=0.9, grassmann_metric=m
)
)
scores_decay.append(row)

scores_decay = np.array(scores_decay)
plt.figure(figsize=(8, 5))
for i, name in enumerate(methods):
plt.plot(decays, scores_decay[:, i], label=name)

plt.xlabel("decay")
plt.ylabel("SGOT distance")
plt.title("SGOT distance vs decay")
plt.legend()
fig, ax = plt.subplots(figsize=(7, 4))
for m in metrics:
ax.plot(taus, decay_scores[m], styles[m], label=m, linewidth=1.8)
ax.set_xlabel(r"Decay rate $\tau$")
ax.set_ylabel(r"$d_S$")
ax.set_title("SGOT distance vs. decay across Grassmannian metrics")
ax.legend()
fig.tight_layout()
plt.show()
2 changes: 1 addition & 1 deletion ot/sgot.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _delta_matrix_1d(Rs, Ls, Rt, Lt, nx=None, eps=1e-12):
Ltn = _normalize_columns(Lt, nx=nx, eps=eps)

Cr = nx.dot(nx.conj(Rsn).T, Rtn)
Cl = nx.dot(nx.conj(Lsn).T, Ltn)
Cl = nx.dot(Lsn.T, nx.conj(Ltn))

delta = nx.abs(Cr * Cl)
delta = nx.clip(delta, 0.0, 1.0)
Expand Down
Loading