aero_metrics.py¶

recipes/aero_cfd/src/aero_cfd/callbacks/aero_metrics.py

  1#  Copyright © 2026 Emmi AI GmbH. All rights reserved.
  2
  3from __future__ import annotations
  4
  5import math
  6from collections import defaultdict
  7from pathlib import Path
  8
  9import torch
 10from pydantic import Field, model_validator
 11
 12from noether.core.callbacks import PeriodicDataIteratorCallback, PeriodicDataIteratorCallbackConfig
 13from noether.core.utils.common.stopwatch import Stopwatch
 14
 15
 16class AeroMetricsCallbackConfig(PeriodicDataIteratorCallbackConfig):
 17    """Configuration for surface/volume evaluation metrics callback."""
 18
 19    kind: str | None = "aero_cfd.callbacks.AeroMetricsCallback"
 20
 21    forward_properties: list[str] = []
 22    """List of properties in the dataset to be forwarded during inference."""
 23    chunked_inference: bool = False
 24    """If True, perform inference in chunks over the full simulation geometry."""
 25    chunk_properties: list[str] = []
 26    """List of properties in the dataset to be chunked for chunked inference."""
 27    batch_size: int = Field(1)
 28    """Batch size for evaluation. Currently only batch_size=1 is supported."""
 29    chunk_size: int | None = None
 30    """Size of each chunk when performing chunked inference."""
 31    sample_size_property: str | None = Field(None)
 32    """Property in the batch to determine the sample size for chunking."""
 33    save_predictions: bool = False
 34    """If True, save denormalized predictions to disk during evaluation."""
 35    predictions_path: str | None = None
 36    """Directory to save per-sample prediction files. Required when save_predictions=True."""
 37    batch_properties_to_save: list[str] = []
 38    """Batch keys (e.g. position tensors) to save alongside predictions."""
 39    compute_forces: bool = False
 40    """If True, compute drag/lift coefficients per sample and log errors."""
 41    measure_inference_time: bool = False
 42    """If True, record per-sample model inference wall time (ms) and log a summary at the end."""
 43    inference_time_warmup_samples: int = 1
 44    """Number of leading samples to drop from inference-time stats (CUDA autotune, kernel
 45    compile, allocator growth on the first forward dominate the timing). Only used when
 46    ``measure_inference_time`` is True. Set to 0 to keep every sample."""
 47
 48    @model_validator(mode="after")
 49    def validate_config(self) -> AeroMetricsCallbackConfig:
 50        if self.batch_size != 1:
 51            raise ValueError("AeroMetricsCallback only supports batch_size=1")
 52        if self.save_predictions and self.predictions_path is None:
 53            raise ValueError("predictions_path must be specified when save_predictions=True")
 54        if self.chunked_inference:
 55            if self.chunk_size is None:
 56                raise ValueError("chunk_size must be specified when chunked_inference is True")
 57            if not self.forward_properties:
 58                raise ValueError("forward_properties must be specified when chunked_inference is True")
 59            if not self.chunk_properties:
 60                raise ValueError("chunk_properties must be specified when chunked_inference is True")
 61            if self.sample_size_property is None:
 62                raise ValueError("sample_size_property must be specified when chunked_inference is True")
 63        return self
 64
 65
 66# Constants
 67DEFAULT_EVALUATION_MODES = [
 68    "surface_pressure",
 69    "surface_friction",
 70    "volume_velocity",
 71    "volume_pressure",
 72    "volume_vorticity",
 73]
 74
 75METRIC_SUFFIX_TARGET = "_target"
 76METRIC_PREFIX_LOSS = "loss/"
 77
 78
 79class MetricType:
 80    """Metric type identifiers."""
 81
 82    MSE = "mse"
 83    MAE = "mae"
 84    L2ERR = "l2err"
 85
 86
 87class AeroMetricsCallback(PeriodicDataIteratorCallback):
 88    """Evaluation callback for aerodynamic surface and volume predictions.
 89
 90    Computes MSE, MAE, and relative L2 error metrics for physical fields
 91    (pressure, friction, velocity, vorticity) by running model inference on
 92    an evaluation dataset.  Supports chunked inference for memory efficiency.
 93
 94    When ``save_predictions=True``, denormalized predictions (and optionally
 95    batch properties such as positions) are saved to disk per-sample for
 96    downstream use (VTK export, force coefficient computation).
 97
 98    Args:
 99        callback_config: Configuration for the callback including dataset key,
100            forward properties, and chunking settings.
101        **kwargs: Additional arguments passed to parent class.
102
103    Attributes:
104        dataset_key: Identifier for the dataset to evaluate.
105        evaluation_modes: List of field names to evaluate.
106        dataset_normalizers: Normalizers for denormalizing predictions.
107        forward_properties: Properties to pass to model forward.
108        chunked_inference: Whether to use chunked inference.
109        chunk_properties: Properties to chunk.
110        chunk_size: Size of each chunk.
111        sample_size_property: Property to determine chunk count.
112    """
113
114    def __init__(self, callback_config: AeroMetricsCallbackConfig, **kwargs):
115        super().__init__(callback_config, **kwargs)
116
117        self._config = callback_config
118        self.dataset_key = callback_config.dataset_key
119        self.evaluation_modes = DEFAULT_EVALUATION_MODES
120        self.dataset_normalizers = self.data_container.get_dataset(self.dataset_key).normalizers
121        self.forward_properties = callback_config.forward_properties
122        self.chunked_inference = callback_config.chunked_inference
123        self.chunk_properties = callback_config.chunk_properties
124        self.chunk_size = callback_config.chunk_size
125        self.sample_size_property = callback_config.sample_size_property
126        self._save_predictions = callback_config.save_predictions
127        self._predictions_path = callback_config.predictions_path
128        self._prediction_counter: int = 0
129        self._measure_inference_time = callback_config.measure_inference_time
130        self._inference_time_warmup_samples = callback_config.inference_time_warmup_samples
131        self._compute_forces = callback_config.compute_forces
132        if self._compute_forces:
133            from scipy.spatial import cKDTree
134
135            from aero_cfd.utils.drag_lift import FlowConditions, compute_force_coefficients
136
137            self._cKDTree = cKDTree
138            self._FlowConditions = FlowConditions
139            self._compute_force_coefficients = compute_force_coefficients
140
141    def _compute_metrics(
142        self, denormalized_predictions: torch.Tensor, denormalized_targets: torch.Tensor, field_name: str
143    ) -> dict[str, torch.Tensor]:
144        """
145        Compute evaluation metrics for predictions vs targets.
146
147        Calculates Mean Squared Error (MSE), Mean Absolute Error (MAE),
148        and relative L2 error for the given field.
149
150        Args:
151            denormalized_predictions: Denormalized prediction tensor
152            denormalized_targets: Denormalized target tensor
153            field_name: Name of the field being evaluated (used for metric naming)
154
155        Returns:
156            Dictionary mapping metric names to computed values
157        """
158        delta = denormalized_predictions - denormalized_targets
159
160        metrics = {
161            f"{field_name}_{MetricType.MSE}": (delta**2).mean(),
162            f"{field_name}_{MetricType.MAE}": delta.abs().mean(),
163        }
164
165        # L2 relative error (avoid division by zero)
166        target_norm = denormalized_targets.norm()
167        if target_norm > 1e-8:
168            metrics[f"{field_name}_{MetricType.L2ERR}"] = delta.norm() / target_norm
169        else:
170            self.logger.warning(f"Target norm too small for {field_name}, skipping L2 error")
171
172        return metrics
173
174    def _create_chunked_batch(
175        self, batch: dict[str, torch.Tensor], start_idx: int, end_idx: int
176    ) -> dict[str, torch.Tensor]:
177        """
178        Create a batch slice for chunked processing.
179
180        Args:
181            batch: Full batch dictionary
182            start_idx: Start index for the chunk
183            end_idx: End index for the chunk
184
185        Returns:
186            Dictionary with chunked tensors for specified properties
187        """
188        chunked_batch = {}
189        for key, value in batch.items():
190            if key in self.chunk_properties:
191                chunked_batch[key] = value[:, start_idx:end_idx]
192            else:
193                chunked_batch[key] = value
194        return chunked_batch
195
196    def _get_chunk_indices(self, batch_size: int) -> list[tuple[int, int]]:
197        """
198        Calculate start and end indices for all chunks.
199
200        Args:
201            batch_size: Total size of the batch to chunk
202
203        Returns:
204            List of (start_idx, end_idx) tuples for each chunk
205        """
206        indices = []
207        num_chunks = math.ceil(batch_size / self.chunk_size)
208
209        for chunk_idx in range(num_chunks):
210            start = chunk_idx * self.chunk_size
211            end = min(start + self.chunk_size, batch_size)
212            indices.append((start, end))
213
214        return indices
215
216    def _chunked_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
217        """
218        Run model inference in chunks to reduce memory usage.
219
220        Splits the batch into smaller chunks, processes each independently,
221        and concatenates the results.
222
223        Args:
224            batch: Full batch dictionary
225
226        Returns:
227            Dictionary of model outputs with concatenated chunk results
228        """
229
230        batch_size = batch[self.sample_size_property].shape[1]
231        chunk_indices = self._get_chunk_indices(batch_size)
232
233        model_outputs = defaultdict(list)
234        for start_idx, end_idx in chunk_indices:
235            chunked_batch = self._create_chunked_batch(batch, start_idx, end_idx)
236            forward_inputs = {k: v for k, v in chunked_batch.items() if k in self.forward_properties}
237
238            with self.trainer.autocast_context:
239                chunked_outputs = self.model(**forward_inputs)
240
241            # Accumulate outputs
242            for key, value in chunked_outputs.items():
243                model_outputs[key].append(value)
244
245        # Concatenate all chunks
246        return {key: torch.cat(chunks, dim=1) for key, chunks in model_outputs.items()}
247
248    def _run_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
249        """
250        Run model inference, optionally in chunks.
251
252        Args:
253            batch: Input batch dictionary
254
255        Returns:
256            Dictionary of model outputs
257        """
258        if self.chunked_inference:
259            return self._chunked_model_inference(batch)
260        else:
261            forward_inputs = {k: v for k, v in batch.items() if k in self.forward_properties}
262            with self.trainer.autocast_context:
263                return self.model(**forward_inputs)
264
265    def _align_chunk_sizes(self, prediction: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
266        """
267        Align prediction and target sizes when using chunked inference.
268
269        Args:
270            prediction: Prediction tensor
271            target: Target tensor
272
273        Returns:
274            Tuple of (aligned_prediction, aligned_target)
275        """
276        if self.chunked_inference and prediction.shape[1] != target.shape[1]:
277            min_size = min(prediction.shape[1], target.shape[1])
278            prediction = prediction[:, :min_size]
279            target = target[:, :min_size]
280        return prediction, target
281
282    def _compute_mode_metrics(
283        self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor], mode: str
284    ) -> dict[str, torch.Tensor]:
285        """
286        Compute metrics for a specific evaluation mode.
287
288        Args:
289            batch: Input batch containing targets
290            model_outputs: Model predictions
291            mode: Evaluation mode (field name)
292
293        Returns:
294            Dictionary of computed metrics for this mode
295        """
296        target = batch.get(f"{mode}{METRIC_SUFFIX_TARGET}")
297        prediction = model_outputs.get(mode)
298
299        if prediction is None or target is None:
300            return {}
301
302        dataset = self.data_container.get_dataset(self.dataset_key)
303        denorm_pred = dataset.denormalize(mode, prediction)
304        denorm_target = dataset.denormalize(mode, target)
305
306        # Align sizes if needed
307        denorm_pred, denorm_target = self._align_chunk_sizes(denorm_pred, denorm_target)
308
309        # Compute metrics
310        return self._compute_metrics(denorm_pred, denorm_target, mode)
311
312    def _compute_force_metrics(
313        self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]
314    ) -> dict[str, torch.Tensor]:
315        """Compute drag/lift coefficient errors for the current sample.
316
317        Uses full-resolution mesh geometry from the batch (``surface_normals``,
318        ``surface_area``, ``surface_position``) and loads full-resolution GT
319        fields from disk (since batch targets are subsampled by the pipeline).
320        Predicted Cd/Cl uses denormalized model outputs matched to the mesh via
321        nearest-neighbor lookup.
322
323        Requires ``surface_normals``, ``surface_area``, and ``surface_position``
324        to be present in the batch. Enable these by removing them from
325        ``excluded_properties`` in the dataset config.
326        """
327        # Full-resolution mesh geometry from batch
328        surface_normals = batch.get("surface_normals")
329        surface_areas = batch.get("surface_area")
330        mesh_positions = batch.get("surface_position")
331
332        if surface_normals is None or surface_areas is None or mesh_positions is None:
333            self.logger.warning(
334                "Skipping force computation: surface_normals, surface_area, or surface_position "
335                "not in batch. Ensure these fields are not excluded in the dataset config."
336            )
337            return {}
338
339        surface_normals = surface_normals.cpu().squeeze(0).float()
340        surface_areas = surface_areas.cpu().squeeze(0).float()
341        mesh_positions = mesh_positions.cpu().squeeze(0).float()
342
343        # Ground-truth Cd/Cl from full-resolution dataset files.
344        # Batch targets are subsampled by the pipeline, so we load the originals.
345        dataset = self.data_container.get_dataset(self.dataset_key)
346        sample_idx = batch["index"].squeeze().item()
347        info = dataset.sample_info(sample_idx)
348        run_dir = Path(info["sample_uri"])
349
350        # Load per-run reference area if available, otherwise use defaults.
351        design_id = info["design_id"]
352        ref_csv = run_dir / f"geo_ref_{design_id}.csv"
353        if ref_csv.exists():
354            import pandas as pd
355
356            ref_area = float(pd.read_csv(ref_csv)["aRef"][0])
357            flow = self._FlowConditions(reference_area=ref_area)
358        else:
359            flow = self._FlowConditions()
360
361        gt_pressure_path = run_dir / "surface_pressure.pt"
362        gt_shear_path = run_dir / "surface_wallshearstress.pt"
363        if not gt_pressure_path.exists() or not gt_shear_path.exists():
364            self.logger.debug(f"Skipping GT force computation for sample {sample_idx}: missing GT files")
365            return {}
366
367        gt_pressure = torch.load(gt_pressure_path, map_location="cpu", weights_only=True).float()
368        gt_shear = torch.load(gt_shear_path, map_location="cpu", weights_only=True).float()
369        if gt_pressure.ndim == 2 and gt_pressure.shape[-1] == 1:
370            gt_pressure = gt_pressure.squeeze(-1)
371
372        gt_coeffs = self._compute_force_coefficients(gt_pressure, gt_shear, surface_normals, surface_areas, flow)
373
374        # Predicted Cd/Cl from model outputs (denormalized)
375        pred_pressure = model_outputs.get("surface_pressure")
376        pred_friction = model_outputs.get("surface_friction")
377        pred_positions = batch.get("surface_anchor_position")
378
379        if pred_pressure is None or pred_friction is None or pred_positions is None:
380            return {}
381
382        pred_pressure_denorm = self.dataset_normalizers["surface_pressure"].inverse(pred_pressure.cpu()).squeeze(0)
383        pred_friction_denorm = self.dataset_normalizers["surface_friction"].inverse(pred_friction.cpu()).squeeze(0)
384        pred_positions_cpu = pred_positions.cpu().squeeze(0)
385
386        if pred_pressure_denorm.ndim == 2 and pred_pressure_denorm.shape[-1] == 1:
387            pred_pressure_denorm = pred_pressure_denorm.squeeze(-1)
388
389        # Match predicted positions to mesh positions for normals/areas lookup
390        position_tree = self._cKDTree(mesh_positions.numpy())
391        _, matched_indices = position_tree.query(pred_positions_cpu.numpy())
392
393        pred_coeffs = self._compute_force_coefficients(
394            pred_pressure_denorm,
395            pred_friction_denorm,
396            surface_normals[matched_indices],
397            surface_areas[matched_indices],
398            flow,
399        )
400
401        return {
402            "drag_error": torch.tensor(abs(gt_coeffs.cd - pred_coeffs.cd)),
403            "lift_error": torch.tensor(abs(gt_coeffs.cl - pred_coeffs.cl)),
404        }
405
406    def _timed_model_inference(self, batch: dict[str, torch.Tensor]) -> tuple[dict[str, torch.Tensor], float]:
407        """Run ``_run_model_inference`` and return (outputs, elapsed_ms)."""
408        device = self.trainer.device if isinstance(self.trainer.device, torch.device) else None
409        with Stopwatch(device=device) as sw:
410            outputs = self._run_model_inference(batch)
411        return outputs, sw.elapsed_milliseconds
412
413    def process_data(self, batch: dict[str, torch.Tensor], **_) -> dict[str, torch.Tensor]:
414        """
415        Execute forward pass and compute metrics.
416
417        Args:
418            batch: Input batch dictionary
419            **_: Additional unused arguments
420
421        Returns:
422            Dictionary mapping metric names to computed values
423        """
424        if self._measure_inference_time:
425            model_outputs, elapsed_ms = self._timed_model_inference(batch)
426        else:
427            model_outputs = self._run_model_inference(batch)
428            elapsed_ms = None
429
430        metrics: dict[str, torch.Tensor] = {}
431        for mode in self.evaluation_modes:
432            metrics.update(self._compute_mode_metrics(batch, model_outputs, mode))
433
434        if self._compute_forces:
435            metrics.update(self._compute_force_metrics(batch, model_outputs))
436
437        if elapsed_ms is not None:
438            metrics["inference_time_ms"] = torch.tensor(elapsed_ms)
439
440        if self._save_predictions:
441            self._collect_predictions(batch, model_outputs)
442
443        return metrics
444
445    def _collect_predictions(self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]) -> None:
446        """Denormalize and save predictions (and batch properties) for the current sample.
447
448        Saves each sample to disk immediately to avoid accumulating large tensors in memory.
449        """
450        sample = {}
451        for mode in self.evaluation_modes:
452            prediction = model_outputs.get(mode)
453            if prediction is None:
454                continue
455            normalizer = self.dataset_normalizers.get(mode)
456            if normalizer is not None:
457                denorm = normalizer.inverse(prediction.cpu())
458            else:
459                denorm = prediction.cpu()
460            sample[mode] = denorm.squeeze(0)
461        for key in self._config.batch_properties_to_save:
462            if key in batch:
463                sample[key] = batch[key].cpu().squeeze(0)
464        if sample:
465            out_dir = Path(self._predictions_path)
466            out_dir.mkdir(parents=True, exist_ok=True)
467            idx = self._prediction_counter
468            torch.save(sample, out_dir / f"sample_{idx:04d}.pt")
469            self._prediction_counter += 1
470
471    def process_results(self, results: dict[str, torch.Tensor], **_) -> None:
472        """
473        Log computed metrics to writer and optionally save predictions.
474
475        Args:
476            results: Dictionary of computed metrics
477            **_: Additional unused arguments
478        """
479        if not results:
480            self.logger.warning(f"No metrics computed for dataset '{self.dataset_key}'")
481            return
482
483        for name, metric in results.items():
484            if name == "inference_time_ms":
485                continue  # handled below with warmup-sample trimming
486            metric_key = f"{METRIC_PREFIX_LOSS}{self.dataset_key}/{name}"
487            self.writer.add_scalar(
488                key=metric_key,
489                value=metric.mean(),
490                logger=self.logger,
491                format_str=".6f",
492            )
493
494        self.logger.debug(f"Logged {len(results)} metrics for dataset '{self.dataset_key}'")
495
496        if self._measure_inference_time:
497            times = results.get("inference_time_ms")
498            if times is not None and times.numel() > 0:
499                self._log_inference_time_summary(times.float())
500
501        if self._save_predictions and self._prediction_counter > 0:
502            self.logger.info(f"Saved {self._prediction_counter} prediction files to {self._predictions_path}")
503            self._prediction_counter = 0
504
505    def _log_inference_time_summary(self, times_ms: torch.Tensor) -> None:
506        """Log count, mean/std/median/min/max inference time over all samples.
507
508        Drops the first ``inference_time_warmup_samples`` values, which are
509        typically dominated by one-off setup cost (CUDA autotune, kernel compile,
510        allocator growth) on the initial forward pass.
511        """
512        warmup = min(self._inference_time_warmup_samples, times_ms.numel())
513        dropped = times_ms[:warmup]
514        kept = times_ms[warmup:]
515
516        if kept.numel() == 0:
517            self.logger.warning(
518                f"Inference-time summary skipped: all {times_ms.numel()} sample(s) dropped as warmup "
519                f"(inference_time_warmup_samples={self._inference_time_warmup_samples})."
520            )
521            return
522
523        n = kept.numel()
524        mean = float(kept.mean())
525        std = float(kept.std(unbiased=False)) if n > 1 else 0.0
526        median = float(kept.median())
527        tmin = float(kept.min())
528        tmax = float(kept.max())
529
530        warmup_note = ""
531        if warmup > 0:
532            warmup_note = f" (dropped {warmup} warmup sample(s): {', '.join(f'{float(x):.1f}ms' for x in dropped)})"
533
534        summary = (
535            f"Inference time on '{self.dataset_key}' over {n} sample(s): "
536            f"mean={mean:.2f}ms  std={std:.2f}ms  median={median:.2f}ms  "
537            f"min={tmin:.2f}ms  max={tmax:.2f}ms{warmup_note}"
538        )
539        self.logger.info(summary)
540
541        self.writer.add_scalar(
542            key=f"{METRIC_PREFIX_LOSS}{self.dataset_key}/inference_time_ms",
543            value=torch.tensor(mean),
544            logger=self.logger,
545            format_str=".6f",
546        )