"""Plot functions for kriging."""
# Copyright 2024 Koji Noshita
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import numpy as np
import numpy.typing as npt
from ..landmark._kernels import tps_coefficients, tps_warp
from ._base import require_dependencies
[docs]
def tps_grid_2d_plot(
x_reference: npt.ArrayLike,
x_target: npt.ArrayLike,
grid_size: str | float = "auto",
outer: float = 0.1,
n_grid_inner: int = 10,
ax: object | None = None,
) -> object:
"""Plot the thin-plate spline 2D warped grid.
Parameters
----------
x_reference : array-like, shape (n_landmarks, n_dim)
Reference configuration.
x_target : array-like, shape (n_landmarks, n_dim)
Target configuration.
grid_size : str/float, optional
Grid size, by default "auto"
outer : float, optional
Outer range of x_reference covered by the grid, by default 0.1
n_grid_inner : int, optional
Number of inner points on each grid, by default 10
ax : :class:`matplotlib.axes.Axes`, optional
Pre-existing axes for the plot. Otherwise, a new figure and axes
are created.
Returns
-------
ax : :class:`matplotlib.axes.Axes`
Matplotlib axes.
Raises
------
ImportError
If matplotlib is not installed.
"""
require_dependencies("matplotlib")
import matplotlib.pyplot as plt
x_reference = np.asarray(x_reference)
W, c, A = tps_coefficients(x_reference, x_target)
if ax is None:
_, ax = plt.subplots()
x_min, y_min = (1 + outer) * np.min(x_reference, axis=0)
x_max, y_max = (1 + outer) * np.max(x_reference, axis=0)
w = x_max - x_min
h = y_max - y_min
grid_size_ = grid_size
if grid_size == "auto":
grid_size_ = np.min([w, h]) / 10
if w > h:
w = w - w % grid_size_ + grid_size_
else:
h = h - h % grid_size_ + grid_size_
n_grid_x = np.rint(w / grid_size_)
n_grid_y = np.rint(h / grid_size_)
n_grid_x_ = int(n_grid_x * n_grid_inner + 1)
n_grid_y_ = int(n_grid_y * n_grid_inner + 1)
# Grid points
xx = np.linspace(x_min, x_max, n_grid_x_)
yy = np.linspace(y_min, y_max, n_grid_y_)
grid_x, grid_y = np.meshgrid(xx, yy, indexing="ij")
grid_points = np.column_stack([grid_x.ravel(), grid_y.ravel()])
# Warp all grid points
warped = tps_warp(grid_points, x_reference, W, c, A)
w_1 = warped.reshape(n_grid_x_, n_grid_y_, 2)
w_2 = w_1.transpose(1, 0, 2)
ax.plot(w_1[:, ::n_grid_inner, 0], w_1[:, ::n_grid_inner, 1], "gray")
ax.plot(w_2[:, ::n_grid_inner, 0], w_2[:, ::n_grid_inner, 1], "gray")
ax.axis("equal")
ax.scatter(x=x_reference[:, 0], y=x_reference[:, 1], zorder=2)
ax.scatter(x=x_target[:, 0], y=x_target[:, 1], zorder=2)
return ax