Source code for ktch.plot._morphospace

"""Morphospace scatter plot with reconstructed shapes."""

# Copyright 2026 Koji Noshita
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any

import numpy as np
import numpy.typing as npt

from ._base import require_dependencies
from ._params import (
    _detect_shape_type,
    _get_renderer_and_projection,
    _resolve_descriptor_params,
    _resolve_reducer_params,
    _validate_components,
)
from ._renderers import _resolve_render_kw


[docs] def morphospace_plot( data: Any | None = None, *, x: str | npt.ArrayLike | None = None, y: str | npt.ArrayLike | None = None, hue: str | npt.ArrayLike | None = None, reducer: Any | None = None, reducer_inverse_transform: Callable[[np.ndarray], np.ndarray] | None = None, n_components: int | None = None, descriptor: Any | None = None, descriptor_inverse_transform: Callable[[np.ndarray], np.ndarray] | None = None, components: tuple[int, int] = (0, 1), shape_type: str = "auto", render_fn: Callable[..., None] | None = None, n_dim: int | None = None, links: Sequence[Sequence[int]] | None = None, n_shapes: int = 5, shape_scale: float = 1.0, shape_color: str = "lightgray", shape_alpha: float = 0.7, palette: str | Sequence | None = None, hue_order: Sequence | None = None, scatter_kw: dict[str, Any] | None = None, ax: object | None = None, **render_kw: Any, ) -> object: """Scatter plot of specimens in morphospace with shape insets. Draws a scatter plot of scores from dimension reduction (reducer) and overlays reconstructed shapes at a regular grid of positions in the low-dimensional space. The function uses the same two-stage inverse transform pipeline as :func:`shape_variation_plot`: ``scores -> [reducer_inverse_transform] -> coefficients -> [descriptor_inverse_transform] -> shape coordinates``. This function calls ``fig.canvas.draw()`` internally to compute accurate pixel positions for inset axes. Inset positions are fixed at draw time and will not automatically update if the figure is later resized or saved at a different DPI. For best results, set the final figure size before calling this function. Parameters ---------- data : DataFrame, optional DataFrame containing scores and metadata. If provided, ``x``, ``y``, ``hue`` refer to column names. x : str or array-like, optional Horizontal axis values (column name or array). y : str or array-like, optional Vertical axis values (column name or array). hue : str or array-like, optional Grouping variable for scatter coloring. reducer : fitted estimator, optional Convenience parameter. Extracts ``reducer_inverse_transform`` via ``.inverse_transform`` and ``n_components`` via ``.n_components_`` (fallback to ``.n_components``). reducer_inverse_transform : callable, optional Overrides ``reducer.inverse_transform``. n_components : int, optional Overrides ``reducer.n_components_``. descriptor : fitted estimator, optional Convenience parameter. Extracts ``descriptor_inverse_transform`` via ``.inverse_transform``. descriptor_inverse_transform : callable, optional Overrides ``descriptor.inverse_transform``. components : tuple of (int, int) 0-indexed component indices for (horizontal, vertical) axes. shape_type : str Shape rendering type. One of ``"auto"``, ``"curve_2d"``, ``"curve_3d"``, ``"surface_3d"``, ``"landmarks_2d"``, ``"landmarks_3d"``. render_fn : callable, optional Custom renderer ``(coords, ax, **kw) -> None``. n_dim : int, optional Spatial dimensionality (for GPA identity case). links : sequence of sequence of int, optional Landmark link pairs. n_shapes : int Number of shapes along each axis (total: ``n_shapes * n_shapes``). shape_scale : float Scale factor for inset shape size. shape_color : str Color for reconstructed shapes. shape_alpha : float Transparency for reconstructed shapes. palette : str or sequence, optional Forwarded to ``sns.scatterplot``. hue_order : sequence, optional Forwarded to ``sns.scatterplot``. scatter_kw : dict, optional Additional kwargs forwarded to ``sns.scatterplot``. ax : matplotlib.axes.Axes, optional Pre-existing axes. If ``None``, creates new figure and axes. **render_kw Forwarded to the shape renderer. Returns ------- ax : matplotlib.axes.Axes The main scatter plot axes. Raises ------ ImportError If matplotlib or seaborn are not installed. ValueError If required parameters cannot be resolved. Notes ----- When ``shape_type="auto"`` (the default), the type is inferred from the output of the descriptor inverse transform: - 4-D array -> ``"surface_3d"`` - 3-D array with last dimension 2 -> ``"curve_2d"`` - 3-D array with last dimension 3 -> ``"curve_3d"`` - No descriptor (identity / GPA case) with ``n_dim=2`` -> ``"landmarks_2d"`` - No descriptor (identity / GPA case) with ``n_dim=3`` -> ``"landmarks_3d"`` For 3-D arrays with ``shape[-1] == 3``, auto-detection chooses ``"curve_3d"``. If the data represents landmarks, specify ``shape_type="landmarks_3d"`` explicitly. 3-D shape types (``"surface_3d"``, ``"curve_3d"``, ``"landmarks_3d"``) use matplotlib 3-D projection for each inset, which is significantly slower. For 3-D surfaces (e.g., SHA), consider using ``n_shapes <= 3`` and reducing surface resolution via a ``descriptor_inverse_transform`` wrapper. See Also -------- shape_variation_plot : Shape grid along component axes. explained_variance_ratio_plot : Scree plot of explained variance. Examples -------- >>> from ktch.plot import morphospace_plot >>> ax = morphospace_plot( # doctest: +SKIP ... data=df_pca, ... x="PC1", y="PC2", hue="genus", ... reducer=pca, ... descriptor=efa, ... palette="Paired", ... n_shapes=5, ... shape_scale=0.8, ... ) """ require_dependencies("matplotlib", "seaborn") import matplotlib.pyplot as plt import seaborn as sns # Create or reuse axes if ax is None: fig, ax = plt.subplots() else: fig = ax.figure # Draw scatter plot (if data provided) if x is not None and y is not None: sns.scatterplot( data=data, x=x, y=y, hue=hue, palette=palette, hue_order=hue_order, ax=ax, **(scatter_kw or {}), ) # Resolve reducer/descriptor parameters (if reducer available) if reducer is not None or reducer_inverse_transform is not None: reducer_inverse_transform, _, n_components = _resolve_reducer_params( reducer, reducer_inverse_transform, explained_variance=None, n_components=n_components, require_variance=False, ) descriptor_inverse_transform, n_dim = _resolve_descriptor_params( descriptor, descriptor_inverse_transform, n_dim, shape_type, ) _validate_components(components, n_components) # Overlay shapes fig.canvas.draw() # force layout for accurate positions comp_h, comp_v = components x_range = np.linspace(*ax.get_xlim(), n_shapes) y_range = np.linspace(*ax.get_ylim(), n_shapes) # Batch reconstruction grid = np.array([(h, v) for h in x_range for v in y_range]) all_scores = np.zeros((len(grid), n_components)) all_scores[:, comp_h] = grid[:, 0] all_scores[:, comp_v] = grid[:, 1] all_coeffs = reducer_inverse_transform(all_scores) if descriptor_inverse_transform is not None: all_coords = np.asarray(descriptor_inverse_transform(all_coeffs)) else: all_coords = all_coeffs.reshape(len(grid), -1, n_dim) # Auto-detect shape_type if needed if shape_type == "auto": shape_type = _detect_shape_type( all_coords[0], descriptor_inverse_transform, n_dim, ) renderer, proj = _get_renderer_and_projection(shape_type, render_fn) # Compute inset sizing ax_extent = ax.get_window_extent() fig_extent = fig.get_window_extent() fig_width = fig_extent.width fig_height = fig_extent.height inset_size = shape_scale * ax_extent.width / (fig_width * n_shapes) resolved = _resolve_render_kw( render_kw, color=shape_color, alpha=shape_alpha, links=links, ) for idx, (score_h, score_v) in enumerate(grid): single = all_coords[idx] loc = ax.transData.transform((score_h, score_v)) axins = fig.add_axes( [ loc[0] / fig_width - inset_size / 2, loc[1] / fig_height - inset_size / 2, inset_size, inset_size, ], anchor="C", projection=proj, ) renderer(single, axins, **resolved) axins.axis("off") ax.set_zorder(1) # draw scatter on top of inset shapes ax.patch.set_alpha(0) # transparent background so insets show through return ax