Source code for ktch.plot._configuration

"""Configuration plot for landmark data."""

# Copyright 2026 Koji Noshita
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np
import pandas as pd

from ._base import require_dependencies


def _coerce_to_dataframe(
    data: pd.DataFrame | np.ndarray,
    x: str,
    y: str,
    z: str | None,
) -> pd.DataFrame:
    """Convert ndarray or DataFrame to a standardized DataFrame.

    Parameters
    ----------
    data : DataFrame or ndarray
        If ndarray, accepts 2D ``(n_landmarks, n_dim)`` for a single
        specimen or 3D ``(n_specimens, n_landmarks, n_dim)`` for multiple
        specimens.
    x, y : str
        Column names for x and y coordinates.
    z : str or None
        Column name for z coordinates, or ``None`` for 2D data.

    Returns
    -------
    DataFrame with appropriate index structure.
    """
    if isinstance(data, np.ndarray):
        dim_cols = [x, y] if z is None else [x, y, z]
        n_dim = len(dim_cols)
        if data.ndim == 2:
            if data.shape[1] < n_dim:
                raise ValueError(
                    f"Expected at least {n_dim} columns, "
                    f"got array with shape {data.shape}."
                )
            df = pd.DataFrame(data[:, :n_dim], columns=dim_cols)
            df.index.name = "coord_id"
            return df
        elif data.ndim == 3:
            if data.shape[2] < n_dim:
                raise ValueError(
                    f"Expected at least {n_dim} columns in last axis, "
                    f"got array with shape {data.shape}."
                )
            n_specimens, n_landmarks = data.shape[0], data.shape[1]
            specimen_ids = np.repeat(np.arange(n_specimens), n_landmarks)
            coord_ids = np.tile(np.arange(n_landmarks), n_specimens)
            values = data[:, :, :n_dim].reshape(-1, n_dim)
            df = pd.DataFrame(values, columns=dim_cols)
            df["specimen_id"] = specimen_ids
            df["coord_id"] = coord_ids
            return df.set_index(["specimen_id", "coord_id"])
        else:
            raise ValueError(f"Expected 2D or 3D array, got {data.ndim}D.")
    return data


def _detect_coord_id_col(data: pd.DataFrame) -> str:
    """Detect the coordinate-id column name from index metadata.

    After ``reset_index()``, the coordinate-id column is the last level
    of the original index.
    """
    if isinstance(data.index, pd.MultiIndex):
        return data.index.names[-1]
    return data.index.name or "coord_id"


def _resolve_color_map(
    config: pd.DataFrame,
    hue: str,
    hue_order: Sequence,
    palette: str | Sequence | None,
) -> dict:
    """Build a mapping from hue values to colors."""
    import seaborn as sns

    if isinstance(palette, dict):
        missing = [k for k in hue_order if k not in palette]
        if missing:
            raise ValueError(f"palette dict is missing keys for hue values: {missing}")
        return {k: palette[k] for k in hue_order}
    elif palette is not None:
        colors = sns.color_palette(palette, n_colors=len(hue_order))
    else:
        if pd.api.types.is_numeric_dtype(config[hue]):
            cmap = sns.cubehelix_palette(as_cmap=True)
            hue_min, hue_max = min(hue_order), max(hue_order)
            colors = []
            for hue_val in hue_order:
                if hue_max > hue_min:
                    norm_val = (hue_val - hue_min) / (hue_max - hue_min)
                else:
                    norm_val = 0.5
                colors.append(cmap(norm_val))
        else:
            colors = sns.color_palette(n_colors=len(hue_order))
    return dict(zip(hue_order, colors))


def _draw_links_2d(
    config: pd.DataFrame,
    links: Sequence[Sequence[int]],
    coord_id_col: str,
    ax: Any,
    *,
    hue: str | None,
    hue_order: Sequence | None,
    color_map: dict | None,
    color_links: str,
    alpha: float,
    x: str,
    y: str,
) -> None:
    """Draw link segments on a 2D axes."""
    from matplotlib.collections import LineCollection

    if hue is None:
        segments = []
        for link in links:
            link_data = config[config[coord_id_col].isin(link)]
            if len(link_data) >= 2:
                coords = link_data[[x, y]].values[:2]
                segments.append(coords)
        if segments:
            lc = LineCollection(segments, colors=color_links, alpha=alpha)
            ax.add_collection(lc)
    else:
        for specimen in hue_order:
            specimen_data = config[config[hue] == specimen]
            segments = []
            for link in links:
                link_data = specimen_data[specimen_data[coord_id_col].isin(link)]
                if len(link_data) >= 2:
                    coords = link_data[[x, y]].values[:2]
                    segments.append(coords)
            if segments:
                lc = LineCollection(segments, colors=[color_map[specimen]], alpha=alpha)
                ax.add_collection(lc)
    ax.autoscale_view()


def _draw_links_3d(
    config: pd.DataFrame,
    links: Sequence[Sequence[int]],
    coord_id_col: str,
    ax: Any,
    *,
    hue: str | None,
    hue_order: Sequence | None,
    color_map: dict | None,
    color_links: str,
    alpha: float,
    x: str,
    y: str,
    z: str,
) -> None:
    """Draw link segments on a 3D axes."""
    from mpl_toolkits.mplot3d.art3d import Line3DCollection

    if hue is None:
        segments = []
        for link in links:
            link_data = config[config[coord_id_col].isin(link)]
            if len(link_data) >= 2:
                coords = link_data[[x, y, z]].values[:2]
                segments.append(coords)
        if segments:
            lc = Line3DCollection(segments, colors=color_links, alpha=alpha)
            ax.add_collection3d(lc)
    else:
        for specimen in hue_order:
            specimen_data = config[config[hue] == specimen]
            segments = []
            for link in links:
                link_data = specimen_data[specimen_data[coord_id_col].isin(link)]
                if len(link_data) >= 2:
                    coords = link_data[[x, y, z]].values[:2]
                    segments.append(coords)
            if segments:
                lc = Line3DCollection(
                    segments, colors=[color_map[specimen]], alpha=alpha
                )
                ax.add_collection3d(lc)


def _scatter_3d(
    config: pd.DataFrame,
    x: str,
    y: str,
    z: str,
    ax: Any,
    *,
    hue: str | None,
    hue_order: Sequence | None,
    color_map: dict | None,
    color: str,
    s: float,
    alpha: float,
) -> None:
    """Draw scatter points on a 3D axes."""
    if hue is None:
        ax.scatter(config[x], config[y], config[z], c=color, s=s, alpha=alpha)
    else:
        for specimen in hue_order:
            specimen_data = config[config[hue] == specimen]
            ax.scatter(
                specimen_data[x],
                specimen_data[y],
                specimen_data[z],
                c=[color_map[specimen]],
                s=s,
                alpha=alpha,
                label=specimen,
            )


def _set_box_aspect_equal(ax: Any, config: pd.DataFrame, x: str, y: str, z: str):
    """Set equal box aspect ratio for 3D axes."""
    ptp = [config[col].max() - config[col].min() for col in (x, y, z)]
    ptp = [max(p, 1e-10) for p in ptp]
    ax.set_box_aspect(ptp)


[docs] def configuration_plot( data: pd.DataFrame | np.ndarray, *, x: str = "x", y: str = "y", z: str | None = None, links: Sequence[Sequence[int]] | None = None, ax: object | None = None, hue: str | None = None, hue_order: Sequence | None = None, palette: str | Sequence | None = None, color: str = "gray", color_links: str | None = None, style: str | None = None, s: float = 10, alpha: float = 1.0, ) -> object: """Plot configurations of landmarks. Visualizes one or more configurations of landmarks (specimens) as scatter points with optional line segments connecting landmarks. Supports both 2D and 3D data. Parameters ---------- data : pandas.DataFrame or numpy.ndarray Landmark coordinates. Accepts: - A DataFrame with ``MultiIndex (specimen_id, coord_id)`` for multiple specimens. - A DataFrame with a single ``coord_id`` index for one specimen (e.g., ``data.coords.loc[0]``). - A 2D array of shape ``(n_landmarks, n_dim)`` for a single specimen. - A 3D array of shape ``(n_specimens, n_landmarks, n_dim)`` for multiple specimens. x : str, default="x" Column name for x coordinates. y : str, default="y" Column name for y coordinates. z : str or None, default=None Column name for z coordinates. If given, a 3D plot is produced. links : sequence of sequence of int, optional Pairs of ``coord_id`` values to connect with line segments, e.g., ``[[0, 1], [1, 2]]``. ax : matplotlib.axes.Axes or None, default=None Target axes. If ``None``, a new figure and axes are created. For 3D plots (``z`` is not ``None``), a ``projection='3d'`` axes is required; one is created automatically when ``ax=None``. hue : str or None, default=None Column name (after ``reset_index()``) for color grouping, e.g., ``"specimen_id"``. When ``hue`` is active, ``color`` and ``color_links`` are ignored and colors are determined by the palette. hue_order : sequence or None, default=None Order of hue categories. If ``None``, uses the order of appearance. palette : str, sequence, dict, or None, default=None Seaborn palette name, explicit color list, or a dict mapping hue values to colors. A dict must contain all values in ``hue_order``. color : str, default="gray" Point color when ``hue`` is ``None``. color_links : str or None, default=None Link color when ``hue`` is ``None``. If ``None``, uses ``color``. style : str or None, default=None Column name for marker style differentiation (2D only, passed to ``seaborn.scatterplot``). Ignored for 3D plots. s : float, default=10 Marker size. alpha : float, default=1.0 Transparency for both points and links. Returns ------- ax : matplotlib.axes.Axes The axes with the plot. """ require_dependencies("matplotlib", "seaborn") import matplotlib.pyplot as plt import seaborn as sns data = _coerce_to_dataframe(data, x, y, z) is_3d = z is not None # Validate axes dimensionality if ax is not None: ax_is_3d = hasattr(ax, "get_zlim") if is_3d and not ax_is_3d: raise ValueError( "z is specified but ax is not a 3D axes. " "Pass ax=None or create axes with projection='3d'." ) if not is_3d and ax_is_3d: raise ValueError( "z is not specified but ax is a 3D axes. Pass ax=None or a 2D axes." ) # Create axes if needed if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection="3d" if is_3d else None) if color_links is None: color_links = color # Normalize input: ensure coord_id is available as a column coord_id_col = _detect_coord_id_col(data) config = data.reset_index() # Assign positional coord_id when the detected column does not exist if links and coord_id_col not in config.columns: config[coord_id_col] = range(len(config)) # Resolve hue color_map = None if hue is not None: if hue_order is None: hue_order = list(config[hue].unique()) color_map = _resolve_color_map(config, hue, hue_order, palette) # Draw links if links: if is_3d: _draw_links_3d( config, links, coord_id_col, ax, hue=hue, hue_order=hue_order, color_map=color_map, color_links=color_links, alpha=alpha, x=x, y=y, z=z, ) else: _draw_links_2d( config, links, coord_id_col, ax, hue=hue, hue_order=hue_order, color_map=color_map, color_links=color_links, alpha=alpha, x=x, y=y, ) # Draw points if is_3d: _scatter_3d( config, x, y, z, ax, hue=hue, hue_order=hue_order, color_map=color_map, color=color, s=s, alpha=alpha, ) _set_box_aspect_equal(ax, config, x, y, z) else: if hue is not None: sns.scatterplot( data=config, x=x, y=y, ax=ax, hue=hue, hue_order=hue_order, palette=palette, style=style, alpha=alpha, s=s, ) else: sns.scatterplot( data=config, x=x, y=y, ax=ax, color=color, style=style, alpha=alpha, s=s, ) ax.set_aspect("equal") # Move legend outside when present if ax.get_legend() is not None: sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) return ax