aero_metrics.py¶
recipes/aero_cfd/src/aero_cfd/callbacks/aero_metrics.py
1# Copyright © 2026 Emmi AI GmbH. All rights reserved.
2
3from __future__ import annotations
4
5import math
6from collections import defaultdict
7from pathlib import Path
8
9import torch
10from pydantic import Field, model_validator
11
12from noether.core.callbacks import PeriodicDataIteratorCallback, PeriodicDataIteratorCallbackConfig
13from noether.core.utils.common.stopwatch import Stopwatch
14
15
16class AeroMetricsCallbackConfig(PeriodicDataIteratorCallbackConfig):
17 """Configuration for surface/volume evaluation metrics callback."""
18
19 kind: str | None = "aero_cfd.callbacks.AeroMetricsCallback"
20
21 forward_properties: list[str] = []
22 """List of properties in the dataset to be forwarded during inference."""
23 chunked_inference: bool = False
24 """If True, perform inference in chunks over the full simulation geometry."""
25 chunk_properties: list[str] = []
26 """List of properties in the dataset to be chunked for chunked inference."""
27 batch_size: int = Field(1)
28 """Batch size for evaluation. Currently only batch_size=1 is supported."""
29 chunk_size: int | None = None
30 """Size of each chunk when performing chunked inference."""
31 sample_size_property: str | None = Field(None)
32 """Property in the batch to determine the sample size for chunking."""
33 save_predictions: bool = False
34 """If True, save denormalized predictions to disk during evaluation."""
35 predictions_path: str | None = None
36 """Directory to save per-sample prediction files. Required when save_predictions=True."""
37 batch_properties_to_save: list[str] = []
38 """Batch keys (e.g. position tensors) to save alongside predictions."""
39 compute_forces: bool = False
40 """If True, compute drag/lift coefficients per sample and log errors."""
41 measure_inference_time: bool = False
42 """If True, record per-sample model inference wall time (ms) and log a summary at the end."""
43 inference_time_warmup_samples: int = 1
44 """Number of leading samples to drop from inference-time stats (CUDA autotune, kernel
45 compile, allocator growth on the first forward dominate the timing). Only used when
46 ``measure_inference_time`` is True. Set to 0 to keep every sample."""
47
48 @model_validator(mode="after")
49 def validate_config(self) -> AeroMetricsCallbackConfig:
50 if self.batch_size != 1:
51 raise ValueError("AeroMetricsCallback only supports batch_size=1")
52 if self.save_predictions and self.predictions_path is None:
53 raise ValueError("predictions_path must be specified when save_predictions=True")
54 if self.chunked_inference:
55 if self.chunk_size is None:
56 raise ValueError("chunk_size must be specified when chunked_inference is True")
57 if not self.forward_properties:
58 raise ValueError("forward_properties must be specified when chunked_inference is True")
59 if not self.chunk_properties:
60 raise ValueError("chunk_properties must be specified when chunked_inference is True")
61 if self.sample_size_property is None:
62 raise ValueError("sample_size_property must be specified when chunked_inference is True")
63 return self
64
65
66# Constants
67DEFAULT_EVALUATION_MODES = [
68 "surface_pressure",
69 "surface_friction",
70 "volume_velocity",
71 "volume_pressure",
72 "volume_vorticity",
73]
74
75METRIC_SUFFIX_TARGET = "_target"
76METRIC_PREFIX_LOSS = "loss/"
77
78
79class MetricType:
80 """Metric type identifiers."""
81
82 MSE = "mse"
83 MAE = "mae"
84 L2ERR = "l2err"
85
86
87class AeroMetricsCallback(PeriodicDataIteratorCallback):
88 """Evaluation callback for aerodynamic surface and volume predictions.
89
90 Computes MSE, MAE, and relative L2 error metrics for physical fields
91 (pressure, friction, velocity, vorticity) by running model inference on
92 an evaluation dataset. Supports chunked inference for memory efficiency.
93
94 When ``save_predictions=True``, denormalized predictions (and optionally
95 batch properties such as positions) are saved to disk per-sample for
96 downstream use (VTK export, force coefficient computation).
97
98 Args:
99 callback_config: Configuration for the callback including dataset key,
100 forward properties, and chunking settings.
101 **kwargs: Additional arguments passed to parent class.
102
103 Attributes:
104 dataset_key: Identifier for the dataset to evaluate.
105 evaluation_modes: List of field names to evaluate.
106 dataset_normalizers: Normalizers for denormalizing predictions.
107 forward_properties: Properties to pass to model forward.
108 chunked_inference: Whether to use chunked inference.
109 chunk_properties: Properties to chunk.
110 chunk_size: Size of each chunk.
111 sample_size_property: Property to determine chunk count.
112 """
113
114 def __init__(self, callback_config: AeroMetricsCallbackConfig, **kwargs):
115 super().__init__(callback_config, **kwargs)
116
117 self._config = callback_config
118 self.dataset_key = callback_config.dataset_key
119 self.evaluation_modes = DEFAULT_EVALUATION_MODES
120 self.dataset_normalizers = self.data_container.get_dataset(self.dataset_key).normalizers
121 self.forward_properties = callback_config.forward_properties
122 self.chunked_inference = callback_config.chunked_inference
123 self.chunk_properties = callback_config.chunk_properties
124 self.chunk_size = callback_config.chunk_size
125 self.sample_size_property = callback_config.sample_size_property
126 self._save_predictions = callback_config.save_predictions
127 self._predictions_path = callback_config.predictions_path
128 self._prediction_counter: int = 0
129 self._measure_inference_time = callback_config.measure_inference_time
130 self._inference_time_warmup_samples = callback_config.inference_time_warmup_samples
131 self._compute_forces = callback_config.compute_forces
132 if self._compute_forces:
133 from scipy.spatial import cKDTree
134
135 from aero_cfd.utils.drag_lift import FlowConditions, compute_force_coefficients
136
137 self._cKDTree = cKDTree
138 self._FlowConditions = FlowConditions
139 self._compute_force_coefficients = compute_force_coefficients
140
141 def _compute_metrics(
142 self, denormalized_predictions: torch.Tensor, denormalized_targets: torch.Tensor, field_name: str
143 ) -> dict[str, torch.Tensor]:
144 """
145 Compute evaluation metrics for predictions vs targets.
146
147 Calculates Mean Squared Error (MSE), Mean Absolute Error (MAE),
148 and relative L2 error for the given field.
149
150 Args:
151 denormalized_predictions: Denormalized prediction tensor
152 denormalized_targets: Denormalized target tensor
153 field_name: Name of the field being evaluated (used for metric naming)
154
155 Returns:
156 Dictionary mapping metric names to computed values
157 """
158 delta = denormalized_predictions - denormalized_targets
159
160 metrics = {
161 f"{field_name}_{MetricType.MSE}": (delta**2).mean(),
162 f"{field_name}_{MetricType.MAE}": delta.abs().mean(),
163 }
164
165 # L2 relative error (avoid division by zero)
166 target_norm = denormalized_targets.norm()
167 if target_norm > 1e-8:
168 metrics[f"{field_name}_{MetricType.L2ERR}"] = delta.norm() / target_norm
169 else:
170 self.logger.warning(f"Target norm too small for {field_name}, skipping L2 error")
171
172 return metrics
173
174 def _create_chunked_batch(
175 self, batch: dict[str, torch.Tensor], start_idx: int, end_idx: int
176 ) -> dict[str, torch.Tensor]:
177 """
178 Create a batch slice for chunked processing.
179
180 Args:
181 batch: Full batch dictionary
182 start_idx: Start index for the chunk
183 end_idx: End index for the chunk
184
185 Returns:
186 Dictionary with chunked tensors for specified properties
187 """
188 chunked_batch = {}
189 for key, value in batch.items():
190 if key in self.chunk_properties:
191 chunked_batch[key] = value[:, start_idx:end_idx]
192 else:
193 chunked_batch[key] = value
194 return chunked_batch
195
196 def _get_chunk_indices(self, batch_size: int) -> list[tuple[int, int]]:
197 """
198 Calculate start and end indices for all chunks.
199
200 Args:
201 batch_size: Total size of the batch to chunk
202
203 Returns:
204 List of (start_idx, end_idx) tuples for each chunk
205 """
206 indices = []
207 num_chunks = math.ceil(batch_size / self.chunk_size)
208
209 for chunk_idx in range(num_chunks):
210 start = chunk_idx * self.chunk_size
211 end = min(start + self.chunk_size, batch_size)
212 indices.append((start, end))
213
214 return indices
215
216 def _chunked_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
217 """
218 Run model inference in chunks to reduce memory usage.
219
220 Splits the batch into smaller chunks, processes each independently,
221 and concatenates the results.
222
223 Args:
224 batch: Full batch dictionary
225
226 Returns:
227 Dictionary of model outputs with concatenated chunk results
228 """
229
230 batch_size = batch[self.sample_size_property].shape[1]
231 chunk_indices = self._get_chunk_indices(batch_size)
232
233 model_outputs = defaultdict(list)
234 for start_idx, end_idx in chunk_indices:
235 chunked_batch = self._create_chunked_batch(batch, start_idx, end_idx)
236 forward_inputs = {k: v for k, v in chunked_batch.items() if k in self.forward_properties}
237
238 with self.trainer.autocast_context:
239 chunked_outputs = self.model(**forward_inputs)
240
241 # Accumulate outputs
242 for key, value in chunked_outputs.items():
243 model_outputs[key].append(value)
244
245 # Concatenate all chunks
246 return {key: torch.cat(chunks, dim=1) for key, chunks in model_outputs.items()}
247
248 def _run_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
249 """
250 Run model inference, optionally in chunks.
251
252 Args:
253 batch: Input batch dictionary
254
255 Returns:
256 Dictionary of model outputs
257 """
258 if self.chunked_inference:
259 return self._chunked_model_inference(batch)
260 else:
261 forward_inputs = {k: v for k, v in batch.items() if k in self.forward_properties}
262 with self.trainer.autocast_context:
263 return self.model(**forward_inputs)
264
265 def _align_chunk_sizes(self, prediction: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
266 """
267 Align prediction and target sizes when using chunked inference.
268
269 Args:
270 prediction: Prediction tensor
271 target: Target tensor
272
273 Returns:
274 Tuple of (aligned_prediction, aligned_target)
275 """
276 if self.chunked_inference and prediction.shape[1] != target.shape[1]:
277 min_size = min(prediction.shape[1], target.shape[1])
278 prediction = prediction[:, :min_size]
279 target = target[:, :min_size]
280 return prediction, target
281
282 def _compute_mode_metrics(
283 self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor], mode: str
284 ) -> dict[str, torch.Tensor]:
285 """
286 Compute metrics for a specific evaluation mode.
287
288 Args:
289 batch: Input batch containing targets
290 model_outputs: Model predictions
291 mode: Evaluation mode (field name)
292
293 Returns:
294 Dictionary of computed metrics for this mode
295 """
296 target = batch.get(f"{mode}{METRIC_SUFFIX_TARGET}")
297 prediction = model_outputs.get(mode)
298
299 if prediction is None or target is None:
300 return {}
301
302 dataset = self.data_container.get_dataset(self.dataset_key)
303 denorm_pred = dataset.denormalize(mode, prediction)
304 denorm_target = dataset.denormalize(mode, target)
305
306 # Align sizes if needed
307 denorm_pred, denorm_target = self._align_chunk_sizes(denorm_pred, denorm_target)
308
309 # Compute metrics
310 return self._compute_metrics(denorm_pred, denorm_target, mode)
311
312 def _compute_force_metrics(
313 self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]
314 ) -> dict[str, torch.Tensor]:
315 """Compute drag/lift coefficient errors for the current sample.
316
317 Uses full-resolution mesh geometry from the batch (``surface_normals``,
318 ``surface_area``, ``surface_position``) and loads full-resolution GT
319 fields from disk (since batch targets are subsampled by the pipeline).
320 Predicted Cd/Cl uses denormalized model outputs matched to the mesh via
321 nearest-neighbor lookup.
322
323 Requires ``surface_normals``, ``surface_area``, and ``surface_position``
324 to be present in the batch. Enable these by removing them from
325 ``excluded_properties`` in the dataset config.
326 """
327 # Full-resolution mesh geometry from batch
328 surface_normals = batch.get("surface_normals")
329 surface_areas = batch.get("surface_area")
330 mesh_positions = batch.get("surface_position")
331
332 if surface_normals is None or surface_areas is None or mesh_positions is None:
333 self.logger.warning(
334 "Skipping force computation: surface_normals, surface_area, or surface_position "
335 "not in batch. Ensure these fields are not excluded in the dataset config."
336 )
337 return {}
338
339 surface_normals = surface_normals.cpu().squeeze(0).float()
340 surface_areas = surface_areas.cpu().squeeze(0).float()
341 mesh_positions = mesh_positions.cpu().squeeze(0).float()
342
343 # Ground-truth Cd/Cl from full-resolution dataset files.
344 # Batch targets are subsampled by the pipeline, so we load the originals.
345 dataset = self.data_container.get_dataset(self.dataset_key)
346 sample_idx = batch["index"].squeeze().item()
347 info = dataset.sample_info(sample_idx)
348 run_dir = Path(info["sample_uri"])
349
350 # Load per-run reference area if available, otherwise use defaults.
351 design_id = info["design_id"]
352 ref_csv = run_dir / f"geo_ref_{design_id}.csv"
353 if ref_csv.exists():
354 import pandas as pd
355
356 ref_area = float(pd.read_csv(ref_csv)["aRef"][0])
357 flow = self._FlowConditions(reference_area=ref_area)
358 else:
359 flow = self._FlowConditions()
360
361 gt_pressure_path = run_dir / "surface_pressure.pt"
362 gt_shear_path = run_dir / "surface_wallshearstress.pt"
363 if not gt_pressure_path.exists() or not gt_shear_path.exists():
364 self.logger.debug(f"Skipping GT force computation for sample {sample_idx}: missing GT files")
365 return {}
366
367 gt_pressure = torch.load(gt_pressure_path, map_location="cpu", weights_only=True).float()
368 gt_shear = torch.load(gt_shear_path, map_location="cpu", weights_only=True).float()
369 if gt_pressure.ndim == 2 and gt_pressure.shape[-1] == 1:
370 gt_pressure = gt_pressure.squeeze(-1)
371
372 gt_coeffs = self._compute_force_coefficients(gt_pressure, gt_shear, surface_normals, surface_areas, flow)
373
374 # Predicted Cd/Cl from model outputs (denormalized)
375 pred_pressure = model_outputs.get("surface_pressure")
376 pred_friction = model_outputs.get("surface_friction")
377 pred_positions = batch.get("surface_anchor_position")
378
379 if pred_pressure is None or pred_friction is None or pred_positions is None:
380 return {}
381
382 pred_pressure_denorm = self.dataset_normalizers["surface_pressure"].inverse(pred_pressure.cpu()).squeeze(0)
383 pred_friction_denorm = self.dataset_normalizers["surface_friction"].inverse(pred_friction.cpu()).squeeze(0)
384 pred_positions_cpu = pred_positions.cpu().squeeze(0)
385
386 if pred_pressure_denorm.ndim == 2 and pred_pressure_denorm.shape[-1] == 1:
387 pred_pressure_denorm = pred_pressure_denorm.squeeze(-1)
388
389 # Match predicted positions to mesh positions for normals/areas lookup
390 position_tree = self._cKDTree(mesh_positions.numpy())
391 _, matched_indices = position_tree.query(pred_positions_cpu.numpy())
392
393 pred_coeffs = self._compute_force_coefficients(
394 pred_pressure_denorm,
395 pred_friction_denorm,
396 surface_normals[matched_indices],
397 surface_areas[matched_indices],
398 flow,
399 )
400
401 return {
402 "drag_error": torch.tensor(abs(gt_coeffs.cd - pred_coeffs.cd)),
403 "lift_error": torch.tensor(abs(gt_coeffs.cl - pred_coeffs.cl)),
404 }
405
406 def _timed_model_inference(self, batch: dict[str, torch.Tensor]) -> tuple[dict[str, torch.Tensor], float]:
407 """Run ``_run_model_inference`` and return (outputs, elapsed_ms)."""
408 device = self.trainer.device if isinstance(self.trainer.device, torch.device) else None
409 with Stopwatch(device=device) as sw:
410 outputs = self._run_model_inference(batch)
411 return outputs, sw.elapsed_milliseconds
412
413 def process_data(self, batch: dict[str, torch.Tensor], **_) -> dict[str, torch.Tensor]:
414 """
415 Execute forward pass and compute metrics.
416
417 Args:
418 batch: Input batch dictionary
419 **_: Additional unused arguments
420
421 Returns:
422 Dictionary mapping metric names to computed values
423 """
424 if self._measure_inference_time:
425 model_outputs, elapsed_ms = self._timed_model_inference(batch)
426 else:
427 model_outputs = self._run_model_inference(batch)
428 elapsed_ms = None
429
430 metrics: dict[str, torch.Tensor] = {}
431 for mode in self.evaluation_modes:
432 metrics.update(self._compute_mode_metrics(batch, model_outputs, mode))
433
434 if self._compute_forces:
435 metrics.update(self._compute_force_metrics(batch, model_outputs))
436
437 if elapsed_ms is not None:
438 metrics["inference_time_ms"] = torch.tensor(elapsed_ms)
439
440 if self._save_predictions:
441 self._collect_predictions(batch, model_outputs)
442
443 return metrics
444
445 def _collect_predictions(self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]) -> None:
446 """Denormalize and save predictions (and batch properties) for the current sample.
447
448 Saves each sample to disk immediately to avoid accumulating large tensors in memory.
449 """
450 sample = {}
451 for mode in self.evaluation_modes:
452 prediction = model_outputs.get(mode)
453 if prediction is None:
454 continue
455 normalizer = self.dataset_normalizers.get(mode)
456 if normalizer is not None:
457 denorm = normalizer.inverse(prediction.cpu())
458 else:
459 denorm = prediction.cpu()
460 sample[mode] = denorm.squeeze(0)
461 for key in self._config.batch_properties_to_save:
462 if key in batch:
463 sample[key] = batch[key].cpu().squeeze(0)
464 if sample:
465 out_dir = Path(self._predictions_path)
466 out_dir.mkdir(parents=True, exist_ok=True)
467 idx = self._prediction_counter
468 torch.save(sample, out_dir / f"sample_{idx:04d}.pt")
469 self._prediction_counter += 1
470
471 def process_results(self, results: dict[str, torch.Tensor], **_) -> None:
472 """
473 Log computed metrics to writer and optionally save predictions.
474
475 Args:
476 results: Dictionary of computed metrics
477 **_: Additional unused arguments
478 """
479 if not results:
480 self.logger.warning(f"No metrics computed for dataset '{self.dataset_key}'")
481 return
482
483 for name, metric in results.items():
484 if name == "inference_time_ms":
485 continue # handled below with warmup-sample trimming
486 metric_key = f"{METRIC_PREFIX_LOSS}{self.dataset_key}/{name}"
487 self.writer.add_scalar(
488 key=metric_key,
489 value=metric.mean(),
490 logger=self.logger,
491 format_str=".6f",
492 )
493
494 self.logger.debug(f"Logged {len(results)} metrics for dataset '{self.dataset_key}'")
495
496 if self._measure_inference_time:
497 times = results.get("inference_time_ms")
498 if times is not None and times.numel() > 0:
499 self._log_inference_time_summary(times.float())
500
501 if self._save_predictions and self._prediction_counter > 0:
502 self.logger.info(f"Saved {self._prediction_counter} prediction files to {self._predictions_path}")
503 self._prediction_counter = 0
504
505 def _log_inference_time_summary(self, times_ms: torch.Tensor) -> None:
506 """Log count, mean/std/median/min/max inference time over all samples.
507
508 Drops the first ``inference_time_warmup_samples`` values, which are
509 typically dominated by one-off setup cost (CUDA autotune, kernel compile,
510 allocator growth) on the initial forward pass.
511 """
512 warmup = min(self._inference_time_warmup_samples, times_ms.numel())
513 dropped = times_ms[:warmup]
514 kept = times_ms[warmup:]
515
516 if kept.numel() == 0:
517 self.logger.warning(
518 f"Inference-time summary skipped: all {times_ms.numel()} sample(s) dropped as warmup "
519 f"(inference_time_warmup_samples={self._inference_time_warmup_samples})."
520 )
521 return
522
523 n = kept.numel()
524 mean = float(kept.mean())
525 std = float(kept.std(unbiased=False)) if n > 1 else 0.0
526 median = float(kept.median())
527 tmin = float(kept.min())
528 tmax = float(kept.max())
529
530 warmup_note = ""
531 if warmup > 0:
532 warmup_note = f" (dropped {warmup} warmup sample(s): {', '.join(f'{float(x):.1f}ms' for x in dropped)})"
533
534 summary = (
535 f"Inference time on '{self.dataset_key}' over {n} sample(s): "
536 f"mean={mean:.2f}ms std={std:.2f}ms median={median:.2f}ms "
537 f"min={tmin:.2f}ms max={tmax:.2f}ms{warmup_note}"
538 )
539 self.logger.info(summary)
540
541 self.writer.add_scalar(
542 key=f"{METRIC_PREFIX_LOSS}{self.dataset_key}/inference_time_ms",
543 value=torch.tensor(mean),
544 logger=self.logger,
545 format_str=".6f",
546 )