noether.modeling.modules.layers.vit_layers¶
Classes¶
Tokenize a 2D grid by average-pooling each |
|
Downsample a boolean mask to patch resolution via max-pooling ( |
|
Final unpatchify projection with optional AdaLN modulation conditioned on a global vector |
|
Conv output head decodes tokens to spatial output |
Module Contents¶
- class noether.modeling.modules.layers.vit_layers.AvgPool2DPatchify(patch_size=16)¶
Bases:
torch.nn.ModuleTokenize a 2D grid by average-pooling each
patch_size``×``patch_sizepatch.- Parameters:
patch_size (int)
- patch_size = 16¶
- patch¶
- forward(x)¶
Pool spatial features into patches.
- Parameters:
x (torch.Tensor) – Input grid with shape
(B, H, W, C).- Returns:
Pooled patch grid of shape
(B, H // patch_size, W // patch_size, C).- Return type:
- class noether.modeling.modules.layers.vit_layers.MaskPatchify(patch_size)¶
Bases:
torch.nn.ModuleDownsample a boolean mask to patch resolution via max-pooling (
True= at least one valid cell).- Parameters:
patch_size (int)
- patch_size¶
- forward(mask)¶
Downsample boolean mask to patch resolution.
- Parameters:
mask (torch.Tensor) – Boolean mask of shape
(B, H, W).- Returns:
Flat boolean mask of shape
(B, (H // patch_size) * (W // patch_size)).- Return type:
- class noether.modeling.modules.layers.vit_layers.FinalLayer(hidden_size, patch_size, out_channels, use_modulation=True)¶
Bases:
torch.nn.ModuleFinal unpatchify projection with optional AdaLN modulation conditioned on a global vector
c.- norm_final¶
- linear¶
- adaLN_modulation: torch.nn.Linear | None¶
- forward(x, c=None)¶
Apply (optionally AdaLN-modulated) norm then linear projection.
- Parameters:
x (torch.Tensor) – Tokens of shape
(B, L, hidden_size).c (torch.Tensor | None) – Conditioning vector of shape
(B, hidden_size)whenuse_modulation=True; must beNonewhenuse_modulation=False. The caller is responsible for any upstream activation (e.g. SiLU) — this layer applies the AdaLN linear directly.
- Returns:
Tensor of shape
(B, L, patch_size**2 * out_channels).- Return type:
- class noether.modeling.modules.layers.vit_layers.ConvOutputHead(hidden_dim, out_channels, patch_size, mid_channels=64)¶
Bases:
torch.nn.ModuleConv output head decodes tokens to spatial output
- patch_size¶
- out_channels¶
- stages¶
- forward(x, grid_h, grid_w)¶
Decode tokens to spatial output via cascaded PixelShuffle stages.
- Parameters:
x (torch.Tensor) – Flattened tokens of shape
(B, grid_h * grid_w, hidden_dim).grid_h (int) – Patch grid height (
H // patch_size).grid_w (int) – Patch grid width (
W // patch_size).
- Returns:
Spatial tensor of shape
(B, H, W, out_channels)after upsampling.- Return type: