{ "cells": [ { "cell_type": "markdown", "id": "be45b072", "metadata": {}, "source": [ "# Generalized Procrustes analysis" ] }, { "cell_type": "code", "execution_count": null, "id": "cbac26b9", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "from sklearn.decomposition import PCA\n", "\n", "from ktch.datasets import load_landmark_mosquito_wings\n", "from ktch.landmark import GeneralizedProcrustesAnalysis\n", "from ktch.plot import explained_variance_ratio_plot" ] }, { "cell_type": "markdown", "id": "c0735d6d", "metadata": {}, "source": [ "## Load mosquito wing landmark dataset\n", "from Rohlf and Slice 1990 _Syst. Zool._" ] }, { "cell_type": "code", "execution_count": null, "id": "879bc7d9", "metadata": {}, "outputs": [], "source": [ "data_landmark_mosquito_wings = load_landmark_mosquito_wings(as_frame=True)\n", "data_landmark_mosquito_wings.coords" ] }, { "cell_type": "code", "execution_count": null, "id": "cdaa44ad", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "sns.scatterplot(\n", " data=data_landmark_mosquito_wings.coords,\n", " x=\"x\",\n", " y=\"y\",\n", " hue=\"specimen_id\",\n", " style=\"coord_id\",\n", " ax=ax,\n", ")\n", "ax.set_aspect(\"equal\")\n", "ax.legend(loc=\"upper left\", bbox_to_anchor=(1, 1))" ] }, { "cell_type": "markdown", "id": "bd609e30", "metadata": {}, "source": [ "For applying generalized Procrustes analysis (GPA), \n", "we convert the configuration data into DataFrame of shape n_specimens x (n_landmarks*n_dim)." ] }, { "cell_type": "code", "execution_count": null, "id": "b9ad0413", "metadata": {}, "outputs": [], "source": [ "def configulation_plot(\n", " configuration_2d,\n", " x=\"x\",\n", " y=\"y\",\n", " links=[],\n", " ax=None,\n", " hue=None,\n", " c=\"gray\",\n", " palette=None,\n", " c_line=\"gray\",\n", " style=None,\n", " s=10,\n", " alpha=1,\n", "):\n", " if ax is None:\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", "\n", " configuration = configuration_2d.reset_index()\n", "\n", " for link in links:\n", " sns.lineplot(\n", " data=configuration[configuration[\"coord_id\"].isin(link)],\n", " x=x,\n", " y=y,\n", " sort=False,\n", " ax=ax,\n", " hue=hue,\n", " c=c_line,\n", " palette=palette,\n", " alpha=alpha,\n", " legend=False,\n", " )\n", "\n", " axis = sns.scatterplot(\n", " data=configuration,\n", " x=x,\n", " y=y,\n", " ax=ax,\n", " hue=hue,\n", " palette=palette,\n", " style=style,\n", " c=c,\n", " alpha=alpha,\n", " s=s,\n", " )\n", "\n", " if axis.legend_:\n", " sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))\n", "\n", " ax.set_aspect(\"equal\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3aa9a24f", "metadata": {}, "outputs": [], "source": [ "links = [\n", " [0, 1],\n", " [1, 2],\n", " [2, 3],\n", " [3, 4],\n", " [4, 5],\n", " [5, 6],\n", " [6, 7],\n", " [7, 8],\n", " [8, 9],\n", " [9, 10],\n", " [10, 11],\n", " [1, 12],\n", " [2, 12],\n", " [13, 14],\n", " [14, 3],\n", " [14, 4],\n", " [15, 5],\n", " [16, 6],\n", " [16, 7],\n", " [17, 8],\n", " [17, 9],\n", "]\n", "\n", "configulation_plot(\n", " data_landmark_mosquito_wings.coords.loc[1], links=links, alpha=0.5, s=20\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "68dad4a7", "metadata": {}, "outputs": [], "source": [ "configulation_plot(\n", " data_landmark_mosquito_wings.coords,\n", " links=links,\n", " hue=\"specimen_id\",\n", " style=\"coord_id\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "4590b0a4", "metadata": {}, "outputs": [], "source": [ "configulation_plot(\n", " data_landmark_mosquito_wings.coords.loc[0:2],\n", " links=links,\n", " hue=\"specimen_id\",\n", " style=\"coord_id\",\n", " palette=\"Set2\",\n", " s=30,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "39086c3d", "metadata": {}, "outputs": [], "source": [ "index = pd.MultiIndex.from_tuples(\n", " [(2, i) for i in data_landmark_mosquito_wings.coords.loc[2].index],\n", " names=[\"specimen_id\", \"coord_id\"],\n", ")\n", "x2 = pd.DataFrame(\n", " data_landmark_mosquito_wings.coords.loc[2].to_numpy(),\n", " columns=[\"x\", \"y\"],\n", " index=index,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "d3669d77", "metadata": {}, "outputs": [], "source": [ "data_landmark_mosquito_wings.coords.loc[1:3]" ] }, { "cell_type": "code", "execution_count": null, "id": "dd968c4b", "metadata": {}, "outputs": [], "source": [ "df_coords = (\n", " data_landmark_mosquito_wings.coords.unstack()\n", " .swaplevel(1, 0, axis=1)\n", " .sort_index(axis=1)\n", ")\n", "df_coords.columns = [\n", " dim + \"_\" + str(landmark_idx) for landmark_idx, dim in df_coords.columns\n", "]\n", "df_coords" ] }, { "cell_type": "markdown", "id": "b271add5", "metadata": {}, "source": [ "## GPA" ] }, { "cell_type": "code", "execution_count": null, "id": "070c8445", "metadata": {}, "outputs": [], "source": [ "gpa = GeneralizedProcrustesAnalysis().set_output(transform=\"pandas\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bd056e36", "metadata": {}, "outputs": [], "source": [ "df_shapes = gpa.fit_transform(df_coords)\n", "df_shapes" ] }, { "cell_type": "markdown", "id": "179198b8", "metadata": {}, "source": [ "We create a DataFrame, called `df_shapes_vis`, of shape (n_specimens*n_landmarks) x n_dim to visualize the aligned shapes." ] }, { "cell_type": "code", "execution_count": null, "id": "b3f20eef", "metadata": {}, "outputs": [], "source": [ "df_shapes_vis = df_shapes.copy()\n", "df_shapes_vis.columns = pd.MultiIndex.from_tuples(\n", " [\n", " [int(landmark_idx), dim]\n", " for dim, landmark_idx in [idx.split(\"_\") for idx in df_shapes_vis.columns]\n", " ],\n", " names=[\"coord_id\", \"dim\"],\n", ")\n", "df_shapes_vis.sort_index(axis=1, inplace=True)\n", "df_shapes_vis = df_shapes_vis.swaplevel(0, 1, axis=1).stack(level=1, future_stack=True)\n", "df_shapes_vis" ] }, { "cell_type": "code", "execution_count": null, "id": "1cf92983", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "sns.scatterplot(\n", " data=df_shapes_vis, x=\"x\", y=\"y\", hue=\"specimen_id\", style=\"coord_id\", ax=ax\n", ")\n", "ax.set_aspect(\"equal\")\n", "ax.legend(loc=\"upper left\", bbox_to_anchor=(1, 1))" ] }, { "cell_type": "code", "execution_count": null, "id": "a460aad5", "metadata": {}, "outputs": [], "source": [ "configulation_plot(df_shapes_vis, links=links, hue=\"specimen_id\", style=\"coord_id\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0531f4b2", "metadata": {}, "outputs": [], "source": [ "configulation_plot(\n", " df_shapes_vis.loc[0:2], links=links, hue=\"specimen_id\", style=\"coord_id\"\n", ")" ] }, { "cell_type": "markdown", "id": "b3058685", "metadata": {}, "source": [ "## PCA" ] }, { "cell_type": "code", "execution_count": null, "id": "386eeefd", "metadata": {}, "outputs": [], "source": [ "pca = PCA(n_components=10).set_output(transform=\"pandas\")\n", "df_pca = pca.fit_transform(df_shapes)\n", "\n", "df_pca = df_pca.join(data_landmark_mosquito_wings.meta)\n", "df_pca" ] }, { "cell_type": "code", "execution_count": null, "id": "6c7af48a", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "sns.scatterplot(data=df_pca, x=\"pca0\", y=\"pca1\", hue=\"genus\", palette=\"Paired\", ax=ax)\n", "ax.legend(loc=\"upper left\", bbox_to_anchor=(1, 1))" ] }, { "cell_type": "markdown", "id": "0e5aa60d", "metadata": {}, "source": [ "## Morphospace" ] }, { "cell_type": "code", "execution_count": null, "id": "27bf1b02", "metadata": {}, "outputs": [], "source": [ "def get_pc_scores_for_morphospace(ax, num=5):\n", " xrange = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], num)\n", " yrange = np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], num)\n", " return xrange, yrange\n", "\n", "\n", "def plot_recon_morphs(\n", " pca,\n", " fig,\n", " ax,\n", " n_PCs_xy=[1, 2],\n", " morph_num=3,\n", " morph_alpha=1.0,\n", " morph_scale=1.0,\n", " links=[],\n", "):\n", " pc_scores_h, pc_scores_v = get_pc_scores_for_morphospace(ax, morph_num)\n", " print(\"PC_h: \", pc_scores_h, \", PC_v: \", pc_scores_v)\n", " for pc_score_h in pc_scores_h:\n", " for pc_score_v in pc_scores_v:\n", " pc_score = np.zeros(pca.n_components_)\n", " n_PC_h, n_PC_v = n_PCs_xy\n", " pc_score[n_PC_h - 1] = pc_score_h\n", " pc_score[n_PC_v - 1] = pc_score_v\n", "\n", " arr_shapes = pca.inverse_transform([pc_score])\n", " arr_shapes = arr_shapes.reshape(-1, 2)\n", "\n", " df_shapes = pd.DataFrame(arr_shapes, columns=[\"x\", \"y\"])\n", " df_shapes[\"coord_id\"] = [i for i in range(len(df_shapes))]\n", " df_shapes = df_shapes.set_index(\"coord_id\")\n", "\n", " ax_width = ax.get_window_extent().width\n", " fig_width = fig.get_window_extent().width\n", " fig_height = fig.get_window_extent().height\n", " morph_size = morph_scale * ax_width / (fig_width * morph_num)\n", " loc = ax.transData.transform((pc_score_h, pc_score_v))\n", "\n", " if arr_shapes.shape[1] == 3:\n", " axins = fig.add_axes(\n", " [\n", " loc[0] / fig_width - morph_size / 2,\n", " loc[1] / fig_height - morph_size / 2,\n", " morph_size,\n", " morph_size,\n", " ],\n", " anchor=\"C\",\n", " projection=\"3d\",\n", " )\n", " axins.patch.set_alpha(0.3)\n", "\n", " configulation_plot(df_shapes, links=links, ax=axins, alpha=morph_alpha)\n", "\n", " else:\n", " axins = fig.add_axes(\n", " [\n", " loc[0] / fig_width - morph_size / 2,\n", " loc[1] / fig_height - morph_size / 2,\n", " morph_size,\n", " morph_size,\n", " ],\n", " anchor=\"C\",\n", " )\n", " configulation_plot(df_shapes, links=links, ax=axins, alpha=morph_alpha)\n", "\n", " axins.axis(\"off\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0443db9f", "metadata": {}, "outputs": [], "source": [ "morph_num = 5\n", "morph_scale = 1.0\n", "morph_alpha = 0.5\n", "\n", "fig = plt.figure(figsize=(16, 16), dpi=200)\n", "hue_order = df_pca[\"genus\"].unique()\n", "\n", "#########\n", "# PC1\n", "#########\n", "ax = fig.add_subplot(2, 2, 1)\n", "sns.scatterplot(\n", " data=df_pca,\n", " x=\"pca0\",\n", " y=\"pca1\",\n", " hue=\"genus\",\n", " hue_order=hue_order,\n", " palette=\"Paired\",\n", " ax=ax,\n", " legend=True,\n", ")\n", "\n", "plot_recon_morphs(\n", " pca,\n", " morph_num=5,\n", " morph_scale=morph_scale,\n", " morph_alpha=0.5,\n", " fig=fig,\n", " ax=ax,\n", " links=links,\n", ")\n", "\n", "ax.patch.set_alpha(0)\n", "ax.set(xlabel=\"PC1\", ylabel=\"PC2\")\n", "\n", "print(\"PC1-PC2 done\")\n", "\n", "#########\n", "# PC2\n", "#########\n", "ax = fig.add_subplot(2, 2, 2)\n", "sns.scatterplot(\n", " data=df_pca,\n", " x=\"pca1\",\n", " y=\"pca2\",\n", " hue=\"genus\",\n", " hue_order=hue_order,\n", " palette=\"Paired\",\n", " ax=ax,\n", " legend=True,\n", ")\n", "\n", "plot_recon_morphs(\n", " pca,\n", " morph_num=5,\n", " morph_scale=morph_scale,\n", " morph_alpha=0.5,\n", " fig=fig,\n", " ax=ax,\n", " links=links,\n", " n_PCs_xy=[2, 3],\n", ")\n", "\n", "\n", "ax.patch.set_alpha(0)\n", "ax.set(xlabel=\"PC2\", ylabel=\"PC3\")\n", "\n", "print(\"PC2-PC3 done\")\n", "\n", "#########\n", "# PC3\n", "#########\n", "ax = fig.add_subplot(2, 2, 3)\n", "sns.scatterplot(\n", " data=df_pca,\n", " x=\"pca2\",\n", " y=\"pca0\",\n", " hue=\"genus\",\n", " hue_order=hue_order,\n", " palette=\"Paired\",\n", " ax=ax,\n", " legend=True,\n", ")\n", "\n", "plot_recon_morphs(\n", " pca,\n", " morph_num=5,\n", " morph_scale=morph_scale,\n", " morph_alpha=0.5,\n", " fig=fig,\n", " ax=ax,\n", " links=links,\n", " n_PCs_xy=[3, 1],\n", ")\n", "\n", "\n", "ax.patch.set_alpha(0)\n", "ax.set(xlabel=\"PC3\", ylabel=\"PC1\")\n", "\n", "print(\"PC3-PC1 done\")\n", "\n", "#########\n", "# CCR\n", "#########\n", "\n", "ax = fig.add_subplot(2, 2, 4)\n", "explained_variance_ratio_plot(pca, ax=ax, verbose=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "a4c4d154", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "default_lexer": "ipython3" }, "kernelspec": { "display_name": "ktch", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }