aerodynamics_cfd.py¶
recipes/aero_cfd/trainers/aerodynamics_cfd.py
1# Copyright © 2025 Emmi AI GmbH. All rights reserved.
2
3from __future__ import annotations
4
5import torch
6import torch.nn.functional as F
7
8from noether.core.schemas.trainers import BaseTrainerConfig
9from noether.training.trainers import BaseTrainer
10
11
12class AerodynamicsCfdTrainerConfig(BaseTrainerConfig):
13 surface_weight: float = 1.0
14 """ Weight of the predicted values on the surface mesh. Defaults to 1.0.."""
15 volume_weight: float = 1.0
16 """Weight of the predicted values in the volume. Defaults to 1.0."""
17 surface_pressure_weight: float = 1.0
18 """Weight of the predicted values for the surface pressure. Defaults to 1.0."""
19 surface_friction_weight: float = 0.0
20 """Weight of the predicted values for the surface wall shear stress. Defaults to 0.0."""
21 volume_velocity_weight: float = 1.0
22 """Weight of the predicted values for the volume velocity. Defaults to 1.0."""
23 volume_pressure_weight: float = 0.0
24 """Weight of the predicted values for the volume total pressure coefficient. Defaults to 0.0."""
25 volume_vorticity_weight: float = 0.0
26 """Weight of the predicted values for the volume vorticity. Defaults to 0.0."""
27 use_physics_features: bool = False
28
29
30class AerodynamicsCFDTrainer(BaseTrainer):
31 """Trainer class for to train automative aerodynaimcs CFD for the: AhmedML, DrivaerML and Shapenet-Car Car dataset."""
32
33 def __init__(self, trainer_config: AerodynamicsCfdTrainerConfig, **kwargs):
34 """Trainer class for to train automative aerodynaimcs CFD for the: AhmedML, DrivaerML and Shapenet-Car Car dataset.
35
36 Args:
37 trainer_config: Configuration for the trainer.
38 **kwargs: Additional keyword arguments for the SgdTrainer.
39
40 Raises:
41 ValueError: When an output mode is not defined in the loss items.
42 """
43 super().__init__(
44 config=trainer_config,
45 **kwargs,
46 )
47
48 self.surface_pressure_weight = trainer_config.surface_pressure_weight
49 self.surface_friction_weight = trainer_config.surface_friction_weight
50 self.volume_velocity_weight = trainer_config.volume_velocity_weight
51 self.volume_pressure_weight = trainer_config.volume_pressure_weight
52 self.volume_vorticity_weight = trainer_config.volume_vorticity_weight
53
54 self.surface_weight = trainer_config.surface_weight
55 self.volume_weight = trainer_config.volume_weight
56
57 loss_items = {
58 "surface_pressure": (self.surface_pressure_weight, self.surface_weight),
59 "surface_friction": (
60 self.surface_friction_weight,
61 self.surface_weight,
62 ), # not used for ShapeNet-Car
63 "volume_velocity": (self.volume_velocity_weight, self.volume_weight),
64 "volume_pressure": (self.volume_pressure_weight, self.volume_weight), # not used for ShapeNet-Car
65 "volume_vorticity": (self.volume_vorticity_weight, self.volume_weight), # not used for ShapeNet-Car
66 }
67
68 self.loss_items = []
69 for target_property in self.target_properties:
70 if target_property[: -len("_target")] not in loss_items:
71 raise ValueError(f"Output mode '{target_property}' is not defined in loss items.")
72 self.loss_items.append(
73 (
74 target_property[: -len("_target")],
75 loss_items[target_property[: -len("_target")]][0],
76 loss_items[target_property[: -len("_target")]][1],
77 )
78 )
79
80 def loss_compute(
81 self, forward_output: dict[str, torch.Tensor], targets: dict[str, torch.Tensor]
82 ) -> dict[str, torch.Tensor]:
83 """Given the output of the model and the targets, compute the losses.
84 Args:
85 forward_output The output of the model, containing the predictions for each output mode.
86 targets: Dict containing all target values to compute the loss.
87
88 Returns:
89 A dictionary containing the computed losses for each output mode.
90 """
91 losses: dict[str, torch.Tensor] = {}
92 for item, weight, group_weight in self.loss_items:
93 if weight > 0 and group_weight > 0 and item in forward_output:
94 if f"{item}_target" not in targets:
95 raise ValueError(
96 f"Target for '{item}' not found in targets. Ensure the targets contain the correct keys."
97 )
98 losses[f"{item}_loss"] = (
99 F.mse_loss(targets[f"{item}_target"], forward_output[item]) * weight * group_weight
100 )
101 if len(losses) == 0:
102 raise ValueError("No losses computed, check your output keys and loss function.")
103 return losses