aerodynamics_cfd.py

recipes/aero_cfd/trainers/aerodynamics_cfd.py

  1#  Copyright © 2025 Emmi AI GmbH. All rights reserved.
  2
  3from __future__ import annotations
  4
  5import torch
  6import torch.nn.functional as F
  7
  8from noether.core.schemas.trainers import BaseTrainerConfig
  9from noether.training.trainers import BaseTrainer
 10
 11
 12class AerodynamicsCfdTrainerConfig(BaseTrainerConfig):
 13    surface_weight: float = 1.0
 14    """ Weight of the predicted values on the surface mesh. Defaults to 1.0.."""
 15    volume_weight: float = 1.0
 16    """Weight of the predicted values in the volume. Defaults to 1.0."""
 17    surface_pressure_weight: float = 1.0
 18    """Weight of the predicted values for the surface pressure. Defaults to 1.0."""
 19    surface_friction_weight: float = 0.0
 20    """Weight of the predicted values for the surface wall shear stress. Defaults to 0.0."""
 21    volume_velocity_weight: float = 1.0
 22    """Weight of the predicted values for the volume velocity. Defaults to 1.0."""
 23    volume_pressure_weight: float = 0.0
 24    """Weight of the predicted values for the volume total pressure coefficient. Defaults to 0.0."""
 25    volume_vorticity_weight: float = 0.0
 26    """Weight of the predicted values for the volume vorticity. Defaults to 0.0."""
 27    use_physics_features: bool = False
 28
 29
 30class AerodynamicsCFDTrainer(BaseTrainer):
 31    """Trainer class for to train automative aerodynaimcs CFD for the: AhmedML, DrivaerML and Shapenet-Car Car dataset."""
 32
 33    def __init__(self, trainer_config: AerodynamicsCfdTrainerConfig, **kwargs):
 34        """Trainer class for to train automative aerodynaimcs CFD for the: AhmedML, DrivaerML and Shapenet-Car Car dataset.
 35
 36        Args:
 37            trainer_config: Configuration for the trainer.
 38            **kwargs: Additional keyword arguments for the SgdTrainer.
 39
 40        Raises:
 41            ValueError: When an output mode is not defined in the loss items.
 42        """
 43        super().__init__(
 44            config=trainer_config,
 45            **kwargs,
 46        )
 47
 48        self.surface_pressure_weight = trainer_config.surface_pressure_weight
 49        self.surface_friction_weight = trainer_config.surface_friction_weight
 50        self.volume_velocity_weight = trainer_config.volume_velocity_weight
 51        self.volume_pressure_weight = trainer_config.volume_pressure_weight
 52        self.volume_vorticity_weight = trainer_config.volume_vorticity_weight
 53
 54        self.surface_weight = trainer_config.surface_weight
 55        self.volume_weight = trainer_config.volume_weight
 56
 57        loss_items = {
 58            "surface_pressure": (self.surface_pressure_weight, self.surface_weight),
 59            "surface_friction": (
 60                self.surface_friction_weight,
 61                self.surface_weight,
 62            ),  # not used for ShapeNet-Car
 63            "volume_velocity": (self.volume_velocity_weight, self.volume_weight),
 64            "volume_pressure": (self.volume_pressure_weight, self.volume_weight),  # not used for ShapeNet-Car
 65            "volume_vorticity": (self.volume_vorticity_weight, self.volume_weight),  # not used for ShapeNet-Car
 66        }
 67
 68        self.loss_items = []
 69        for target_property in self.target_properties:
 70            if target_property[: -len("_target")] not in loss_items:
 71                raise ValueError(f"Output mode '{target_property}' is not defined in loss items.")
 72            self.loss_items.append(
 73                (
 74                    target_property[: -len("_target")],
 75                    loss_items[target_property[: -len("_target")]][0],
 76                    loss_items[target_property[: -len("_target")]][1],
 77                )
 78            )
 79
 80    def loss_compute(
 81        self, forward_output: dict[str, torch.Tensor], targets: dict[str, torch.Tensor]
 82    ) -> dict[str, torch.Tensor]:
 83        """Given the output of the model and the targets, compute the losses.
 84        Args:
 85            forward_output The output of the model, containing the predictions for each output mode.
 86            targets: Dict containing all target values to compute the loss.
 87
 88        Returns:
 89            A dictionary containing the computed losses for each output mode.
 90        """
 91        losses: dict[str, torch.Tensor] = {}
 92        for item, weight, group_weight in self.loss_items:
 93            if weight > 0 and group_weight > 0 and item in forward_output:
 94                if f"{item}_target" not in targets:
 95                    raise ValueError(
 96                        f"Target for '{item}' not found in targets. Ensure the targets contain the correct keys."
 97                    )
 98                losses[f"{item}_loss"] = (
 99                    F.mse_loss(targets[f"{item}_target"], forward_output[item]) * weight * group_weight
100                )
101        if len(losses) == 0:
102            raise ValueError("No losses computed, check your output keys and loss function.")
103        return losses