Generalized Procrustes analysis#

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.decomposition import PCA

from ktch.datasets import load_landmark_mosquito_wings
from ktch.landmark import GeneralizedProcrustesAnalysis
/tmp/ipykernel_1948/3494986560.py:2: DeprecationWarning: 
Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd

Load mosquito wing landmark dataset#

from Rohlf and Slice 1990 Syst. Zool.

data_landmark_mosquito_wings = load_landmark_mosquito_wings(as_frame=True)
data_landmark_mosquito_wings.coords
x y
specimen_id coord_id
1 0 -0.4933 0.0130
1 -0.0777 0.0832
2 0.2231 0.0861
3 0.2641 0.0462
4 0.2645 0.0261
... ... ... ...
127 13 -0.2028 0.0371
14 0.0490 0.0347
15 -0.0422 0.0204
16 0.1004 -0.0180
17 -0.1473 -0.0057

2286 rows × 2 columns

fig, ax = plt.subplots()
sns.scatterplot(
    data = data_landmark_mosquito_wings.coords,
    x="x",y="y", 
    hue="specimen_id", style="coord_id",ax=ax
    )
ax.set_aspect('equal')
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
<matplotlib.legend.Legend at 0x7f3308a1c350>
../../_images/5963ec8fd561bcca98610c99de2bab1eda9875d9c1f270a110c5a383224960bb.png

For applying generalized Procrustes analysis (GPA), we convert the configuration data into DataFrame of shape n_specimens x (n_landmarks*n_dim).

df_coords = data_landmark_mosquito_wings.coords.unstack().swaplevel(1, 0, axis=1).sort_index(axis=1)
df_coords.columns = [dim +"_"+ str(landmark_idx) for landmark_idx,dim in df_coords.columns]
df_coords
x_0 y_0 x_1 y_1 x_2 y_2 x_3 y_3 x_4 y_4 ... x_13 y_13 x_14 y_14 x_15 y_15 x_16 y_16 x_17 y_17
specimen_id
1 -0.4933 0.0130 -0.0777 0.0832 0.2231 0.0861 0.2641 0.0462 0.2645 0.0261 ... -0.1768 0.0341 0.0715 0.0509 -0.0540 0.0238 0.0575 -0.0059 -0.1401 -0.0240
2 -0.4814 0.0135 -0.0058 0.0780 0.2345 0.0644 0.2460 0.0467 0.2487 0.0281 ... -0.1808 0.0229 0.0484 0.0405 -0.0519 0.0164 0.0623 -0.0047 -0.1444 -0.0286
3 -0.4622 0.0159 0.0089 0.0689 0.2404 0.0545 0.2501 0.0424 0.2600 0.0230 ... -0.1724 0.0182 0.0577 0.0344 -0.0468 0.0115 0.0766 -0.0079 -0.1602 -0.0162
4 -0.4534 -0.0028 -0.0318 0.0738 0.2423 0.0808 0.2627 0.0559 0.2654 0.0322 ... -0.1536 0.0150 0.0617 0.0436 -0.0549 0.0217 0.0705 -0.0031 -0.1507 -0.0273
5 -0.4926 -0.0212 -0.0260 0.0708 0.2347 0.0679 0.2398 0.0584 0.2415 0.0355 ... -0.1374 0.0154 0.0762 0.0457 -0.0313 0.0170 0.0880 0.0037 -0.1318 -0.0278
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
123 -0.4703 -0.0009 0.0011 0.0805 0.2146 0.0747 0.2348 0.0585 0.2453 0.0350 ... -0.1624 0.0193 0.0557 0.0467 -0.0272 0.0190 0.0718 -0.0052 -0.1583 -0.0251
124 -0.4725 0.0441 0.0318 0.0808 0.2382 0.0419 0.2504 0.0237 0.2492 0.0019 ... -0.1724 0.0390 0.0667 0.0418 -0.0108 0.0205 0.0548 -0.0049 -0.1546 -0.0115
125 -0.4697 0.0196 0.0007 0.0850 0.2228 0.0760 0.2485 0.0540 0.2548 0.0319 ... -0.1816 0.0213 0.0376 0.0408 -0.0073 0.0137 0.0592 -0.0103 -0.1531 -0.0300
126 -0.4620 0.0204 0.0082 0.0893 0.2061 0.0797 0.2374 0.0529 0.2457 0.0306 ... -0.1876 0.0241 0.0279 0.0403 -0.0425 0.0145 0.0729 -0.0123 -0.1460 -0.0269
127 -0.4570 0.0468 -0.0090 0.0753 0.2381 0.0463 0.2556 0.0275 0.2614 0.0064 ... -0.2028 0.0371 0.0490 0.0347 -0.0422 0.0204 0.1004 -0.0180 -0.1473 -0.0057

127 rows × 36 columns

GPA#

gpa = GeneralizedProcrustesAnalysis().set_output(transform="pandas")
df_shapes = gpa.fit_transform(df_coords)
df_shapes
x_0 y_0 x_1 y_1 x_2 y_2 x_3 y_3 x_4 y_4 ... x_13 y_13 x_14 y_14 x_15 y_15 x_16 y_16 x_17 y_17
specimen_id
1 -0.488913 0.021159 -0.075651 0.083813 0.222655 0.081657 0.262641 0.041408 0.262701 0.021471 ... -0.174735 0.036786 0.071747 0.049290 -0.053145 0.024519 0.056916 -0.006796 -0.139317 -0.021437
2 -0.479850 0.030667 -0.002995 0.078029 0.236289 0.055873 0.247131 0.037801 0.249161 0.019145 ... -0.179578 0.029304 0.049746 0.038675 -0.051194 0.018213 0.062000 -0.006922 -0.145098 -0.023383
3 -0.461238 0.021194 0.009668 0.068674 0.240606 0.051631 0.250150 0.039440 0.259809 0.019959 ... -0.171907 0.020150 0.057987 0.033671 -0.046599 0.012014 0.076367 -0.008775 -0.160124 -0.014332
4 -0.451582 0.021733 -0.027671 0.075199 0.245617 0.067345 0.264582 0.041450 0.265988 0.017707 ... -0.152122 0.023242 0.063790 0.040074 -0.053488 0.024575 0.070026 -0.006899 -0.151522 -0.019031
5 -0.491092 0.019428 -0.020015 0.072475 0.238724 0.048169 0.243009 0.038313 0.242816 0.015425 ... -0.135231 0.026596 0.079447 0.039142 -0.029701 0.019466 0.087717 -0.003550 -0.133219 -0.016779
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
123 -0.468839 0.028646 0.006143 0.080164 0.218584 0.060969 0.237700 0.043553 0.246690 0.019470 ... -0.160671 0.029436 0.058442 0.043044 -0.025929 0.020642 0.071228 -0.009699 -0.159374 -0.015077
124 -0.473804 0.016670 0.027045 0.082425 0.235150 0.055559 0.248370 0.038112 0.248433 0.016301 ... -0.174200 0.028931 0.064108 0.045546 -0.011957 0.019822 0.054939 -0.001719 -0.153528 -0.020407
125 -0.468449 0.032047 0.002958 0.084844 0.224473 0.069955 0.249549 0.047306 0.255251 0.025073 ... -0.180752 0.026088 0.038626 0.039732 -0.006925 0.013867 0.058834 -0.011863 -0.153659 -0.025890
126 -0.459471 0.036434 0.011277 0.088662 0.208061 0.072197 0.238303 0.044411 0.245792 0.021909 ... -0.186025 0.030549 0.029190 0.039168 -0.041832 0.015925 0.072178 -0.014794 -0.146368 -0.021701
127 -0.457745 0.025260 -0.012492 0.074647 0.235186 0.057309 0.253511 0.039387 0.260281 0.018625 ... -0.203892 0.027493 0.047226 0.036891 -0.043017 0.018365 0.100931 -0.013236 -0.146563 -0.012573

127 rows × 36 columns

We create a DataFrame, called df_shapes_vis, of shape (n_specimens*n_landmarks) x n_dim to visualize the aligned shapes.

df_shapes_vis = df_shapes.copy()
df_shapes_vis.columns = pd.MultiIndex.from_tuples([[int(landmark_idx), dim] for dim, landmark_idx in [idx.split("_") for idx in df_shapes_vis.columns]], names=["coord_id","dim"])
df_shapes_vis.sort_index(axis=1, inplace=True)
df_shapes_vis = df_shapes_vis.swaplevel(0,1,axis=1).stack(level=1)
df_shapes_vis
/tmp/ipykernel_1948/3231015902.py:4: FutureWarning: The previous implementation of stack is deprecated and will be removed in a future version of pandas. See the What's New notes for pandas 2.1.0 for details. Specify future_stack=True to adopt the new implementation and silence this warning.
  df_shapes_vis = df_shapes_vis.swaplevel(0,1,axis=1).stack(level=1)
dim x y
specimen_id coord_id
1 0 -0.488913 0.021159
1 -0.075651 0.083813
2 0.222655 0.081657
3 0.262641 0.041408
4 0.262701 0.021471
... ... ... ...
127 13 -0.203892 0.027493
14 0.047226 0.036891
15 -0.043017 0.018365
16 0.100931 -0.013236
17 -0.146563 -0.012573

2286 rows × 2 columns

fig, ax = plt.subplots()
sns.scatterplot(
    data = df_shapes_vis,
    x="x",y="y", 
    hue="specimen_id", style="coord_id",ax=ax
    )
ax.set_aspect('equal')
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
<matplotlib.legend.Legend at 0x7f3308a176d0>
../../_images/adb2df782b479ca3fe463e00bdfa8ba01de9dcb187fda17924b35832c658372e.png

PCA#

pca = PCA(n_components=10).set_output(transform="pandas")
df_pca = pca.fit_transform(df_shapes)

df_pca = df_pca.join(data_landmark_mosquito_wings.meta)
df_pca
pca0 pca1 pca2 pca3 pca4 pca5 pca6 pca7 pca8 pca9 genus
specimen_id
1 -0.003856 0.093210 -0.029908 -0.049589 0.007598 0.042657 -0.005370 0.010359 -0.011595 0.034991 AN
2 -0.029857 0.023434 -0.006114 0.017268 0.014229 0.014271 0.013344 0.021849 -0.011211 0.005069 AN
3 -0.012572 0.019728 0.034610 0.025665 0.005116 -0.017905 0.008511 -0.004312 0.001597 0.009884 AN
4 -0.001652 0.049890 0.034137 0.001013 -0.019706 -0.010102 -0.001692 -0.010425 0.021595 -0.006472 AN
5 0.026869 0.032989 -0.010228 0.051852 -0.016741 0.023749 0.015917 -0.000273 -0.013100 0.012755 AN
... ... ... ... ... ... ... ... ... ... ... ...
123 -0.007864 0.009853 0.006369 0.026082 0.025424 0.013786 -0.009967 -0.018669 0.000513 -0.005515 CX
124 -0.005360 -0.014451 0.014347 0.013199 -0.000102 0.001379 -0.002919 -0.009755 -0.020409 -0.017790 CX
125 -0.030056 -0.000885 -0.013455 -0.000542 0.008940 0.008503 -0.021390 -0.014563 -0.003408 -0.010083 CX
126 -0.048757 0.000199 -0.024147 0.029841 0.040190 0.013335 0.000598 0.016086 0.018985 0.004560 DE
127 -0.017155 0.030018 0.001935 0.014438 0.014540 -0.035565 -0.011888 0.017324 -0.009739 0.012342 DE

127 rows × 11 columns

fig, ax = plt.subplots()
sns.scatterplot(
    data=df_pca, x="pca0", y = "pca1", 
    hue="genus",palette="Paired",
    ax=ax
    )
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
<matplotlib.legend.Legend at 0x7f330695d990>
../../_images/0b48308ebd7bffd2872e38c31ede8d51272d3587f20b88ddf22a85e0ee033bfe.png