diff --git a/bluemath_tk/deeplearning/metrics.py b/bluemath_tk/deeplearning/metrics.py new file mode 100644 index 0000000..b743b4e --- /dev/null +++ b/bluemath_tk/deeplearning/metrics.py @@ -0,0 +1,345 @@ +"""Metrics and losses for deeplearning models. + +This module provides reconstruction metrics that can be used both for +post-training evaluation and, when using PyTorch tensors, as differentiable +training losses. +""" + +from __future__ import annotations + +from typing import Dict, Literal, Tuple, Union + +import numpy as np +import torch +from torch import nn + + +ArrayLike = Union[np.ndarray, torch.Tensor] +MetricName = Literal["mse", "mae", "rmse"] +ReductionName = Literal["none", "sample", "mean", "sum"] + + +__all__ = [ + "reconstruction_error", + "evaluate_reconstruction", + "ReconstructionLoss", +] + + +def _normalise_options(metric: str, reduction: str) -> Tuple[str, str]: + """Validate and normalise metric/reduction names.""" + metric = metric.lower() + reduction = reduction.lower() + + valid_metrics = {"mse", "mae", "rmse"} + valid_reductions = {"none", "sample", "mean", "sum"} + + if metric not in valid_metrics: + raise ValueError(f"metric must be one of {valid_metrics}, got {metric!r}") + if reduction not in valid_reductions: + raise ValueError( + f"reduction must be one of {valid_reductions}, got {reduction!r}" + ) + + return metric, reduction + + +def _uses_torch(y_true: ArrayLike, y_pred: ArrayLike) -> bool: + """Return True when at least one input is a PyTorch tensor.""" + return torch.is_tensor(y_true) or torch.is_tensor(y_pred) + + +def _to_matching_tensor( + array: ArrayLike, + reference: torch.Tensor | None = None, +) -> torch.Tensor: + """Convert an array to a tensor matching a reference tensor if provided.""" + if torch.is_tensor(array): + return array + + if reference is not None: + return torch.as_tensor(array, dtype=reference.dtype, device=reference.device) + + return torch.as_tensor(array) + + +def _to_numpy(array: ArrayLike) -> np.ndarray: + """Convert NumPy arrays or tensors to NumPy arrays for summary statistics.""" + if torch.is_tensor(array): + return array.detach().cpu().numpy() + + return np.asarray(array) + + +def _check_same_shape(y_true: ArrayLike, y_pred: ArrayLike) -> None: + """Raise a ValueError when target and prediction shapes differ.""" + if tuple(y_true.shape) != tuple(y_pred.shape): + raise ValueError( + "y_true and y_pred must have the same shape. " + f"Got y_true.shape={tuple(y_true.shape)} and " + f"y_pred.shape={tuple(y_pred.shape)}." + ) + + +def _elementwise_error_torch( + y_true: torch.Tensor, + y_pred: torch.Tensor, + metric: str, +) -> torch.Tensor: + """Return elementwise reconstruction error for tensors.""" + diff = y_pred - y_true + + if metric == "mae": + return torch.abs(diff) + + return diff.pow(2) + + +def _elementwise_error_numpy( + y_true: np.ndarray, + y_pred: np.ndarray, + metric: str, +) -> np.ndarray: + """Return elementwise reconstruction error for NumPy arrays.""" + diff = y_pred - y_true + + if metric == "mae": + return np.abs(diff) + + return diff**2 + + +def _reduce_torch( + elementwise_error: torch.Tensor, + metric: str, + reduction: str, + eps: float, +) -> torch.Tensor: + """Reduce elementwise tensor errors.""" + if reduction == "none": + if metric == "rmse": + return torch.sqrt(elementwise_error + eps) + return elementwise_error + + if elementwise_error.ndim <= 1: + sample_errors = elementwise_error + else: + axes = tuple(range(1, elementwise_error.ndim)) + sample_errors = elementwise_error.mean(dim=axes) + + if metric == "rmse": + sample_errors = torch.sqrt(sample_errors + eps) + + if reduction == "sample": + return sample_errors + if reduction == "mean": + return sample_errors.mean() + + return sample_errors.sum() + + +def _reduce_numpy( + elementwise_error: np.ndarray, + metric: str, + reduction: str, + eps: float, +) -> Union[np.ndarray, float]: + """Reduce elementwise NumPy errors.""" + if reduction == "none": + if metric == "rmse": + return np.sqrt(elementwise_error + eps) + return elementwise_error + + if elementwise_error.ndim <= 1: + sample_errors = elementwise_error + else: + axes = tuple(range(1, elementwise_error.ndim)) + sample_errors = np.mean(elementwise_error, axis=axes) + + if metric == "rmse": + sample_errors = np.sqrt(sample_errors + eps) + + if reduction == "sample": + return sample_errors + if reduction == "mean": + return float(np.mean(sample_errors)) + + return float(np.sum(sample_errors)) + + +def reconstruction_error( + y_true: ArrayLike, + y_pred: ArrayLike, + metric: MetricName = "mse", + reduction: ReductionName = "mean", + eps: float = 0.0, +) -> Union[np.ndarray, torch.Tensor, float]: + """Compute reconstruction error between target and prediction. + + Parameters + ---------- + y_true : np.ndarray or torch.Tensor + Target/reference data. + y_pred : np.ndarray or torch.Tensor + Reconstructed or predicted data. + metric : {"mse", "mae", "rmse"}, optional + Reconstruction metric. Default is "mse". + reduction : {"none", "sample", "mean", "sum"}, optional + Reduction mode. + + - "none": return elementwise errors. + - "sample": return one value per sample, reducing over all axes except + the first. + - "mean": return the mean of sample errors. + - "sum": return the sum of sample errors. + + Default is "mean". + eps : float, optional + Small value added inside the square root for RMSE. Default is 0.0. + + Returns + ------- + np.ndarray, torch.Tensor or float + Reconstruction error according to the selected metric and reduction. + + Notes + ----- + When PyTorch tensors are provided, the returned value remains a tensor and + can be used as a differentiable loss. + """ + metric, reduction = _normalise_options(metric, reduction) + + if _uses_torch(y_true, y_pred): + if torch.is_tensor(y_pred): + y_pred_tensor = y_pred + y_true_tensor = _to_matching_tensor(y_true, reference=y_pred_tensor) + else: + y_true_tensor = y_true + y_pred_tensor = _to_matching_tensor(y_pred, reference=y_true_tensor) + + _check_same_shape(y_true_tensor, y_pred_tensor) + + elementwise_error = _elementwise_error_torch( + y_true_tensor, + y_pred_tensor, + metric, + ) + return _reduce_torch(elementwise_error, metric, reduction, eps) + + y_true_array = np.asarray(y_true) + y_pred_array = np.asarray(y_pred) + + _check_same_shape(y_true_array, y_pred_array) + + elementwise_error = _elementwise_error_numpy( + y_true_array, + y_pred_array, + metric, + ) + return _reduce_numpy(elementwise_error, metric, reduction, eps) + + +def evaluate_reconstruction( + y_true: ArrayLike, + y_pred: ArrayLike, + metric: MetricName = "mse", + eps: float = 0.0, +) -> Dict[str, float]: + """Return summary statistics for per-sample reconstruction error. + + Parameters + ---------- + y_true : np.ndarray or torch.Tensor + Target/reference data. + y_pred : np.ndarray or torch.Tensor + Reconstructed or predicted data. + metric : {"mse", "mae", "rmse"}, optional + Reconstruction metric. Default is "mse". + eps : float, optional + Small value added inside the square root for RMSE. Default is 0.0. + + Returns + ------- + dict + Dictionary containing the metric name, number of samples, mean, + standard deviation, median, minimum and maximum reconstruction error. + """ + metric, _ = _normalise_options(metric, "sample") + + sample_errors = reconstruction_error( + y_true, + y_pred, + metric=metric, + reduction="sample", + eps=eps, + ) + + values = np.ravel(_to_numpy(sample_errors)).astype(float) + + return { + "metric": metric, + "n_samples": int(values.shape[0]), + "mean": float(np.mean(values)), + "std": float(np.std(values)), + "median": float(np.median(values)), + "min": float(np.min(values)), + "max": float(np.max(values)), + } + + +class ReconstructionLoss(nn.Module): + """PyTorch loss wrapper for reconstruction metrics. + + Parameters + ---------- + metric : {"mse", "mae", "rmse"}, optional + Reconstruction metric. Default is "mse". + reduction : {"none", "sample", "mean", "sum"}, optional + Reduction mode. Default is "mean". + eps : float, optional + Small value added inside the square root for RMSE. Default is 0.0. + + Examples + -------- + >>> criterion = ReconstructionLoss(metric="mse", reduction="mean") + >>> loss = criterion(y_pred, y_true) + >>> loss.backward() + """ + + def __init__( + self, + metric: MetricName = "mse", + reduction: ReductionName = "mean", + eps: float = 0.0, + ): + super().__init__() + self.metric, self.reduction = _normalise_options(metric, reduction) + self.eps = eps + + def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: + """Compute reconstruction loss. + + Parameters + ---------- + y_pred : torch.Tensor + Predicted/reconstructed tensor. + y_true : torch.Tensor + Target/reference tensor. + + Returns + ------- + torch.Tensor + Reconstruction loss. + """ + loss = reconstruction_error( + y_true, + y_pred, + metric=self.metric, + reduction=self.reduction, + eps=self.eps, + ) + + if not torch.is_tensor(loss): + return torch.as_tensor(loss, dtype=y_pred.dtype, device=y_pred.device) + + return loss diff --git a/tests/deeplearning/test_metrics.py b/tests/deeplearning/test_metrics.py new file mode 100644 index 0000000..20b591e --- /dev/null +++ b/tests/deeplearning/test_metrics.py @@ -0,0 +1,166 @@ +import numpy as np +import pytest + +torch = pytest.importorskip("torch") + +from bluemath_tk.deeplearning.metrics import ( + ReconstructionLoss, + evaluate_reconstruction, + reconstruction_error, +) + + +torch.set_num_threads(1) + + +def test_reconstruction_error_numpy_sample_mse_matches_manual(): + """Per-sample MSE should match the manual NumPy calculation.""" + y_true = np.array( + [ + [[0.0, 1.0], [2.0, 3.0]], + [[1.0, 1.0], [1.0, 1.0]], + ], + dtype=np.float32, + ) + y_pred = np.array( + [ + [[1.0, 1.0], [4.0, 3.0]], + [[2.0, 1.0], [1.0, 5.0]], + ], + dtype=np.float32, + ) + + errors = reconstruction_error( + y_true, + y_pred, + metric="mse", + reduction="sample", + ) + manual_errors = np.mean((y_pred - y_true) ** 2, axis=(1, 2)) + + assert isinstance(errors, np.ndarray) + assert errors.shape == (2,) + assert np.allclose(errors, manual_errors) + + +def test_reconstruction_error_numpy_reduction_none_and_sum(): + """Elementwise and summed MAE reductions should match manual results.""" + y_true = np.array([[0.0, 1.0, 2.0], [1.0, 1.0, 1.0]], dtype=np.float32) + y_pred = np.array([[1.0, 1.0, 4.0], [2.0, 3.0, 1.0]], dtype=np.float32) + + elementwise = reconstruction_error( + y_true, + y_pred, + metric="mae", + reduction="none", + ) + summed = reconstruction_error( + y_true, + y_pred, + metric="mae", + reduction="sum", + ) + + manual_elementwise = np.abs(y_pred - y_true) + manual_summed = np.sum(np.mean(manual_elementwise, axis=1)) + + assert np.allclose(elementwise, manual_elementwise) + assert np.isclose(summed, manual_summed) + + +def test_reconstruction_error_numpy_rmse_sample_matches_manual(): + """Per-sample RMSE should be sqrt of per-sample MSE.""" + y_true = np.array([[0.0, 0.0], [0.0, 0.0]], dtype=np.float32) + y_pred = np.array([[3.0, 4.0], [0.0, 12.0]], dtype=np.float32) + + errors = reconstruction_error( + y_true, + y_pred, + metric="rmse", + reduction="sample", + ) + manual_errors = np.sqrt(np.mean((y_pred - y_true) ** 2, axis=1)) + + assert np.allclose(errors, manual_errors) + + +def test_evaluate_reconstruction_returns_summary_statistics(): + """evaluate_reconstruction should return useful scalar summary statistics.""" + y_true = np.array([[0.0, 0.0], [1.0, 1.0]], dtype=np.float32) + y_pred = np.array([[1.0, 1.0], [1.0, 3.0]], dtype=np.float32) + + summary = evaluate_reconstruction(y_true, y_pred, metric="mse") + + assert summary["metric"] == "mse" + assert summary["n_samples"] == 2 + for key in ["mean", "std", "median", "min", "max"]: + assert key in summary + assert np.isfinite(summary[key]) + assert summary["min"] <= summary["mean"] <= summary["max"] + + +def test_reconstruction_error_torch_mean_is_differentiable(): + """Tensor inputs should keep gradients so the metric can be used as a loss.""" + y_true = torch.zeros((4, 3), dtype=torch.float32) + y_pred = torch.randn((4, 3), dtype=torch.float32, requires_grad=True) + + loss = reconstruction_error( + y_true, + y_pred, + metric="mse", + reduction="mean", + ) + loss.backward() + + assert torch.is_tensor(loss) + assert loss.ndim == 0 + assert y_pred.grad is not None + assert torch.isfinite(y_pred.grad).all() + + +def test_reconstruction_error_supports_mixed_numpy_target_and_tensor_prediction(): + """NumPy targets should be converted to the prediction tensor device/dtype.""" + y_true = np.zeros((4, 3), dtype=np.float32) + y_pred = torch.ones((4, 3), dtype=torch.float32, requires_grad=True) + + loss = reconstruction_error( + y_true, + y_pred, + metric="mse", + reduction="mean", + ) + loss.backward() + + assert torch.is_tensor(loss) + assert y_pred.grad is not None + assert torch.isfinite(y_pred.grad).all() + + +def test_reconstruction_loss_module_can_be_used_in_training(): + """ReconstructionLoss should behave like a PyTorch criterion.""" + y_true = torch.zeros((4, 3), dtype=torch.float32) + y_pred = torch.randn((4, 3), dtype=torch.float32, requires_grad=True) + criterion = ReconstructionLoss(metric="mae", reduction="mean") + + loss = criterion(y_pred, y_true) + loss.backward() + + assert torch.is_tensor(loss) + assert loss.ndim == 0 + assert y_pred.grad is not None + assert torch.isfinite(y_pred.grad).all() + + +def test_reconstruction_metrics_validate_inputs(): + """Invalid options and shape mismatches should raise clear errors.""" + y_true = np.zeros((4, 3), dtype=np.float32) + y_pred = np.zeros((4, 2), dtype=np.float32) + + with pytest.raises(ValueError, match="same shape"): + reconstruction_error(y_true, y_pred) + + with pytest.raises(ValueError, match="metric"): + reconstruction_error(y_true, y_true, metric="invalid") + + with pytest.raises(ValueError, match="reduction"): + reconstruction_error(y_true, y_true, reduction="invalid")