noether.modeling.modules.layers.transformer_batchnorm

Classes

TransformerBatchNorm

Wrapper around torch.nn.BatchNorm1d that considers all tokens of a single sample as the full batch.

Module Contents

class noether.modeling.modules.layers.transformer_batchnorm.TransformerBatchNorm(num_features, eps=1e-05, elementwise_affine=True, bias=True)

Bases: torch.nn.Module

Wrapper around torch.nn.BatchNorm1d that considers all tokens of a single sample as the full batch. Additionally remaps affine to elementwise_affine and supports disabling bias to comply with the torch.nn.LayerNorm interface. Does not use any nn.BatchNorm1d modules to avoid errors with nn.SyncBatchnorm.

Parameters:
num_features
eps = 1e-05
elementwise_affine = True
forward(x)

BatchNorm1d where all tokens of a single sample correspond to a full batch.

Parameters:

x (torch.Tensor) – Tensor of shape (batch_size, seqlen, dim).

Returns:

Normalized x of shape (batch_size, seqlen, dim).

Return type:

torch.Tensor