noether.modeling.models.vit

Classes

ViTConfig

Configuration for ViT model

ViT

Vision Transformer for spatial regression on continuous-coordinate grids.

Module Contents

class noether.modeling.models.vit.ViTConfig(/, **data)

Bases: noether.core.models.base.ModelBaseConfig

Configuration for ViT model

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

coord_dim: int = None

Coordinate dimensionality of the input grid (2 for 2D, 3 for 3D).

out_channels: int = None

Number of output channels emitted per spatial cell.

patch_size: int = None

Patch side length in cells. The grid resolution must be divisible by this value.

hidden_dim: int = None

Token hidden dimension throughout the transformer stack.

num_heads: int = None

Number of attention heads in each transformer block.

depth: int = None

Number of stacked transformer blocks.

mlp_ratio: int = None

FFN expansion factor inside each transformer block.

use_conditioning: bool = True

If True, enable AdaLN-Zero conditioning (forward requires cond); if False, plain ViT (cond must be None).

token_dropout: float = None

Per-patch token dropout probability used during training.

attn_drop: float = None

Dropout probability inside attention.

use_conv_output_head: bool = True

If True, decode via a cascaded PixelShuffle conv head; if False, decode via a linear unpatchify.

property transformer_block_config: noether.modeling.modules.blocks.transformer.TransformerBlockConfig
Return type:

noether.modeling.modules.blocks.transformer.TransformerBlockConfig

class noether.modeling.models.vit.ViT(config)

Bases: torch.nn.Module

Vision Transformer for spatial regression on continuous-coordinate grids.

Based on the ViT paper (https://arxiv.org/pdf/2010.11929) with several modifications, such as:

  • Continuous coordinate inputs with sincos positional embedding and RoPE (vs. learned 1D position embeddings).

  • Optional AdaLN-Zero conditioning, à la DiT (https://arxiv.org/abs/2212.09748).

  • RMSNorm and QK-norm in attention (vs. LayerNorm only).

Parameters:

config (ViTConfig) – Configuration for the ViT model. See ViTConfig for available options.

coord_dim
out_channels
patch_size
hidden_dim
num_heads
token_dropout
use_conditioning
pool_patch
mask_patchify
pos_embedding
rope
backbone
use_conv_output_head
initialize_weights()

Initialize backbone weights

Return type:

None

unpatchify(x, grid_h, grid_w)

Linear unpatchify: (B, L, p²·C_out) (B, H, W, C_out).

Parameters:
Return type:

torch.Tensor

forward(x, coords, mask=None, cond=None, return_tokens=False)

Run the standard ViT.

Parameters:
  • x (torch.Tensor | None) – Optional pre-computed patch embeddings of shape (B, L, hidden_dim). When None, tokens come purely from positional encoding.

  • coords (torch.Tensor) – Per-cell coordinates of shape (B, H, W, coord_dim).

  • mask (torch.Tensor | None) – Optional per-cell fluid mask of shape (B, H, W).

  • cond (torch.Tensor | None) – AdaLN conditioning vector of shape (B, hidden_dim). Required when the ViT was built with use_conditioning=True (the default); must be None otherwise.

  • return_tokens (bool) – If True, return raw post-FinalLayer tokens plus (grid_h, grid_w) instead of the decoded spatial output.

Returns:

Either (B, H, W, out_channels) or (tokens, (grid_h, grid_w)) if return_tokens.

Return type:

torch.Tensor | tuple[torch.Tensor, tuple[int, int]]