aero_metrics.py¶
recipes/aero_cfd/callbacks/aero_metrics.py
1# Copyright © 2026 Emmi AI GmbH. All rights reserved.
2
3from __future__ import annotations
4
5import math
6from collections import defaultdict
7from pathlib import Path
8
9import torch
10from pydantic import Field, model_validator
11
12from noether.core.callbacks.periodic import PeriodicDataIteratorCallback
13from noether.core.schemas.callbacks import PeriodicDataIteratorCallbackConfig
14
15
16class AeroMetricsCallbackConfig(PeriodicDataIteratorCallbackConfig):
17 """Configuration for surface/volume evaluation metrics callback."""
18
19 kind: str | None = "aero_cfd.callbacks.AeroMetricsCallback"
20
21 forward_properties: list[str] = []
22 """List of properties in the dataset to be forwarded during inference."""
23 chunked_inference: bool = False
24 """If True, perform inference in chunks over the full simulation geometry."""
25 chunk_properties: list[str] = []
26 """List of properties in the dataset to be chunked for chunked inference."""
27 batch_size: int = Field(1)
28 """Batch size for evaluation. Currently only batch_size=1 is supported."""
29 chunk_size: int | None = None
30 """Size of each chunk when performing chunked inference."""
31 sample_size_property: str | None = Field(None)
32 """Property in the batch to determine the sample size for chunking."""
33 save_predictions: bool = False
34 """If True, save denormalized predictions to disk during evaluation."""
35 predictions_path: str | None = None
36 """Directory to save per-sample prediction files. Required when save_predictions=True."""
37 batch_properties_to_save: list[str] = []
38 """Batch keys (e.g. position tensors) to save alongside predictions."""
39 compute_forces: bool = False
40 """If True, compute drag/lift coefficients per sample and log errors."""
41
42 @model_validator(mode="after")
43 def validate_config(self) -> AeroMetricsCallbackConfig:
44 if self.batch_size != 1:
45 raise ValueError("AeroMetricsCallback only supports batch_size=1")
46 if self.save_predictions and self.predictions_path is None:
47 raise ValueError("predictions_path must be specified when save_predictions=True")
48 if self.chunked_inference:
49 if self.chunk_size is None:
50 raise ValueError("chunk_size must be specified when chunked_inference is True")
51 if not self.forward_properties:
52 raise ValueError("forward_properties must be specified when chunked_inference is True")
53 if not self.chunk_properties:
54 raise ValueError("chunk_properties must be specified when chunked_inference is True")
55 return self
56
57
58# Constants
59DEFAULT_EVALUATION_MODES = [
60 "surface_pressure",
61 "surface_friction",
62 "volume_velocity",
63 "volume_pressure",
64 "volume_vorticity",
65]
66
67METRIC_SUFFIX_TARGET = "_target"
68METRIC_PREFIX_LOSS = "loss/"
69
70
71class MetricType:
72 """Metric type identifiers."""
73
74 MSE = "mse"
75 MAE = "mae"
76 L2ERR = "l2err"
77
78
79class AeroMetricsCallback(PeriodicDataIteratorCallback):
80 """Evaluation callback for aerodynamic surface and volume predictions.
81
82 Computes MSE, MAE, and relative L2 error metrics for physical fields
83 (pressure, friction, velocity, vorticity) by running model inference on
84 an evaluation dataset. Supports chunked inference for memory efficiency.
85
86 When ``save_predictions=True``, denormalized predictions (and optionally
87 batch properties such as positions) are saved to disk per-sample for
88 downstream use (VTK export, force coefficient computation).
89
90 Args:
91 callback_config: Configuration for the callback including dataset key,
92 forward properties, and chunking settings.
93 **kwargs: Additional arguments passed to parent class.
94
95 Attributes:
96 dataset_key: Identifier for the dataset to evaluate.
97 evaluation_modes: List of field names to evaluate.
98 dataset_normalizers: Normalizers for denormalizing predictions.
99 forward_properties: Properties to pass to model forward.
100 chunked_inference: Whether to use chunked inference.
101 chunk_properties: Properties to chunk.
102 chunk_size: Size of each chunk.
103 sample_size_property: Property to determine chunk count.
104 """
105
106 def __init__(self, callback_config: AeroMetricsCallbackConfig, **kwargs):
107 super().__init__(callback_config, **kwargs)
108
109 self._config = callback_config
110 self.dataset_key = callback_config.dataset_key
111 self.evaluation_modes = DEFAULT_EVALUATION_MODES
112 self.dataset_normalizers = self.data_container.get_dataset(self.dataset_key).normalizers
113 self.forward_properties = callback_config.forward_properties
114 self.chunked_inference = callback_config.chunked_inference
115 self.chunk_properties = callback_config.chunk_properties
116 self.chunk_size = callback_config.chunk_size
117 self.sample_size_property = callback_config.sample_size_property
118 self._save_predictions = callback_config.save_predictions
119 self._predictions_path = callback_config.predictions_path
120 self._prediction_counter: int = 0
121 self._compute_forces = callback_config.compute_forces
122 if self._compute_forces:
123 from scipy.spatial import cKDTree
124
125 from aero_cfd.utils.drag_lift import FlowConditions, compute_force_coefficients
126
127 self._cKDTree = cKDTree
128 self._FlowConditions = FlowConditions
129 self._compute_force_coefficients = compute_force_coefficients
130
131 def _compute_metrics(
132 self, denormalized_predictions: torch.Tensor, denormalized_targets: torch.Tensor, field_name: str
133 ) -> dict[str, torch.Tensor]:
134 """
135 Compute evaluation metrics for predictions vs targets.
136
137 Calculates Mean Squared Error (MSE), Mean Absolute Error (MAE),
138 and relative L2 error for the given field.
139
140 Args:
141 denormalized_predictions: Denormalized prediction tensor
142 denormalized_targets: Denormalized target tensor
143 field_name: Name of the field being evaluated (used for metric naming)
144
145 Returns:
146 Dictionary mapping metric names to computed values
147 """
148 delta = denormalized_predictions - denormalized_targets
149
150 metrics = {
151 f"{field_name}_{MetricType.MSE}": (delta**2).mean(),
152 f"{field_name}_{MetricType.MAE}": delta.abs().mean(),
153 }
154
155 # L2 relative error (avoid division by zero)
156 target_norm = denormalized_targets.norm()
157 if target_norm > 1e-8:
158 metrics[f"{field_name}_{MetricType.L2ERR}"] = delta.norm() / target_norm
159 else:
160 self.logger.warning(f"Target norm too small for {field_name}, skipping L2 error")
161
162 return metrics
163
164 def _create_chunked_batch(
165 self, batch: dict[str, torch.Tensor], start_idx: int, end_idx: int
166 ) -> dict[str, torch.Tensor]:
167 """
168 Create a batch slice for chunked processing.
169
170 Args:
171 batch: Full batch dictionary
172 start_idx: Start index for the chunk
173 end_idx: End index for the chunk
174
175 Returns:
176 Dictionary with chunked tensors for specified properties
177 """
178 chunked_batch = {}
179 for key, value in batch.items():
180 if key in self.chunk_properties:
181 chunked_batch[key] = value[:, start_idx:end_idx]
182 else:
183 chunked_batch[key] = value
184 return chunked_batch
185
186 def _get_chunk_indices(self, batch_size: int) -> list[tuple[int, int]]:
187 """
188 Calculate start and end indices for all chunks.
189
190 Args:
191 batch_size: Total size of the batch to chunk
192
193 Returns:
194 List of (start_idx, end_idx) tuples for each chunk
195 """
196 indices = []
197 num_chunks = math.ceil(batch_size / self.chunk_size)
198
199 for chunk_idx in range(num_chunks):
200 start = chunk_idx * self.chunk_size
201 end = min(start + self.chunk_size, batch_size)
202 indices.append((start, end))
203
204 return indices
205
206 def _chunked_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
207 """
208 Run model inference in chunks to reduce memory usage.
209
210 Splits the batch into smaller chunks, processes each independently,
211 and concatenates the results.
212
213 Args:
214 batch: Full batch dictionary
215
216 Returns:
217 Dictionary of model outputs with concatenated chunk results
218 """
219
220 batch_size = batch[self.sample_size_property].shape[1]
221 chunk_indices = self._get_chunk_indices(batch_size)
222
223 model_outputs = defaultdict(list)
224 for start_idx, end_idx in chunk_indices:
225 chunked_batch = self._create_chunked_batch(batch, start_idx, end_idx)
226 forward_inputs = {k: v for k, v in chunked_batch.items() if k in self.forward_properties}
227
228 with self.trainer.autocast_context:
229 chunked_outputs = self.model(**forward_inputs)
230
231 # Accumulate outputs
232 for key, value in chunked_outputs.items():
233 model_outputs[key].append(value)
234
235 # Concatenate all chunks
236 return {key: torch.cat(chunks, dim=1) for key, chunks in model_outputs.items()}
237
238 def _run_model_inference(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
239 """
240 Run model inference, optionally in chunks.
241
242 Args:
243 batch: Input batch dictionary
244
245 Returns:
246 Dictionary of model outputs
247 """
248 if self.chunked_inference:
249 return self._chunked_model_inference(batch)
250 else:
251 forward_inputs = {k: v for k, v in batch.items() if k in self.forward_properties}
252 with self.trainer.autocast_context:
253 return self.model(**forward_inputs)
254
255 def _align_chunk_sizes(self, prediction: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
256 """
257 Align prediction and target sizes when using chunked inference.
258
259 Args:
260 prediction: Prediction tensor
261 target: Target tensor
262
263 Returns:
264 Tuple of (aligned_prediction, aligned_target)
265 """
266 if self.chunked_inference and prediction.shape[1] != target.shape[1]:
267 min_size = min(prediction.shape[1], target.shape[1])
268 prediction = prediction[:, :min_size]
269 target = target[:, :min_size]
270 return prediction, target
271
272 def _compute_mode_metrics(
273 self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor], mode: str
274 ) -> dict[str, torch.Tensor]:
275 """
276 Compute metrics for a specific evaluation mode.
277
278 Args:
279 batch: Input batch containing targets
280 model_outputs: Model predictions
281 mode: Evaluation mode (field name)
282
283 Returns:
284 Dictionary of computed metrics for this mode
285 """
286 target = batch.get(f"{mode}{METRIC_SUFFIX_TARGET}")
287 prediction = model_outputs.get(mode)
288
289 if prediction is None or target is None:
290 return {}
291
292 dataset = self.data_container.get_dataset(self.dataset_key)
293 denorm_pred = dataset.denormalize(mode, prediction)
294 denorm_target = dataset.denormalize(mode, target)
295
296 # Align sizes if needed
297 denorm_pred, denorm_target = self._align_chunk_sizes(denorm_pred, denorm_target)
298
299 # Compute metrics
300 return self._compute_metrics(denorm_pred, denorm_target, mode)
301
302 def _compute_force_metrics(
303 self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]
304 ) -> dict[str, torch.Tensor]:
305 """Compute drag/lift coefficient errors for the current sample.
306
307 Uses full-resolution mesh geometry from the batch (``surface_normals``,
308 ``surface_area``, ``surface_position``) and loads full-resolution GT
309 fields from disk (since batch targets are subsampled by the pipeline).
310 Predicted Cd/Cl uses denormalized model outputs matched to the mesh via
311 nearest-neighbor lookup.
312
313 Requires ``surface_normals``, ``surface_area``, and ``surface_position``
314 to be present in the batch. Enable these by removing them from
315 ``excluded_properties`` in the dataset config.
316 """
317 # Full-resolution mesh geometry from batch
318 surface_normals = batch.get("surface_normals")
319 surface_areas = batch.get("surface_area")
320 mesh_positions = batch.get("surface_position")
321
322 if surface_normals is None or surface_areas is None or mesh_positions is None:
323 self.logger.warning(
324 "Skipping force computation: surface_normals, surface_area, or surface_position "
325 "not in batch. Ensure these fields are not excluded in the dataset config."
326 )
327 return {}
328
329 surface_normals = surface_normals.cpu().squeeze(0).float()
330 surface_areas = surface_areas.cpu().squeeze(0).float()
331 mesh_positions = mesh_positions.cpu().squeeze(0).float()
332
333 # Ground-truth Cd/Cl from full-resolution dataset files.
334 # Batch targets are subsampled by the pipeline, so we load the originals.
335 dataset = self.data_container.get_dataset(self.dataset_key)
336 sample_idx = batch["index"].squeeze().item()
337 info = dataset.sample_info(sample_idx)
338 run_dir = Path(info["sample_uri"])
339
340 # Load per-run reference area if available, otherwise use defaults.
341 design_id = info["design_id"]
342 ref_csv = run_dir / f"geo_ref_{design_id}.csv"
343 if ref_csv.exists():
344 import pandas as pd
345
346 ref_area = float(pd.read_csv(ref_csv)["aRef"][0])
347 flow = self._FlowConditions(reference_area=ref_area)
348 else:
349 flow = self._FlowConditions()
350
351 gt_pressure_path = run_dir / "surface_pressure.pt"
352 gt_shear_path = run_dir / "surface_wallshearstress.pt"
353 if not gt_pressure_path.exists() or not gt_shear_path.exists():
354 self.logger.debug(f"Skipping GT force computation for sample {sample_idx}: missing GT files")
355 return {}
356
357 gt_pressure = torch.load(gt_pressure_path, map_location="cpu", weights_only=True).float()
358 gt_shear = torch.load(gt_shear_path, map_location="cpu", weights_only=True).float()
359 if gt_pressure.ndim == 2 and gt_pressure.shape[-1] == 1:
360 gt_pressure = gt_pressure.squeeze(-1)
361
362 gt_coeffs = self._compute_force_coefficients(gt_pressure, gt_shear, surface_normals, surface_areas, flow)
363
364 # Predicted Cd/Cl from model outputs (denormalized)
365 pred_pressure = model_outputs.get("surface_pressure")
366 pred_friction = model_outputs.get("surface_friction")
367 pred_positions = batch.get("surface_anchor_position")
368
369 if pred_pressure is None or pred_friction is None or pred_positions is None:
370 return {}
371
372 pred_pressure_denorm = self.dataset_normalizers["surface_pressure"].inverse(pred_pressure.cpu()).squeeze(0)
373 pred_friction_denorm = self.dataset_normalizers["surface_friction"].inverse(pred_friction.cpu()).squeeze(0)
374 pred_positions_cpu = pred_positions.cpu().squeeze(0)
375
376 if pred_pressure_denorm.ndim == 2 and pred_pressure_denorm.shape[-1] == 1:
377 pred_pressure_denorm = pred_pressure_denorm.squeeze(-1)
378
379 # Match predicted positions to mesh positions for normals/areas lookup
380 position_tree = self._cKDTree(mesh_positions.numpy())
381 _, matched_indices = position_tree.query(pred_positions_cpu.numpy())
382
383 pred_coeffs = self._compute_force_coefficients(
384 pred_pressure_denorm,
385 pred_friction_denorm,
386 surface_normals[matched_indices],
387 surface_areas[matched_indices],
388 flow,
389 )
390
391 return {
392 "drag_error": torch.tensor(abs(gt_coeffs.cd - pred_coeffs.cd)),
393 "lift_error": torch.tensor(abs(gt_coeffs.cl - pred_coeffs.cl)),
394 }
395
396 def process_data(self, batch: dict[str, torch.Tensor], **_) -> dict[str, torch.Tensor]:
397 """
398 Execute forward pass and compute metrics.
399
400 Args:
401 batch: Input batch dictionary
402 **_: Additional unused arguments
403
404 Returns:
405 Dictionary mapping metric names to computed values
406 """
407 model_outputs = self._run_model_inference(batch)
408
409 metrics = {}
410 for mode in self.evaluation_modes:
411 metrics.update(self._compute_mode_metrics(batch, model_outputs, mode))
412
413 if self._compute_forces:
414 metrics.update(self._compute_force_metrics(batch, model_outputs))
415
416 if self._save_predictions:
417 self._collect_predictions(batch, model_outputs)
418
419 return metrics
420
421 def _collect_predictions(self, batch: dict[str, torch.Tensor], model_outputs: dict[str, torch.Tensor]) -> None:
422 """Denormalize and save predictions (and batch properties) for the current sample.
423
424 Saves each sample to disk immediately to avoid accumulating large tensors in memory.
425 """
426 sample = {}
427 for mode in self.evaluation_modes:
428 prediction = model_outputs.get(mode)
429 if prediction is None:
430 continue
431 normalizer = self.dataset_normalizers.get(mode)
432 if normalizer is not None:
433 denorm = normalizer.inverse(prediction.cpu())
434 else:
435 denorm = prediction.cpu()
436 sample[mode] = denorm.squeeze(0)
437 for key in self._config.batch_properties_to_save:
438 if key in batch:
439 sample[key] = batch[key].cpu().squeeze(0)
440 if sample:
441 out_dir = Path(self._predictions_path)
442 out_dir.mkdir(parents=True, exist_ok=True)
443 idx = self._prediction_counter
444 torch.save(sample, out_dir / f"sample_{idx:04d}.pt")
445 self._prediction_counter += 1
446
447 def process_results(self, results: dict[str, torch.Tensor], **_) -> None:
448 """
449 Log computed metrics to writer and optionally save predictions.
450
451 Args:
452 results: Dictionary of computed metrics
453 **_: Additional unused arguments
454 """
455 if not results:
456 self.logger.warning(f"No metrics computed for dataset '{self.dataset_key}'")
457 return
458
459 for name, metric in results.items():
460 metric_key = f"{METRIC_PREFIX_LOSS}{self.dataset_key}/{name}"
461 self.writer.add_scalar(
462 key=metric_key,
463 value=metric.mean(),
464 logger=self.logger,
465 format_str=".6f",
466 )
467
468 self.logger.debug(f"Logged {len(results)} metrics for dataset '{self.dataset_key}'")
469
470 if self._save_predictions and self._prediction_counter > 0:
471 self.logger.info(f"Saved {self._prediction_counter} prediction files to {self._predictions_path}")
472 self._prediction_counter = 0