Source code for ktch.plot._kriging

"""Plot functions for kriging."""

# Copyright 2024 Koji Noshita
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import numpy as np
import numpy.typing as npt

from ..landmark._kernels import tps_coefficients, tps_warp
from ._base import require_dependencies


[docs] def tps_grid_2d_plot( x_reference: npt.ArrayLike, x_target: npt.ArrayLike, grid_size: str | float = "auto", outer: float = 0.1, n_grid_inner: int = 10, ax: object | None = None, ) -> object: """Plot the thin-plate spline 2D warped grid. Parameters ---------- x_reference : array-like, shape (n_landmarks, n_dim) Reference configuration. x_target : array-like, shape (n_landmarks, n_dim) Target configuration. grid_size : str/float, optional Grid size, by default "auto" outer : float, optional Outer range of x_reference covered by the grid, by default 0.1 n_grid_inner : int, optional Number of inner points on each grid, by default 10 ax : :class:`matplotlib.axes.Axes`, optional Pre-existing axes for the plot. Otherwise, a new figure and axes are created. Returns ------- ax : :class:`matplotlib.axes.Axes` Matplotlib axes. Raises ------ ImportError If matplotlib is not installed. """ require_dependencies("matplotlib") import matplotlib.pyplot as plt x_reference = np.asarray(x_reference) W, c, A = tps_coefficients(x_reference, x_target) if ax is None: _, ax = plt.subplots() x_min, y_min = (1 + outer) * np.min(x_reference, axis=0) x_max, y_max = (1 + outer) * np.max(x_reference, axis=0) w = x_max - x_min h = y_max - y_min grid_size_ = grid_size if grid_size == "auto": grid_size_ = np.min([w, h]) / 10 if w > h: w = w - w % grid_size_ + grid_size_ else: h = h - h % grid_size_ + grid_size_ n_grid_x = np.rint(w / grid_size_) n_grid_y = np.rint(h / grid_size_) n_grid_x_ = int(n_grid_x * n_grid_inner + 1) n_grid_y_ = int(n_grid_y * n_grid_inner + 1) # Grid points xx = np.linspace(x_min, x_max, n_grid_x_) yy = np.linspace(y_min, y_max, n_grid_y_) grid_x, grid_y = np.meshgrid(xx, yy, indexing="ij") grid_points = np.column_stack([grid_x.ravel(), grid_y.ravel()]) # Warp all grid points warped = tps_warp(grid_points, x_reference, W, c, A) w_1 = warped.reshape(n_grid_x_, n_grid_y_, 2) w_2 = w_1.transpose(1, 0, 2) ax.plot(w_1[:, ::n_grid_inner, 0], w_1[:, ::n_grid_inner, 1], "gray") ax.plot(w_2[:, ::n_grid_inner, 0], w_2[:, ::n_grid_inner, 1], "gray") ax.axis("equal") ax.scatter(x=x_reference[:, 0], y=x_reference[:, 1], zorder=2) ax.scatter(x=x_target[:, 0], y=x_target[:, 1], zorder=2) return ax