noether.modeling.modules.layers.vit_layers

Classes

AvgPool2DPatchify

Tokenize a 2D grid by average-pooling each patch_size``×``patch_size patch.

MaskPatchify

Downsample a boolean mask to patch resolution via max-pooling (True = at least one valid cell).

FinalLayer

Final unpatchify projection with optional AdaLN modulation conditioned on a global vector c.

ConvOutputHead

Conv output head decodes tokens to spatial output

Module Contents

class noether.modeling.modules.layers.vit_layers.AvgPool2DPatchify(patch_size=16)

Bases: torch.nn.Module

Tokenize a 2D grid by average-pooling each patch_size``×``patch_size patch.

Parameters:

patch_size (int)

patch_size = 16
patch
forward(x)

Pool spatial features into patches.

Parameters:

x (torch.Tensor) – Input grid with shape (B, H, W, C).

Returns:

Pooled patch grid of shape (B, H // patch_size, W // patch_size, C).

Return type:

torch.Tensor

class noether.modeling.modules.layers.vit_layers.MaskPatchify(patch_size)

Bases: torch.nn.Module

Downsample a boolean mask to patch resolution via max-pooling (True = at least one valid cell).

Parameters:

patch_size (int)

patch_size
forward(mask)

Downsample boolean mask to patch resolution.

Parameters:

mask (torch.Tensor) – Boolean mask of shape (B, H, W).

Returns:

Flat boolean mask of shape (B, (H // patch_size) * (W // patch_size)).

Return type:

torch.Tensor

class noether.modeling.modules.layers.vit_layers.FinalLayer(hidden_size, patch_size, out_channels, use_modulation=True)

Bases: torch.nn.Module

Final unpatchify projection with optional AdaLN modulation conditioned on a global vector c.

Parameters:
  • hidden_size (int)

  • patch_size (int)

  • out_channels (int)

  • use_modulation (bool)

norm_final
linear
adaLN_modulation: torch.nn.Linear | None
forward(x, c=None)

Apply (optionally AdaLN-modulated) norm then linear projection.

Parameters:
  • x (torch.Tensor) – Tokens of shape (B, L, hidden_size).

  • c (torch.Tensor | None) – Conditioning vector of shape (B, hidden_size) when use_modulation=True; must be None when use_modulation=False. The caller is responsible for any upstream activation (e.g. SiLU) — this layer applies the AdaLN linear directly.

Returns:

Tensor of shape (B, L, patch_size**2 * out_channels).

Return type:

torch.Tensor

class noether.modeling.modules.layers.vit_layers.ConvOutputHead(hidden_dim, out_channels, patch_size, mid_channels=64)

Bases: torch.nn.Module

Conv output head decodes tokens to spatial output

Parameters:
  • hidden_dim (int)

  • out_channels (int)

  • patch_size (int)

  • mid_channels (int)

patch_size
out_channels
stages
forward(x, grid_h, grid_w)

Decode tokens to spatial output via cascaded PixelShuffle stages.

Parameters:
  • x (torch.Tensor) – Flattened tokens of shape (B, grid_h * grid_w, hidden_dim).

  • grid_h (int) – Patch grid height (H // patch_size).

  • grid_w (int) – Patch grid width (W // patch_size).

Returns:

Spatial tensor of shape (B, H, W, out_channels) after upsampling.

Return type:

torch.Tensor