noether.modeling.modules.untied

Classes

UntiedLinear

Linear layer with per-domain weight banks.

UntiedMixedAttention

Multi-head attention with per-type QKV and output projections.

UntiedPerceiverAttention

Perceiver cross-attention with per-type Q and output projections.

UntiedMLP

Multi-layer perceptron with per-type weights.

UntiedTransformerBlock

Pre-norm transformer block with per-type (untied) attention and MLP weights.

UntiedPerceiverBlock

Perceiver block with per-type (untied) Q/output projections and MLP weights.

Module Contents

class noether.modeling.modules.untied.UntiedLinear(config)

Bases: torch.nn.Module

Linear layer with per-domain weight banks.

Groups token specs by TokenSpec.domain and applies a separate F.linear per domain. Automatically picks the fastest strategy at runtime:

  1. torch._grouped_mm (CUDA, equal-size groups)

  2. Reshape + single torch.bmm (equal-size groups, any device)

  3. Padded torch.bmm (moderate skew)

  4. Split + F.linear loop (heavy skew or very few groups)

Per-type weights are stored as independent 2D nn.Parameter entries in an nn.ParameterList (one matrix per type). The bmm-based fast paths stack them into a 3D tensor on the fly. Storing them as 2D is what lets torch.optim.Muon (which rejects non-2D parameters) update each type’s weight matrix independently.

Domains must be strictly consecutive in token_specs.

Parameters:

config (noether.core.schemas.modules.untied.UntiedLinearConfig) – Number of types and shared linear-projection geometry.

bias: torch.nn.ParameterList | None
num_types
init_weights
input_dim
output_dim
weight
reset_parameters()

Initialize the per-type weight banks following LinearProjection’s scheme.

Supported modes (from config.linear_projection.init_weights): "torch" (PyTorch nn.Linear defaults, applied per type), "truncnormal" / "truncnormal002" (truncated normal, std=0.02, zero bias), "zeros" (zero weight and bias).

Return type:

None

forward(x, token_specs)

Apply per-domain linear projections.

Parameters:
Returns:

Output tensor (B, S, D_out) in the same positional order as x.

Return type:

torch.Tensor

class noether.modeling.modules.untied.UntiedMixedAttention(config)

Bases: noether.modeling.modules.attention.anchor_attention.mixed.MixedAttention

Multi-head attention with per-type QKV and output projections.

Each token type has its own QKV and output projection weights (via UntiedLinear). The attention computation is pattern-aware, mirroring MixedAttention:

  • If attention_patterns is supplied (typically by a wrapping MultiBranchAnchorAttention such as SelfAnchorAttention/CrossAnchorAttention/JointAnchorAttention), those patterns are honored.

  • If no patterns are supplied (standalone usage), it falls back to a single all-to-all pattern — every token attends to every other token.

This lets the same module act as either a drop-in pattern-free attention or the inner workhorse of a multi-branch anchor attention, while keeping QKV and output projection weights untied per token type.

Parameters:
q
k
v
proj
forward(x, token_specs, attention_patterns, key_padding_mask=None, freqs=None, kv_cache=None)

Apply attention with per-type QKV/output projections.

Positional argument order matches MixedAttention.forward() so this module is a drop-in replacement for the shared MixedAttention that MultiBranchAnchorAttention owns — a wrapping anchor attention can call self.mixed_attention(x, token_specs, patterns, ...) unchanged.

Parameters:
Returns:

Output tensor (B, S, D).

Return type:

torch.Tensor

class noether.modeling.modules.untied.UntiedPerceiverAttention(config, num_types)

Bases: noether.modeling.modules.attention.perceiver.PerceiverAttention

Perceiver cross-attention with per-type Q and output projections.

The Q and output projections are UntiedLinear layers (one weight bank per token type), while the KV projection remains a shared nn.Linear since it operates on a single source (e.g. geometry encoding).

Parameters:
q
proj
forward(q, token_specs, kv=None, attn_mask=None, q_freqs=None, k_freqs=None, kv_cache=None)

Forward pass with per-type Q and output projections.

Parameters:
Returns:

Tuple of (output, new_kv_cache).

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]

class noether.modeling.modules.untied.UntiedMLP(config)

Bases: torch.nn.Module

Multi-layer perceptron with per-type weights.

Mirrors the topology of MLP (input -> [hidden]*(num_layers+1) -> output, with activations between layers), but uses UntiedLinear for every linear layer so each token type learns independent weights.

Parameters:

config (noether.core.schemas.modules.untied.UntiedMLPConfig) – Configuration for the untied MLP.

layers
activation
forward(x, token_specs)

Apply the untied MLP.

Parameters:
Returns:

Output tensor (B, S, mlp_config.output_dim).

Return type:

torch.Tensor

class noether.modeling.modules.untied.UntiedTransformerBlock(config)

Bases: noether.modeling.modules.blocks.transformer.TransformerBlock

Pre-norm transformer block with per-type (untied) attention and MLP weights.

Same architecture and control flow as TransformerBlock; only the attention and MLP sub-modules differ. The parent’s __init__ builds all the shared plumbing (norms, modulation, layer scale, drop path, and the full forward pass) — including whichever attention module the configured attention_constructor selects. This subclass then:

  1. Replaces self.mlp with UntiedMLP.

  2. Injects per-type QKV/output projections into the attention sub-layer without disturbing its attention pattern:

    • When the parent built a MultiBranchAnchorAttention (SelfAnchorAttention / CrossAnchorAttention / JointAnchorAttention), only its inner mixed_attention is swapped for UntiedMixedAttention. The outer wrapper keeps emitting the correct per-branch attention patterns — so self_untied behaves like self with untied weights (not like a joint attention).

    • Otherwise (default dot_product attention, or any other non-multi-branch constructor) self.attention_block is replaced outright with UntiedMixedAttention, which falls back to a single all-to-all pattern.

  3. Overrides _mlp_forward() to route token_specs from attn_kwargs into UntiedMLP.

Attention receives token_specs automatically because the parent’s forward passes **attn_kwargs to self.attention_block, which in turn forwards them on to the inner UntiedMixedAttention.

Parameters:
num_types
mlp
forward(x, condition=None, attn_kwargs=None)

Validate token_specs upfront, then delegate to TransformerBlock.forward().

The untied weight banks are indexed by name-order in token_specs, so duplicates would silently merge types and missing types would index out of range. We check here (before attention runs) to surface a clear error.

Parameters:
Return type:

tuple[torch.Tensor, dict[str, dict[str, torch.Tensor]] | None]

class noether.modeling.modules.untied.UntiedPerceiverBlock(config)

Bases: noether.modeling.modules.blocks.perceiver.PerceiverBlock

Perceiver block with per-type (untied) Q/output projections and MLP weights.

Same architecture and control flow as PerceiverBlock; the Q and output projections in the attention module become per-type via UntiedPerceiverAttention, and the feed-forward MLP becomes UntiedMLP. The KV projection remains shared since it operates on a single source (e.g. geometry encoding).

token_specs must be provided via attn_kwargs["token_specs"].

Parameters:
num_types
attn
mlp
forward(q, kv=None, condition=None, attn_kwargs=None)

Forward pass with per-type Q/output projections and MLP.

Validates token_specs upfront, then runs the same perceiver block logic as the parent with token_specs routed to the untied sub-modules.

Parameters:
  • q (torch.Tensor) – Query tensor (B, S, hidden_dim) with tokens from all domains concatenated in token_specs order.

  • kv (torch.Tensor | None) – Key/value tensor from a single source (e.g. geometry encoding). Can be None when kv_cache is provided in attn_kwargs.

  • condition (torch.Tensor | None) – Optional conditioning vector for modulation.

  • attn_kwargs (dict[str, Any] | None) – Must contain "token_specs" key. Other entries (kv_cache, RoPE frequencies, masks) are forwarded to attention.

Returns:

Tuple of (output, kv_cache).

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]