noether.modeling.modules.layers.transformer_batchnorm¶
Classes¶
Wrapper around torch.nn.BatchNorm1d that considers all tokens of a single sample as the full batch. |
Module Contents¶
- class noether.modeling.modules.layers.transformer_batchnorm.TransformerBatchNorm(num_features, eps=1e-05, elementwise_affine=True, bias=True)¶
Bases:
torch.nn.ModuleWrapper around torch.nn.BatchNorm1d that considers all tokens of a single sample as the full batch. Additionally remaps affine to elementwise_affine and supports disabling bias to comply with the torch.nn.LayerNorm interface. Does not use any nn.BatchNorm1d modules to avoid errors with nn.SyncBatchnorm.
- num_features¶
- eps = 1e-05¶
- elementwise_affine = True¶
- forward(x)¶
BatchNorm1d where all tokens of a single sample correspond to a full batch.
- Parameters:
x (torch.Tensor) – Tensor of shape (batch_size, seqlen, dim).
- Returns:
Normalized x of shape (batch_size, seqlen, dim).
- Return type: