The Trainer

The AerodynamicsCFDTrainer (in trainers/aerodynamics_cfd.py) is a specialized trainer designed for aerodynamics Computational Fluid Dynamics (CFD) tasks, specifically for the AhmedML, DrivAerML, DrivAerNet++, ShapeNet-Car, and Emmi-Wing datasets. Its primary role is to manage the training step by processing model outputs, computing a flexible weighted loss, and returning the results.

For a step-by-step guide on implementing custom trainers, see How to Implement a Custom Trainer.

BaseTrainer implementation

To implement a custom Trainer for a downstream project, you must extend the BaseTrainer class. The BaseTrainer handles the full training loop and provides two key methods:

def loss_compute(
    self, forward_output: dict[str, torch.Tensor], targets: dict[str, torch.Tensor]
) -> LossResult | tuple[LossResult, dict[str, torch.Tensor]]:
    """
    Each trainer that extends this class needs to implement a custom loss computation using the targets and the model output.

    Args:
        forward_output: Output of the model after the forward pass.
        targets: Dict with target tensors needed to compute the loss for this trainer.

    Returns:
        A dict with the (weighted) sub-losses to log. Or a tuple of (losses, additional_outputs) where additional_outputs is
        a dict with additional information about the model forward pass that is passed to the track_after_accumulation_step method of the callbacks, e.g., the logits and targets to calculate a training accuracy in a callback).

    Note: If a tuple is returned, the second element will be passed as additional_outputs in the TrainerResult returned by the train_step method.
    """
    raise NotImplementedError("Subclasses must implement loss_compute.")
def train_step(self, batch: dict[str, Tensor], model: torch.nn.Module) -> TrainerResult:
    """Overriding this function is optional. By default, the `train_step` of the model will be called and is
    expected to return a TrainerResult. Trainers can override this method to implement custom training logic.

    Args:
        batch: Batch of data from which the loss is calculated.
        model: Model to use for processing the data.

    Returns:
        TrainerResult dataclass with the loss for backpropagation, (optionally) individual losses if multiple
        losses are used, and (optionally) additional information about the model forward pass that is passed
        to the callbacks (e.g., the logits and targets to calculate a training accuracy in a callback).
    """
    forward_batch, targets_batch = self._split_batch(batch)
    forward_output = model(**forward_batch)
    additional_outputs = None
    losses = self.loss_compute(forward_output=forward_output, targets=targets_batch)

    if isinstance(losses, tuple) and len(losses) == 2:
        losses, additional_outputs = losses

    if isinstance(losses, torch.Tensor):
        return TrainerResult(
            total_loss=losses, additional_outputs=additional_outputs, losses_to_log={"loss": losses}
        )
    elif isinstance(losses, list):
        losses = {f"loss_{i}": loss for i, loss in enumerate(losses)}

    if len(losses) == 0:
        raise ValueError("No losses computed, check your output keys and loss function.")

    return TrainerResult(
        total_loss=sum(losses.values(), start=torch.zeros_like(next(iter(losses.values())))),
        losses_to_log=losses,
        additional_outputs=additional_outputs,
    )

Understanding the two key methods:

As an end-user, you need to implement loss_compute and sometimes train_step.

The train_step method receives the batch from the multi-stage pipeline and the model being trained (which can be a DistributedDataParallel model when training on multiple GPUs).

In the base implementation, the batch is split into two sub-batches:

  1. Forward batch: Contains all tensors needed for the forward pass. The model receives the forward_batch as named keyword arguments, and the forward pass is computed.

  2. Targets batch: Contains tensors needed for loss computation. The loss_compute method computes the custom loss for your task.

Important

A warning is emitted if there are keys in the batch that do not end in either the forward batch or the target batch. This means that the collator returns tensors that are not used during the forward pass.

Return value requirements:

The train_step method must always return the TrainerResult dataclass, which should contain:

  • A scalar value of the total loss used to compute gradients (can be a weighted sum of multiple losses)

  • A dictionary with the losses you want to log

  • Optionally, a dictionary with additional output for logging

When to override train_step:

The train_step method defined in the BaseTrainer class fits most general deep learning forward passes. However, you can decide whether this implementation is sufficient for your downstream training task. If not, you can always implement a custom train_step method in the child trainer class (as has been done in the scaffold template at src/noether/scaffold/template_files/trainer/base.py).

BaseTrainer configuration

When using the default train_step method, you must define both the forward_properties and the target_properties to define which tensors are part of the forward_batch and which tensors are part of the target_batch.

In this walkthrough, the target properties are fixed per dataset, while the forward_properties depend on the model. Therefore, we define them as follows:

Full trainer configuration:

The complete trainer config for ShapeNet-Car is defined in configs/trainer/shapenet_trainer.yaml:

#BaseTrainerConfig
kind: trainers.AerodynamicsCFDTrainer
precision: bfloat16
max_epochs: 500
effective_batch_size: 1
log_every_n_epochs: 1
callbacks:  ${callbacks} 
forward_properties: ${model.forward_properties}
target_properties:
  - surface_pressure_target
  - volume_velocity_target
# AerodynamicsCFDTrainerConfig
surface_weight: 1.0
volume_weight: 1.0 
surface_pressure_weight: 1.0 
volume_velocity_weight: 1.0 
use_physics_features: false # whether to use the physics features (e.g., surface normals, volume normals, SDF) as input to the model

AerodynamicsCFDTrainer implementation

The most important variables in the __init__ method are the loss weights, which give you fine-grained control over the training objective.

Loss weight hierarchy:

The loss has two levels of weights:

  • Individual weights: Parameters like surface_pressure_weight and volume_velocity_weight control the importance of a specific physical quantity in the total loss.

  • Group weights: The surface_weight and volume_weight parameters apply an additional weight to all surface-related or volume-related losses, respectively.

During initialization, the trainer uses these weights to build an internal loss_items list. The output_modes parameter (e.g., ['surface_pressure', 'volume_velocity']) specifies which of these potential losses should be computed during training.

Custom loss calculation (loss_compute):

This method contains the core logic of the trainer for computing the loss:

def loss_compute(
    self, forward_output: dict[str, torch.Tensor], targets: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
    """Given the output of the model and the targets, compute the losses.
    Args:
        forward_output The output of the model, containing the predictions for each output mode.
        targets: Dict containing all target values to compute the loss.

    Returns:
        A dictionary containing the computed losses for each output mode.
    """
    losses: dict[str, torch.Tensor] = {}
    for item, weight, group_weight in self.loss_items:
        if weight > 0 and group_weight > 0 and item in forward_output:
            if f"{item}_target" not in targets:
                raise ValueError(
                    f"Target for '{item}' not found in targets. Ensure the targets contain the correct keys."
                )
            losses[f"{item}_loss"] = (
                F.mse_loss(targets[f"{item}_target"], forward_output[item]) * weight * group_weight
            )
    if len(losses) == 0:
        raise ValueError("No losses computed, check your output keys and loss function.")
    return losses

It iterates through the loss_items configured during initialization. For each item (like surface_pressure), it checks that its weight is non-zero and that the model produced a corresponding output key.

This flexible system allows you to easily experiment with different combinations of output objectives without changing the underlying code.

When using only a single loss value, the loss_compute method is not needed and can be implemented directly inside the forward function (by overriding the base train_step method, as done in the scaffold template at src/noether/scaffold/template_files/trainer/base.py).