noether.modeling.models.vit¶
Classes¶
Module Contents¶
- class noether.modeling.models.vit.ViTConfig(/, **data)¶
Bases:
noether.core.models.base.ModelBaseConfigConfiguration for ViT model
Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- model_config¶
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- patch_size: int = None¶
Patch side length in cells. The grid resolution must be divisible by this value.
Token hidden dimension throughout the transformer stack.
- use_conditioning: bool = True¶
If True, enable AdaLN-Zero conditioning (forward requires
cond); if False, plain ViT (condmust beNone).
- use_conv_output_head: bool = True¶
If True, decode via a cascaded PixelShuffle conv head; if False, decode via a linear unpatchify.
- property transformer_block_config: noether.modeling.modules.blocks.transformer.TransformerBlockConfig¶
- class noether.modeling.models.vit.ViT(config)¶
Bases:
torch.nn.ModuleVision Transformer for spatial regression on continuous-coordinate grids.
Based on the ViT paper (https://arxiv.org/pdf/2010.11929) with several modifications, such as:
Continuous coordinate inputs with sincos positional embedding and RoPE (vs. learned 1D position embeddings).
Optional AdaLN-Zero conditioning, à la DiT (https://arxiv.org/abs/2212.09748).
RMSNorm and QK-norm in attention (vs. LayerNorm only).
- Parameters:
config (ViTConfig) – Configuration for the ViT model. See
ViTConfigfor available options.
- coord_dim¶
- out_channels¶
- patch_size¶
- num_heads¶
- token_dropout¶
- use_conditioning¶
- pool_patch¶
- mask_patchify¶
- pos_embedding¶
- rope¶
- backbone¶
- use_conv_output_head¶
- initialize_weights()¶
Initialize backbone weights
- Return type:
None
- unpatchify(x, grid_h, grid_w)¶
Linear unpatchify:
(B, L, p²·C_out) → (B, H, W, C_out).- Parameters:
x (torch.Tensor)
grid_h (int)
grid_w (int)
- Return type:
- forward(x, coords, mask=None, cond=None, return_tokens=False)¶
Run the standard ViT.
- Parameters:
x (torch.Tensor | None) – Optional pre-computed patch embeddings of shape
(B, L, hidden_dim). WhenNone, tokens come purely from positional encoding.coords (torch.Tensor) – Per-cell coordinates of shape
(B, H, W, coord_dim).mask (torch.Tensor | None) – Optional per-cell fluid mask of shape
(B, H, W).cond (torch.Tensor | None) – AdaLN conditioning vector of shape
(B, hidden_dim). Required when the ViT was built withuse_conditioning=True(the default); must beNoneotherwise.return_tokens (bool) – If True, return raw post-FinalLayer tokens plus
(grid_h, grid_w)instead of the decoded spatial output.
- Returns:
Either
(B, H, W, out_channels)or(tokens, (grid_h, grid_w))ifreturn_tokens.- Return type:
torch.Tensor | tuple[torch.Tensor, tuple[int, int]]