Source code for ktch.io._tps

"""TPS file I/O functions."""

# Copyright 2023 Koji Noshita
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import numpy.typing as npt
import pandas as pd

####################################
# TPS dataclass
####################################


@dataclass
class TPSData:
    """Data for a single specimen in TPS format.

    Attributes
    ----------
    idx : str
        Specimen identifier (ID field in TPS).
    landmarks : ndarray of shape (n_landmarks, n_dim)
        Landmark coordinates. ``n_dim`` is 2 or 3.
    image_path : str, optional
        Path to the associated image (IMAGE field in TPS).
    scale : float, optional
        Scale factor (SCALE field in TPS).
    curves : list of ndarray, optional
        Semilandmark curves. Each array has shape ``(n_points, n_dim)``.
    comments : str, optional
        Free-text comments (COMMENTS field in TPS).
    """

    idx: str
    landmarks: np.ndarray
    image_path: str = None
    scale: float = None
    curves: list[np.ndarray] = None
    comments: str = None

    @property
    def landmarks(self) -> np.ndarray:
        return self._landmarks

    @landmarks.setter
    def landmarks(self, value: npt.ArrayLike) -> None:
        self._landmarks = np.array(value)

    def to_numpy(self):
        if self.curves is None:
            return self.landmarks
        else:
            return self.landmarks, self.curves

    def to_dataframe(self):
        if self.landmarks.shape[1] == 2:
            columns = ["x", "y"]
        elif self.landmarks.shape[1] == 3:
            columns = ["x", "y", "z"]
        else:
            raise ValueError("n_dim must be 2 or 3.")

        df_landmarks = pd.DataFrame(
            self.landmarks,
            columns=columns,
            index=pd.MultiIndex.from_tuples(
                [[self.idx, i] for i in range(len(self.landmarks))],
                name=["specimen_id", "coord_id"],
            ),
        )
        if self.curves is None:
            return df_landmarks
        else:
            return df_landmarks, [
                pd.DataFrame(
                    curve,
                    columns=columns,
                    index=pd.MultiIndex.from_tuples(
                        [[self.idx, i] for i in range(len(curve))],
                        name=["specimen_id", "coord_id"],
                    ),
                )
                for curve in self.curves
            ]


####################################
# TPS I/O functions
####################################


[docs] def read_tps(file_path, as_frame=False): """Read landmark data from a TPS file. Parameters ---------- file_path : str or path-like Path to the TPS file. as_frame : bool, default=False If True, return pandas DataFrame(s). Otherwise, return ndarray(s). Returns ------- landmarks : ndarray of shape (n_specimens, n_landmarks, n_dim) or DataFrame Landmark coordinates. For a single specimen without semilandmarks, the shape is ``(n_landmarks, n_dim)``. semilandmarks : list, optional Returned only when the file contains CURVES data. When ``as_frame=False``, a list of per-specimen curve lists (each curve is an ndarray of shape ``(n_points, n_dim)``). When ``as_frame=True``, a list of per-specimen DataFrame lists. """ path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"{path} does not exist.") if not path.suffix == ".tps": raise ValueError(f"{path} is not a TPS file.") tps_data = _read_tps(path) if isinstance(tps_data, TPSData): if as_frame: return tps_data.to_dataframe() else: return tps_data.to_numpy() elif isinstance(tps_data, list): semilandmark_flag = sum( [False if tps_datum.curves is None else True for tps_datum in tps_data] ) if (semilandmark_flag == 0) or (semilandmark_flag == len(tps_data)): pass else: raise ValueError("Some specimens have semilandmarks and others do not.") if as_frame: if semilandmark_flag == 0: landmarks = [tps_datum.to_dataframe() for tps_datum in tps_data] landmarks = pd.concat(landmarks) return landmarks else: landmarks = [tps_datum.to_dataframe()[0] for tps_datum in tps_data] semilandmarks = [tps_datum.to_dataframe()[1] for tps_datum in tps_data] return landmarks, semilandmarks else: if semilandmark_flag == 0: landmarks = np.array([tps_datum.to_numpy() for tps_datum in tps_data]) return landmarks else: landmarks = np.array( [tps_datum.to_numpy()[0] for tps_datum in tps_data] ) semilandmarks = [tps_datum.to_numpy()[1] for tps_datum in tps_data] return landmarks, semilandmarks
[docs] def write_tps( file_path, landmarks, image_path=None, idx=None, scale=None, semilandmarks=None, comments=None, ) -> None: """Write landmark data to a TPS file. Parameters ---------- file_path : str or path-like Output file path. landmarks : ndarray or list of ndarray Landmark coordinates. A 2D array of shape ``(n_landmarks, n_dim)`` for a single specimen, a 3D array of shape ``(n_specimens, n_landmarks, n_dim)`` for multiple specimens, or a list of per-specimen 2D arrays. image_path : str or list of str, optional Image path(s) for the IMAGE field. idx : str, int, or list, optional Specimen identifier(s) for the ID field. Defaults to 0-based indices for multiple specimens. scale : float or list of float, optional Scale factor(s) for the SCALE field. semilandmarks : array-like or list, optional Semilandmark curves. For a single specimen, a 2D array or a list of 2D arrays (one per curve). For multiple specimens, a list with one entry per specimen. comments : str or list of str, optional Free-text comment(s) for the COMMENTS field. """ landmarks_list = _normalize_landmarks_input(landmarks) n_specimens = len(landmarks_list) image_path_ = _normalize_optional_values( image_path, n_specimens=n_specimens, arg_name="image_path", default=None ) if idx is None: idx_ = [None] if n_specimens == 1 else [i for i in range(n_specimens)] else: idx_ = _normalize_optional_values( idx, n_specimens=n_specimens, arg_name="idx", default=None ) scale_ = _normalize_optional_values( scale, n_specimens=n_specimens, arg_name="scale", default=None ) semilandmarks_ = _normalize_semilandmarks_input(semilandmarks, n_specimens) comments_ = _normalize_optional_values( comments, n_specimens=n_specimens, arg_name="comments", default=None ) tps_data = [ TPSData( idx=idx_[i], landmarks=landmarks_list[i], image_path=image_path_[i], scale=scale_[i], curves=semilandmarks_[i], comments=comments_[i], ) for i in range(n_specimens) ] if n_specimens == 1: tps_data = tps_data[0] _write_tps(file_path=file_path, tps_data=tps_data)
def _normalize_landmarks_input(landmarks) -> list[np.ndarray]: """Normalize landmarks input into a per-specimen list of arrays.""" if isinstance(landmarks, np.ndarray): if landmarks.ndim == 2: return [landmarks] if landmarks.ndim == 3: return [landmarks[i] for i in range(len(landmarks))] raise ValueError("landmarks must be a 2D or 3D numpy array.") if isinstance(landmarks, list): if len(landmarks) == 0: raise ValueError("landmarks list cannot be empty.") return [np.asarray(lm) for lm in landmarks] if isinstance(landmarks, pd.DataFrame): raise NotImplementedError() raise ValueError( "landmarks must be a numpy array or a list of per-specimen arrays." ) def _normalize_optional_values( values, *, n_specimens: int, arg_name: str, default=None, ): """Normalize scalar or sequence metadata into n_specimens length list.""" if values is None: return [default] * n_specimens if isinstance(values, np.ndarray): if values.ndim == 0: return [values.item()] * n_specimens values_list = values.tolist() if len(values_list) != n_specimens: raise ValueError( f"{arg_name} must have length {n_specimens}, got {len(values_list)}." ) return values_list if isinstance(values, Sequence) and not isinstance(values, (str, bytes, bytearray)): values_list = list(values) if len(values_list) != n_specimens: raise ValueError( f"{arg_name} must have length {n_specimens}, got {len(values_list)}." ) return values_list return [values] * n_specimens def _normalize_curve_block(curves): """Normalize one specimen's semilandmarks to a list of 2D arrays.""" if curves is None: return None if isinstance(curves, np.ndarray): if curves.ndim != 2: raise ValueError( "Each curve must be a 2D array with shape (n_points, n_dim)." ) return [curves] if isinstance(curves, Sequence) and not isinstance(curves, (str, bytes, bytearray)): curves_list = [] for curve in curves: curve_arr = np.asarray(curve) if curve_arr.ndim != 2: raise ValueError( "Each curve must be a 2D array with shape (n_points, n_dim)." ) curves_list.append(curve_arr) return curves_list raise ValueError( "semilandmarks must be None, a 2D array, or a sequence of 2D arrays." ) def _normalize_semilandmarks_input(semilandmarks, n_specimens: int): """Normalize semilandmarks input into per-specimen curve lists.""" if semilandmarks is None: return [None] * n_specimens if n_specimens == 1: if ( isinstance(semilandmarks, Sequence) and not isinstance(semilandmarks, (str, bytes, bytearray, np.ndarray)) and len(semilandmarks) == 1 and ( semilandmarks[0] is None or isinstance(semilandmarks[0], (Sequence, np.ndarray)) ) ): return [_normalize_curve_block(semilandmarks[0])] return [_normalize_curve_block(semilandmarks)] if not ( isinstance(semilandmarks, Sequence) and not isinstance(semilandmarks, (str, bytes, bytearray, np.ndarray)) ): raise ValueError( "For multiple specimens, semilandmarks must be a sequence with one " "entry per specimen." ) semilandmarks_list = list(semilandmarks) if len(semilandmarks_list) != n_specimens: raise ValueError( "For multiple specimens, semilandmarks must have the same length as " f"landmarks ({n_specimens}), got {len(semilandmarks_list)}." ) return [_normalize_curve_block(curves) for curves in semilandmarks_list] ###################################### # Regular Express Patterns # ###################################### PTN_HEAD = re.compile(r"^LM3?\s*=", flags=re.MULTILINE) PTN_LM = re.compile( r"^(?P<LM>LM3?\s*=\s*[0-9]+(?:\s+[-\w\.]+)*)$", flags=re.MULTILINE, ) PTN_CURVES = re.compile( r"^(?P<CURVES>CURVES\s*=\s*[0-9]+\s+(?:POINTS\s*=\s*[0-9]+(?:\s+[-0-9\.]+)+\s*)+)(?=^[A-Z]|\Z)", flags=re.MULTILINE, ) PTN_POINTS = re.compile( r"^(?P<POINTS>POINTS\s*=\s*[0-9]+(?:\s+[-0-9\.]+)*)$", flags=re.MULTILINE, ) PTN_DICT = re.compile(r"^(?P<key>\w+)\s*=\s*(?P<value>[^\r\n]+)$") PTN_COORD = re.compile( r"^(?P<x>[-0-9\.]+|nan)\s(?P<y>[-0-9\.]+|nan)\s*(?P<z>[-0-9\.]+|nan)?$", flags=re.IGNORECASE, ) ###################################### # Helper functions # ###################################### def _read_tps(file_path): with open(file_path, "r") as f: read_data = f.read() specimens = PTN_HEAD.split(read_data) specimens = ["LM=" + specimen for specimen in specimens if len(specimen) > 0] tps_data = [_read_tps_single(specimen) for specimen in specimens] if len(tps_data) == 1: return tps_data[0] return tps_data def _read_tps_single(specimen_str: str) -> TPSData: m = PTN_LM.search(specimen_str) if m is not None: key, landmarks = _read_coordinate_values( [row for row in m["LM"].splitlines() if len(row) > 0] ) else: raise ValueError("Failed to parse landmark (LM) section in TPS specimen") filtered_str = PTN_CURVES.sub("", PTN_LM.sub("", specimen_str)) filtered_list = [row for row in filtered_str.splitlines() if len(row) > 0] meta_dict = {} for row in filtered_list: m_meta = PTN_DICT.search(row) if m_meta is None: raise ValueError(f"Invalid metadata line in TPS specimen: {row!r}") meta_dict[m_meta["key"]] = m_meta["value"] if "ID" in meta_dict.keys(): idx = meta_dict["ID"] else: raise ValueError("Missing required 'ID' field in TPS specimen") if "IMAGE" in meta_dict.keys(): image_path = meta_dict["IMAGE"] else: image_path = None if "SCALE" in meta_dict.keys(): scale = meta_dict["SCALE"] else: scale = None if "COMMENTS" in meta_dict.keys(): comments = meta_dict["COMMENTS"] else: comments = None m = PTN_CURVES.search(specimen_str) if m is not None: curves = [] for points in PTN_POINTS.finditer(m["CURVES"]): key, val = _read_coordinate_values( [row for row in points["POINTS"].splitlines() if len(row) > 0] ) curves.append(val) else: curves = None tps_data = TPSData( idx=idx, landmarks=landmarks, image_path=image_path, scale=scale, curves=curves, comments=comments, ) return tps_data def _read_coordinate_values(coordinate_list): header = coordinate_list[0] body = coordinate_list[1:] m = PTN_DICT.match(header) if m is None: raise ValueError(f"Invalid coordinate header in TPS: {header!r}") key = m["key"] value_rows = [] for row in body: m_coord = PTN_COORD.match(row) if m_coord is None: raise ValueError(f"Invalid coordinate row in TPS: {row!r}") value_rows.append( [float(x) for x in m_coord.groups() if x is not None and len(x) > 0] ) value = np.array(value_rows) return key, value def _write_tps_single(file_path, tps_data, write_mode="w"): with open(file_path, write_mode) as f: f.write("LM=" + str(len(tps_data.landmarks)) + "\n") f.write( "\n".join([" ".join(map(str, row)) for row in tps_data.landmarks.tolist()]) ) f.write("\n") if tps_data.image_path is not None: f.write("IMAGE=" + tps_data.image_path + "\n") f.write("ID=" + str(tps_data.idx) + "\n") if tps_data.scale is not None: f.write("SCALE=" + str(tps_data.scale) + "\n") if tps_data.curves is not None: f.write("CURVES=" + str(len(tps_data.curves)) + "\n") for curve in tps_data.curves: f.write("POINTS=" + str(len(curve)) + "\n") curve_lines = [" ".join(map(str, row)) for row in curve.tolist()] f.write("\n".join(curve_lines)) f.write("\n") if tps_data.comments is not None: f.write("COMMENTS=" + tps_data.comments + "\n") f.write("\n") def _write_tps(file_path, tps_data): if isinstance(tps_data, TPSData): _write_tps_single(file_path, tps_data) elif isinstance(tps_data, list): for i, tps_datum in enumerate(tps_data): if i == 0: mode = "w" else: mode = "a" _write_tps_single(file_path, tps_datum, write_mode=mode)