noether.modeling.modules.untied

Classes

UntiedLinearConfig

Configuration for a linear layer with per-type (untied) weight banks.

UntiedMixedAttentionConfig

Configuration for multi-head attention with per-type (untied) QKV and output projections.

UntiedMLPConfig

Configuration for an MLP with per-type (untied) weights.

UntiedTransformerBlockConfig

Configuration for a transformer block with per-type (untied) attention and MLP weights.

UntiedPerceiverBlockConfig

Configuration for a perceiver block with per-type (untied) Q/output projections and MLP weights.

UntiedLinear

Linear layer with per-domain weight banks.

UntiedMixedAttention

Multi-head attention with per-type QKV and output projections.

UntiedPerceiverAttention

Perceiver cross-attention with per-type Q and output projections.

UntiedMLP

Multi-layer perceptron with per-type weights.

UntiedTransformerBlock

Pre-norm transformer block with per-type (untied) attention and MLP weights.

UntiedPerceiverBlock

Perceiver block with per-type (untied) Q/output projections and MLP weights.

Module Contents

class noether.modeling.modules.untied.UntiedLinearConfig(/, **data)

Bases: pydantic.BaseModel

Configuration for a linear layer with per-type (untied) weight banks.

Composes a LinearProjectionConfig (shared across types) with a num_types field: each token type gets its own independent weight matrix with the geometry described by the linear projection config.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

num_types: int = None

Number of distinct token types, each with its own weight bank.

linear_projection: noether.modeling.modules.layers.linear_projection.LinearProjectionConfig

Shared geometry (input/output dims, bias, init) for every per-type weight bank.

class noether.modeling.modules.untied.UntiedMixedAttentionConfig(/, **data)

Bases: noether.modeling.modules.attention.anchor_attention.mixed.MixedAttentionConfig

Configuration for multi-head attention with per-type (untied) QKV and output projections.

Extends MixedAttentionConfig with a num_types field: the QKV and output projections are UntiedLinear layers so each token type gets its own projection weights. Attention itself is still computed across all tokens via MixedAttention._process_pattern_batched().

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

num_types: int = None

Number of distinct token types, each with its own QKV/output weight bank.

projection_config()

Configuration for the per-type QKV and output projections.

Return type:

UntiedLinearConfig

class noether.modeling.modules.untied.UntiedMLPConfig(/, **data)

Bases: pydantic.BaseModel

Configuration for an MLP with per-type (untied) weights.

Composes an MLPConfig (architecture: dims, activation, init) with a num_types field. The untied MLP mirrors MLP’s topology (input -> [hidden]*(num_layers+1) -> output with activations between layers) but uses UntiedLinear for every linear layer.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

num_types: int = None

Number of distinct token types.

mlp: noether.modeling.modules.mlp.mlp.MLPConfig

Underlying MLP architecture (dims, activation, init).

class noether.modeling.modules.untied.UntiedTransformerBlockConfig(/, **data)

Bases: pydantic.BaseModel

Configuration for a transformer block with per-type (untied) attention and MLP weights.

Composes a TransformerBlockConfig (shared layout: dims, heads, layer scale, drop path, etc.) with a num_types field. Both sub-layers have per-type weights: UntiedMultiHeadAttention for attention and UntiedMLP for the feed-forward.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

num_types: int = None

Number of distinct token types for the untied MLP.

transformer_block: noether.modeling.modules.blocks.transformer.TransformerBlockConfig

Shared transformer-block layout (dims, heads, layer scale, drop path, etc.).

attention_config()

Configuration for the UntiedMultiHeadAttention sub-layer.

Return type:

UntiedMixedAttentionConfig

untied_mlp_config()

Configuration for the UntiedMLP sub-layer.

Return type:

UntiedMLPConfig

class noether.modeling.modules.untied.UntiedPerceiverBlockConfig(/, **data)

Bases: pydantic.BaseModel

Configuration for a perceiver block with per-type (untied) Q/output projections and MLP weights.

Composes a PerceiverBlockConfig (shared layout: dims, heads, layer scale, drop path, etc.) with a num_types field. The Q and output projections in PerceiverAttention become per-type via UntiedLinear, while the KV projection stays shared (it operates on a single geometry encoding). The MLP is also replaced with UntiedMLP.

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Parameters:

data (Any)

num_types: int = None

Number of distinct token types for the untied projections.

perceiver_block_config: noether.modeling.modules.blocks.perceiver.PerceiverBlockConfig

Shared perceiver-block layout (dims, heads, kv_dim, layer scale, drop path, etc.).

untied_mlp_config()

Configuration for the UntiedMLP sub-layer.

Return type:

UntiedMLPConfig

class noether.modeling.modules.untied.UntiedLinear(config)

Bases: torch.nn.Module

Linear layer with per-domain weight banks.

Groups token specs by TokenSpec.domain and applies a separate F.linear per domain. Automatically picks the fastest strategy at runtime:

  1. torch._grouped_mm (CUDA, equal-size groups)

  2. Reshape + single torch.bmm (equal-size groups, any device)

  3. Padded torch.bmm (moderate skew)

  4. Split + F.linear loop (heavy skew or very few groups)

Per-type weights are stored as independent 2D nn.Parameter entries in an nn.ParameterList (one matrix per type). The bmm-based fast paths stack them into a 3D tensor on the fly. Storing them as 2D is what lets torch.optim.Muon (which rejects non-2D parameters) update each type’s weight matrix independently.

Domains must be strictly consecutive in token_specs.

Parameters:

config (UntiedLinearConfig) – Number of types and shared linear-projection geometry.

bias: torch.nn.ParameterList | None
num_types
init_weights
input_dim
output_dim
weight
reset_parameters()

Initialize the per-type weight banks following LinearProjection’s scheme.

Supported modes (from config.linear_projection.init_weights): "torch" (PyTorch nn.Linear defaults, applied per type), "truncnormal" / "truncnormal002" (truncated normal, std=0.02, zero bias), "zeros" (zero weight and bias).

Return type:

None

forward(x, token_specs)

Apply per-domain linear projections.

Parameters:
Returns:

Output tensor (B, S, D_out) in the same positional order as x.

Return type:

torch.Tensor

class noether.modeling.modules.untied.UntiedMixedAttention(config)

Bases: noether.modeling.modules.attention.anchor_attention.mixed.MixedAttention

Multi-head attention with per-type QKV and output projections.

Each token type has its own QKV and output projection weights (via UntiedLinear). The attention computation is pattern-aware, mirroring MixedAttention:

  • If attention_patterns is supplied (typically by a wrapping MultiBranchAnchorAttention such as SelfAnchorAttention/CrossAnchorAttention/JointAnchorAttention), those patterns are honored.

  • If no patterns are supplied (standalone usage), it falls back to a single all-to-all pattern — every token attends to every other token.

This lets the same module act as either a drop-in pattern-free attention or the inner workhorse of a multi-branch anchor attention, while keeping QKV and output projection weights untied per token type.

Parameters:
  • config (UntiedMixedAttentionConfig) – Configuration specifying attention dims and num_types.

  • config – Configuration for the MixedAttention module. See MixedAttentionConfig for the available options.

q
k
v
proj
forward(x, token_specs, attention_patterns, key_padding_mask=None, freqs=None, kv_cache=None)

Apply attention with per-type QKV/output projections.

Positional argument order matches MixedAttention.forward() so this module is a drop-in replacement for the shared MixedAttention that MultiBranchAnchorAttention owns — a wrapping anchor attention can call self.mixed_attention(x, token_specs, patterns, ...) unchanged.

Parameters:
Returns:

Output tensor (B, S, D).

Return type:

torch.Tensor

class noether.modeling.modules.untied.UntiedPerceiverAttention(config, num_types)

Bases: noether.modeling.modules.attention.perceiver.PerceiverAttention

Perceiver cross-attention with per-type Q and output projections.

The Q and output projections are UntiedLinear layers (one weight bank per token type), while the KV projection remains a shared nn.Linear since it operates on a single source (e.g. geometry encoding).

Parameters:
q
proj
forward(q, token_specs, kv=None, attn_mask=None, q_freqs=None, k_freqs=None, kv_cache=None)

Forward pass with per-type Q and output projections.

Parameters:
Returns:

Tuple of (output, new_kv_cache).

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]

class noether.modeling.modules.untied.UntiedMLP(config)

Bases: torch.nn.Module

Multi-layer perceptron with per-type weights.

Mirrors the topology of MLP (input -> [hidden]*(num_layers+1) -> output, with activations between layers), but uses UntiedLinear for every linear layer so each token type learns independent weights.

Parameters:

config (UntiedMLPConfig) – Configuration for the untied MLP.

layers
activation
forward(x, token_specs)

Apply the untied MLP.

Parameters:
Returns:

Output tensor (B, S, mlp_config.output_dim).

Return type:

torch.Tensor

class noether.modeling.modules.untied.UntiedTransformerBlock(config)

Bases: noether.modeling.modules.blocks.transformer.TransformerBlock

Pre-norm transformer block with per-type (untied) attention and MLP weights.

Same architecture and control flow as TransformerBlock; only the attention and MLP sub-modules differ. The parent’s __init__ builds all the shared plumbing (norms, modulation, layer scale, drop path, and the full forward pass) — including whichever attention module the configured attention_constructor selects. This subclass then:

  1. Replaces self.mlp with UntiedMLP.

  2. Injects per-type QKV/output projections into the attention sub-layer without disturbing its attention pattern:

    • When the parent built a MultiBranchAnchorAttention (SelfAnchorAttention / CrossAnchorAttention / JointAnchorAttention), only its inner mixed_attention is swapped for UntiedMixedAttention. The outer wrapper keeps emitting the correct per-branch attention patterns — so self_untied behaves like self with untied weights (not like a joint attention).

    • Otherwise (default dot_product attention, or any other non-multi-branch constructor) self.attention_block is replaced outright with UntiedMixedAttention, which falls back to a single all-to-all pattern.

  3. Overrides _mlp_forward() to route token_specs from attn_kwargs into UntiedMLP.

Attention receives token_specs automatically because the parent’s forward passes **attn_kwargs to self.attention_block, which in turn forwards them on to the inner UntiedMixedAttention.

Parameters:
  • config (UntiedTransformerBlockConfig) – Configuration for the untied transformer block.

  • config – Configuration for the transformer block. See TransformerBlockConfig for available options.

num_types
mlp
forward(x, condition=None, attn_kwargs=None)

Validate token_specs upfront, then delegate to TransformerBlock.forward().

The untied weight banks are indexed by name-order in token_specs, so duplicates would silently merge types and missing types would index out of range. We check here (before attention runs) to surface a clear error.

Parameters:
Return type:

tuple[torch.Tensor, dict[str, dict[str, torch.Tensor]] | None]

class noether.modeling.modules.untied.UntiedPerceiverBlock(config)

Bases: noether.modeling.modules.blocks.perceiver.PerceiverBlock

Perceiver block with per-type (untied) Q/output projections and MLP weights.

Same architecture and control flow as PerceiverBlock; the Q and output projections in the attention module become per-type via UntiedPerceiverAttention, and the feed-forward MLP becomes UntiedMLP. The KV projection remains shared since it operates on a single source (e.g. geometry encoding).

token_specs must be provided via attn_kwargs["token_specs"].

Parameters:
  • config (UntiedPerceiverBlockConfig) – Configuration for the untied perceiver block.

  • config – Configuration of the PerceiverBlock. See PerceiverBlockConfig

  • options. (for available)

num_types
attn
mlp
forward(q, kv=None, condition=None, attn_kwargs=None)

Forward pass with per-type Q/output projections and MLP.

Validates token_specs upfront, then runs the same perceiver block logic as the parent with token_specs routed to the untied sub-modules.

Parameters:
  • q (torch.Tensor) – Query tensor (B, S, hidden_dim) with tokens from all domains concatenated in token_specs order.

  • kv (torch.Tensor | None) – Key/value tensor from a single source (e.g. geometry encoding). Can be None when kv_cache is provided in attn_kwargs.

  • condition (torch.Tensor | None) – Optional conditioning vector for modulation.

  • attn_kwargs (dict[str, Any] | None) – Must contain "token_specs" key. Other entries (kv_cache, RoPE frequencies, masks) are forwarded to attention.

Returns:

Tuple of (output, kv_cache).

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]