aero_metrics.py¶

recipes/aero_cfd/callbacks/aero_metrics.py

  1#  Copyright © 2026 Emmi AI GmbH. All rights reserved.
  2
  3from __future__ import annotations
  4
  5import math
  6from collections import defaultdict
  7from pathlib import Path
  8
  9import torch
 10from pydantic import Field, model_validator
 11
 12from noether.core.callbacks.periodic import PeriodicDataIteratorCallback
 13from noether.core.schemas.callbacks import PeriodicDataIteratorCallbackConfig
 14
 15
 16class AeroMetricsCallbackConfig(PeriodicDataIteratorCallbackConfig):
 17    """Configuration for surface/volume evaluation metrics callback."""
 18
 19    kind: str | None = "aero_cfd.callbacks.AeroMetricsCallback"
 20
 21    forward_properties: list[str] = []
 22    """List of properties in the dataset to be forwarded during inference."""
 23    chunked_inference: bool = False
 24    """If True, perform inference in chunks over the full simulation geometry."""
 25    chunk_properties: list[str] = []
 26    """List of properties in the dataset to be chunked for chunked inference."""
 27    batch_size: int = Field(1)
 28    """Batch size for evaluation. Currently only batch_size=1 is supported."""
 29    chunk_size: int | None = None
 30    """Size of each chunk when performing chunked inference."""
 31    sample_size_property: str | None = Field(None)
 32    """Property in the batch to determine the sample size for chunking."""
 33    save_predictions: bool = False
 34    """If True, save denormalized predictions to disk during evaluation."""
 35    predictions_path: str | None = None
 36    """Directory to save per-sample prediction files. Required when save_predictions=True."""
 37    batch_properties_to_save: list[str] = []
 38    """Batch keys (e.g. position tensors) to save alongside predictions."""
 39    compute_forces: bool = False
 40    """If True, compute drag/lift coefficients per sample and log errors."""
 41
 42    @model_validator(mode="after")
 43    def validate_config(self) -> AeroMetricsCallbackConfig:
 44        if self.batch_size != 1:
 45            raise ValueError("AeroMetricsCallback only supports batch_size=1")
 46        if self.save_predictions and self.predictions_path is None:
 47            raise ValueError("predictions_path must be specified when save_predictions=True")
 48        if self.chunked_inference:
 49            if self.chunk_size is None:
 50                raise ValueError("chunk_size must be specified when chunked_inference is True")
 51            if not self.forward_properties:
 52                raise ValueError("forward_properties must be specified when chunked_inference is True")
 53            if not self.chunk_properties:
 54                raise ValueError("chunk_properties must be specified when chunked_inference is True")
 55        return self
 56
 57
 58# Constants
 59DEFAULT_EVALUATION_MODES = [
 60    "surface_pressure",
 61    "surface_friction",
 62    "volume_velocity",
 63    "volume_pressure",
 64    "volume_vorticity",
 65]
 66
 67METRIC_SUFFIX_TARGET = "_target"
 68METRIC_PREFIX_LOSS = "loss/"
 69
 70
 71class MetricType:
 72    """Metric type identifiers."""
 73
 74    MSE = "mse"
 75    MAE = "mae"
 76    L2ERR = "l2err"
 77
 78
 79class AeroMetricsCallback(PeriodicDataIteratorCallback):
 80    """Evaluation callback for aerodynamic surface and volume predictions.
 81
 82    Computes MSE, MAE, and relative L2 error metrics for physical fields
 83    (pressure, friction, velocity, vorticity) by running model inference on
 84    an evaluation dataset.  Supports chunked inference for memory efficiency.
 85
 86    When ``save_predictions=True``, denormalized predictions (and optionally
 87    batch properties such as positions) are saved to disk per-sample for
 88    downstream use (VTK export, force coefficient computation).
 89
 90    Args:
 91        callback_config: Configuration for the callback including dataset key,
 92            forward properties, and chunking settings.
 93        **kwargs: Additional arguments passed to parent class.
 94
 95    Attributes:
 96        dataset_key: Identifier for the dataset to evaluate.
 97        evaluation_modes: List of field names to evaluate.
 98        dataset_normalizers: Normalizers for denormalizing predictions.
 99        forward_properties: Properties to pass to model forward.
100        chunked_inference: Whether to use chunked inference.
101        chunk_properties: Properties to chunk.
102        chunk_size: Size of each chunk.
103        sample_size_property: Property to determine chunk count.
104    """
105
106    def __init__(self, callback_config: AeroMetricsCallbackConfig, **kwargs):
107        super().__init__(callback_config, **kwargs)
108
109        self._config = callback_config
110        self.dataset_key = callback_config.dataset_key
111        self.evaluation_modes = DEFAULT_EVALUATION_MODES
112        self.dataset_normalizers = self.data_container.get_dataset(self.dataset_key).normalizers
113        self.forward_properties = callback_config.forward_properties
114        self.chunked_inference = callback_config.chunked_inference
115        self.chunk_properties = callback_config.chunk_properties
116        self.chunk_size = callback_config.chunk_size
117        self.sample_size_property = callback_config.sample_size_property
118        self._save_predictions = callback_config.save_predictions
119        self._predictions_path = callback_config.predictions_path
120        self._prediction_counter: int = 0
121        self._compute_forces = callback_config.compute_forces
122        if self._compute_forces:
123            from scipy.spatial import cKDTree
124
125            from aero_cfd.utils.drag_lift import FlowConditions, compute_force_coefficients
126
127            self._cKDTree = cKDTree
128            self._FlowConditions = FlowConditions
129            self._compute_force_coefficients = compute_force_coefficients
130
131    def _compute_metrics(
132        self, denormalized_predictions: torch.Tensor, denormalized_targets: torch.Tensor, field_name: str
133    ) -> dict[str, torch.Tensor]:
134        """
135        Compute evaluation metrics for predictions vs targets.
136
137        Calculates Mean Squared Error (MSE), Mean Absolute Error (MAE),
138        and relative L2 error for the given field.
139
140        Args:
141            denormalized_predictions: Denormalized prediction tensor
142            denormalized_targets: Denormalized target tensor
143            field_name: Name of the field being evaluated (used for metric naming)
144
145        Returns:
146            Dictionary mapping metric names to computed values
147        """
148        delta = denormalized_predictions - denormalized_targets
149
150        metrics = {
151            f"{field_name}_{MetricType.MSE}": (delta**2).mean(),
152            f"{field_name}_{MetricType.MAE}": delta.abs().mean(),
153        }
154
155        # L2 relative error (avoid division by zero)
156        target_norm = denormalized_targets.norm()
157        if target_norm > 1e-8:
158            metrics[f"{field_name}_{MetricType.L2ERR}"] = delta.norm() / target_norm
159        else:
160            self.logger.warning(f"Target norm too small for {field_name}, skipping L2 error")
161
162        return metrics
163
164    def _create_chunked_batch(
165        self, batch: dict[str, torch.Tensor], start_idx: int, end_idx: int
166    ) -> dict[str, torch.Tensor]:
167        """
168        Create a batch slice for chunked processing.
169
170        Args:
171            batch: Full batch dictionary
172            start_idx: Start index for the chunk
173            end_idx: End index for the chunk
174
175        Returns:
176            Dictionary with chunked tensors for specified properties
177        """
178        chunked_batch = {}
179        for key, value in batch.items():
180            if key in self.chunk_properties:
181                chunked_batch[key] = value[:, start_idx:end_idx]
182            else:
183                chunked_batch[key] = value
184        return chunked_batch
185
186    def _get_chunk_indices(self, batch_size: int) -> list[tuple[int, int]]:
187        """
188        Calculate start and end indices for all chunks.
189
190        Args:
191            batch_size: Total size of the batch to chunk
192
193        Returns:
194            List of (start_idx, end_idx) tuples for each chunk
195        """
196        indices = []
197        num_chunks = math.ceil(batch_size / self.chunk_size)
198
199        for chunk_idx in range(num_chunks):
200            start = chunk_idx * self.chunk_size
201            end = min(start + self.chunk_size, batch_size)
202            indices.append((start, end))
203
204        return indices
205
206    def _chunked_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
207        """
208        Run model inference in chunks to reduce memory usage.
209
210        Splits the batch into smaller chunks, processes each independently,
211        and concatenates the results.
212
213        Args:
214            batch: Full batch dictionary
215
216        Returns:
217            Dictionary of model outputs with concatenated chunk results
218        """
219
220        batch_size = batch[self.sample_size_property].shape[1]
221        chunk_indices = self._get_chunk_indices(batch_size)
222
223        model_outputs = defaultdict(list)
224        for start_idx, end_idx in chunk_indices:
225            chunked_batch = self._create_chunked_batch(batch, start_idx, end_idx)
226            forward_inputs = {k: v for k, v in chunked_batch.items() if k in self.forward_properties}
227
228            with self.trainer.autocast_context:
229                chunked_outputs = self.model(**forward_inputs)
230
231            # Accumulate outputs
232            for key, value in chunked_outputs.items():
233                model_outputs[key].append(value)
234
235        # Concatenate all chunks
236        return {key: torch.cat(chunks, dim=1) for key, chunks in model_outputs.items()}
237
238    def _run_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
239        """
240        Run model inference, optionally in chunks.
241
242        Args:
243            batch: Input batch dictionary
244
245        Returns:
246            Dictionary of model outputs
247        """
248        if self.chunked_inference:
249            return self._chunked_model_inference(batch)
250        else:
251            forward_inputs = {k: v for k, v in batch.items() if k in self.forward_properties}
252            with self.trainer.autocast_context:
253                return self.model(**forward_inputs)
254
255    def _align_chunk_sizes(self, prediction: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
256        """
257        Align prediction and target sizes when using chunked inference.
258
259        Args:
260            prediction: Prediction tensor
261            target: Target tensor
262
263        Returns:
264            Tuple of (aligned_prediction, aligned_target)
265        """
266        if self.chunked_inference and prediction.shape[1] != target.shape[1]:
267            min_size = min(prediction.shape[1], target.shape[1])
268            prediction = prediction[:, :min_size]
269            target = target[:, :min_size]
270        return prediction, target
271
272    def _compute_mode_metrics(
273        self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor], mode: str
274    ) -> dict[str, torch.Tensor]:
275        """
276        Compute metrics for a specific evaluation mode.
277
278        Args:
279            batch: Input batch containing targets
280            model_outputs: Model predictions
281            mode: Evaluation mode (field name)
282
283        Returns:
284            Dictionary of computed metrics for this mode
285        """
286        target = batch.get(f"{mode}{METRIC_SUFFIX_TARGET}")
287        prediction = model_outputs.get(mode)
288
289        if prediction is None or target is None:
290            return {}
291
292        dataset = self.data_container.get_dataset(self.dataset_key)
293        denorm_pred = dataset.denormalize(mode, prediction)
294        denorm_target = dataset.denormalize(mode, target)
295
296        # Align sizes if needed
297        denorm_pred, denorm_target = self._align_chunk_sizes(denorm_pred, denorm_target)
298
299        # Compute metrics
300        return self._compute_metrics(denorm_pred, denorm_target, mode)
301
302    def _compute_force_metrics(
303        self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]
304    ) -> dict[str, torch.Tensor]:
305        """Compute drag/lift coefficient errors for the current sample.
306
307        Uses full-resolution mesh geometry from the batch (``surface_normals``,
308        ``surface_area``, ``surface_position``) and loads full-resolution GT
309        fields from disk (since batch targets are subsampled by the pipeline).
310        Predicted Cd/Cl uses denormalized model outputs matched to the mesh via
311        nearest-neighbor lookup.
312
313        Requires ``surface_normals``, ``surface_area``, and ``surface_position``
314        to be present in the batch. Enable these by removing them from
315        ``excluded_properties`` in the dataset config.
316        """
317        # Full-resolution mesh geometry from batch
318        surface_normals = batch.get("surface_normals")
319        surface_areas = batch.get("surface_area")
320        mesh_positions = batch.get("surface_position")
321
322        if surface_normals is None or surface_areas is None or mesh_positions is None:
323            self.logger.warning(
324                "Skipping force computation: surface_normals, surface_area, or surface_position "
325                "not in batch. Ensure these fields are not excluded in the dataset config."
326            )
327            return {}
328
329        surface_normals = surface_normals.cpu().squeeze(0).float()
330        surface_areas = surface_areas.cpu().squeeze(0).float()
331        mesh_positions = mesh_positions.cpu().squeeze(0).float()
332
333        # Ground-truth Cd/Cl from full-resolution dataset files.
334        # Batch targets are subsampled by the pipeline, so we load the originals.
335        dataset = self.data_container.get_dataset(self.dataset_key)
336        sample_idx = batch["index"].squeeze().item()
337        info = dataset.sample_info(sample_idx)
338        run_dir = Path(info["sample_uri"])
339
340        # Load per-run reference area if available, otherwise use defaults.
341        design_id = info["design_id"]
342        ref_csv = run_dir / f"geo_ref_{design_id}.csv"
343        if ref_csv.exists():
344            import pandas as pd
345
346            ref_area = float(pd.read_csv(ref_csv)["aRef"][0])
347            flow = self._FlowConditions(reference_area=ref_area)
348        else:
349            flow = self._FlowConditions()
350
351        gt_pressure_path = run_dir / "surface_pressure.pt"
352        gt_shear_path = run_dir / "surface_wallshearstress.pt"
353        if not gt_pressure_path.exists() or not gt_shear_path.exists():
354            self.logger.debug(f"Skipping GT force computation for sample {sample_idx}: missing GT files")
355            return {}
356
357        gt_pressure = torch.load(gt_pressure_path, map_location="cpu", weights_only=True).float()
358        gt_shear = torch.load(gt_shear_path, map_location="cpu", weights_only=True).float()
359        if gt_pressure.ndim == 2 and gt_pressure.shape[-1] == 1:
360            gt_pressure = gt_pressure.squeeze(-1)
361
362        gt_coeffs = self._compute_force_coefficients(gt_pressure, gt_shear, surface_normals, surface_areas, flow)
363
364        # Predicted Cd/Cl from model outputs (denormalized)
365        pred_pressure = model_outputs.get("surface_pressure")
366        pred_friction = model_outputs.get("surface_friction")
367        pred_positions = batch.get("surface_anchor_position")
368
369        if pred_pressure is None or pred_friction is None or pred_positions is None:
370            return {}
371
372        pred_pressure_denorm = self.dataset_normalizers["surface_pressure"].inverse(pred_pressure.cpu()).squeeze(0)
373        pred_friction_denorm = self.dataset_normalizers["surface_friction"].inverse(pred_friction.cpu()).squeeze(0)
374        pred_positions_cpu = pred_positions.cpu().squeeze(0)
375
376        if pred_pressure_denorm.ndim == 2 and pred_pressure_denorm.shape[-1] == 1:
377            pred_pressure_denorm = pred_pressure_denorm.squeeze(-1)
378
379        # Match predicted positions to mesh positions for normals/areas lookup
380        position_tree = self._cKDTree(mesh_positions.numpy())
381        _, matched_indices = position_tree.query(pred_positions_cpu.numpy())
382
383        pred_coeffs = self._compute_force_coefficients(
384            pred_pressure_denorm,
385            pred_friction_denorm,
386            surface_normals[matched_indices],
387            surface_areas[matched_indices],
388            flow,
389        )
390
391        return {
392            "drag_error": torch.tensor(abs(gt_coeffs.cd - pred_coeffs.cd)),
393            "lift_error": torch.tensor(abs(gt_coeffs.cl - pred_coeffs.cl)),
394        }
395
396    def process_data(self, batch: dict[str, torch.Tensor], **_) -> dict[str, torch.Tensor]:
397        """
398        Execute forward pass and compute metrics.
399
400        Args:
401            batch: Input batch dictionary
402            **_: Additional unused arguments
403
404        Returns:
405            Dictionary mapping metric names to computed values
406        """
407        model_outputs = self._run_model_inference(batch)
408
409        metrics = {}
410        for mode in self.evaluation_modes:
411            metrics.update(self._compute_mode_metrics(batch, model_outputs, mode))
412
413        if self._compute_forces:
414            metrics.update(self._compute_force_metrics(batch, model_outputs))
415
416        if self._save_predictions:
417            self._collect_predictions(batch, model_outputs)
418
419        return metrics
420
421    def _collect_predictions(self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]) -> None:
422        """Denormalize and save predictions (and batch properties) for the current sample.
423
424        Saves each sample to disk immediately to avoid accumulating large tensors in memory.
425        """
426        sample = {}
427        for mode in self.evaluation_modes:
428            prediction = model_outputs.get(mode)
429            if prediction is None:
430                continue
431            normalizer = self.dataset_normalizers.get(mode)
432            if normalizer is not None:
433                denorm = normalizer.inverse(prediction.cpu())
434            else:
435                denorm = prediction.cpu()
436            sample[mode] = denorm.squeeze(0)
437        for key in self._config.batch_properties_to_save:
438            if key in batch:
439                sample[key] = batch[key].cpu().squeeze(0)
440        if sample:
441            out_dir = Path(self._predictions_path)
442            out_dir.mkdir(parents=True, exist_ok=True)
443            idx = self._prediction_counter
444            torch.save(sample, out_dir / f"sample_{idx:04d}.pt")
445            self._prediction_counter += 1
446
447    def process_results(self, results: dict[str, torch.Tensor], **_) -> None:
448        """
449        Log computed metrics to writer and optionally save predictions.
450
451        Args:
452            results: Dictionary of computed metrics
453            **_: Additional unused arguments
454        """
455        if not results:
456            self.logger.warning(f"No metrics computed for dataset '{self.dataset_key}'")
457            return
458
459        for name, metric in results.items():
460            metric_key = f"{METRIC_PREFIX_LOSS}{self.dataset_key}/{name}"
461            self.writer.add_scalar(
462                key=metric_key,
463                value=metric.mean(),
464                logger=self.logger,
465                format_str=".6f",
466            )
467
468        self.logger.debug(f"Logged {len(results)} metrics for dataset '{self.dataset_key}'")
469
470        if self._save_predictions and self._prediction_counter > 0:
471            self.logger.info(f"Saved {self._prediction_counter} prediction files to {self._predictions_path}")
472            self._prediction_counter = 0