noether.modeling.modules.untied¶
Classes¶
Linear layer with per-domain weight banks. |
|
Multi-head attention with per-type QKV and output projections. |
|
Perceiver cross-attention with per-type Q and output projections. |
|
Multi-layer perceptron with per-type weights. |
|
Pre-norm transformer block with per-type (untied) attention and MLP weights. |
|
Perceiver block with per-type (untied) Q/output projections and MLP weights. |
Module Contents¶
- class noether.modeling.modules.untied.UntiedLinear(config)¶
Bases:
torch.nn.ModuleLinear layer with per-domain weight banks.
Groups token specs by
TokenSpec.domainand applies a separateF.linearper domain. Automatically picks the fastest strategy at runtime:torch._grouped_mm(CUDA, equal-size groups)Reshape + single
torch.bmm(equal-size groups, any device)Padded
torch.bmm(moderate skew)Split +
F.linearloop (heavy skew or very few groups)
Per-type weights are stored as independent 2D
nn.Parameterentries in annn.ParameterList(one matrix per type). The bmm-based fast paths stack them into a 3D tensor on the fly. Storing them as 2D is what letstorch.optim.Muon(which rejects non-2D parameters) update each type’s weight matrix independently.Domains must be strictly consecutive in
token_specs.- Parameters:
config (noether.core.schemas.modules.untied.UntiedLinearConfig) – Number of types and shared linear-projection geometry.
- bias: torch.nn.ParameterList | None¶
- num_types¶
- init_weights¶
- input_dim¶
- output_dim¶
- weight¶
- reset_parameters()¶
Initialize the per-type weight banks following
LinearProjection’s scheme.Supported modes (from
config.linear_projection.init_weights):"torch"(PyTorchnn.Lineardefaults, applied per type),"truncnormal"/"truncnormal002"(truncated normal, std=0.02, zero bias),"zeros"(zero weight and bias).- Return type:
None
- forward(x, token_specs)¶
Apply per-domain linear projections.
- Parameters:
x (torch.Tensor) – Input tensor
(B, S, D_in)with tokens concatenated intoken_specsorder. Specs for the same domain must be adjacent.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications whose sizes sum to
S.
- Returns:
Output tensor
(B, S, D_out)in the same positional order asx.- Return type:
- class noether.modeling.modules.untied.UntiedMixedAttention(config)¶
Bases:
noether.modeling.modules.attention.anchor_attention.mixed.MixedAttentionMulti-head attention with per-type QKV and output projections.
Each token type has its own QKV and output projection weights (via
UntiedLinear). The attention computation is pattern-aware, mirroringMixedAttention:If
attention_patternsis supplied (typically by a wrappingMultiBranchAnchorAttentionsuch asSelfAnchorAttention/CrossAnchorAttention/JointAnchorAttention), those patterns are honored.If no patterns are supplied (standalone usage), it falls back to a single all-to-all pattern — every token attends to every other token.
This lets the same module act as either a drop-in pattern-free attention or the inner workhorse of a multi-branch anchor attention, while keeping QKV and output projection weights untied per token type.
- Parameters:
config (noether.core.schemas.modules.untied.UntiedMixedAttentionConfig) – Configuration specifying attention dims and
num_types.config – Configuration for the MixedAttention module. See
MixedAttentionConfigfor the available options.
- q¶
- k¶
- v¶
- proj¶
- forward(x, token_specs, attention_patterns, key_padding_mask=None, freqs=None, kv_cache=None)¶
Apply attention with per-type QKV/output projections.
Positional argument order matches
MixedAttention.forward()so this module is a drop-in replacement for the sharedMixedAttentionthatMultiBranchAnchorAttentionowns — a wrapping anchor attention can callself.mixed_attention(x, token_specs, patterns, ...)unchanged.- Parameters:
x (torch.Tensor) – Input tensor
(B, S, D)with tokens concatenated intoken_specsorder.Smust equal the sum oftoken_specssizes.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications defining the input structure.
attention_patterns (collections.abc.Sequence[noether.core.schemas.modules.attention.AttentionPattern]) – Optional attention patterns. When
None, falls back to a single all-to-all pattern over every name intoken_specs(standalone usage).key_padding_mask (torch.Tensor | None) – Optional boolean mask
(B, S).True= real token.freqs (torch.Tensor | None) – RoPE frequencies for positional encoding.
kv_cache (dict[str, torch.Tensor] | None) – Optional dictionary for caching key/value tensors. Not yet supported.
- Returns:
Output tensor
(B, S, D).- Return type:
- class noether.modeling.modules.untied.UntiedPerceiverAttention(config, num_types)¶
Bases:
noether.modeling.modules.attention.perceiver.PerceiverAttentionPerceiver cross-attention with per-type Q and output projections.
The Q and output projections are
UntiedLinearlayers (one weight bank per token type), while the KV projection remains a sharednn.Linearsince it operates on a single source (e.g. geometry encoding).- Parameters:
config (noether.core.schemas.modules.attention.AttentionConfig) – Attention configuration (
AttentionConfig-compatible).num_types (int) – Number of distinct token types for per-type projections.
config – Configuration for the PerceiverAttention module. See
AttentionConfigfor available options.
- q¶
- proj¶
- forward(q, token_specs, kv=None, attn_mask=None, q_freqs=None, k_freqs=None, kv_cache=None)¶
Forward pass with per-type Q and output projections.
- Parameters:
q (torch.Tensor) – Query tensor
(B, S, hidden_dim)with tokens concatenated intoken_specsorder.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications defining the per-type structure.
kv (torch.Tensor | None) – Key/value tensor from a single source (e.g. geometry). Can be
Nonewhenkv_cacheis provided.attn_mask (torch.Tensor | None) – Optional attention mask.
q_freqs (torch.Tensor | None) – RoPE frequencies for queries.
k_freqs (torch.Tensor | None) – RoPE frequencies for keys.
kv_cache (dict[str, torch.Tensor] | None) – Cached K/V tensors from a previous forward pass.
- Returns:
Tuple of (output, new_kv_cache).
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor] | None]
- class noether.modeling.modules.untied.UntiedMLP(config)¶
Bases:
torch.nn.ModuleMulti-layer perceptron with per-type weights.
Mirrors the topology of
MLP(input -> [hidden]*(num_layers+1) -> output, with activations between layers), but usesUntiedLinearfor every linear layer so each token type learns independent weights.- Parameters:
config (noether.core.schemas.modules.untied.UntiedMLPConfig) – Configuration for the untied MLP.
- layers¶
- activation¶
- forward(x, token_specs)¶
Apply the untied MLP.
- Parameters:
x (torch.Tensor) – Input tensor
(B, S, mlp_config.input_dim).token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications defining the input structure.
- Returns:
Output tensor
(B, S, mlp_config.output_dim).- Return type:
- class noether.modeling.modules.untied.UntiedTransformerBlock(config)¶
Bases:
noether.modeling.modules.blocks.transformer.TransformerBlockPre-norm transformer block with per-type (untied) attention and MLP weights.
Same architecture and control flow as
TransformerBlock; only the attention and MLP sub-modules differ. The parent’s__init__builds all the shared plumbing (norms, modulation, layer scale, drop path, and the full forward pass) — including whichever attention module the configuredattention_constructorselects. This subclass then:Replaces
self.mlpwithUntiedMLP.Injects per-type QKV/output projections into the attention sub-layer without disturbing its attention pattern:
When the parent built a
MultiBranchAnchorAttention(SelfAnchorAttention/CrossAnchorAttention/JointAnchorAttention), only its innermixed_attentionis swapped forUntiedMixedAttention. The outer wrapper keeps emitting the correct per-branch attention patterns — soself_untiedbehaves likeselfwith untied weights (not like a joint attention).Otherwise (default
dot_productattention, or any other non-multi-branch constructor)self.attention_blockis replaced outright withUntiedMixedAttention, which falls back to a single all-to-all pattern.
Overrides
_mlp_forward()to routetoken_specsfromattn_kwargsintoUntiedMLP.
Attention receives
token_specsautomatically because the parent’s forward passes**attn_kwargstoself.attention_block, which in turn forwards them on to the innerUntiedMixedAttention.- Parameters:
config (noether.core.schemas.modules.untied.UntiedTransformerBlockConfig) – Configuration for the untied transformer block.
config – Configuration for the transformer block. See
TransformerBlockConfigfor available options.
- num_types¶
- mlp¶
- forward(x, condition=None, attn_kwargs=None)¶
Validate
token_specsupfront, then delegate toTransformerBlock.forward().The untied weight banks are indexed by name-order in
token_specs, so duplicates would silently merge types and missing types would index out of range. We check here (before attention runs) to surface a clear error.- Parameters:
x (torch.Tensor)
condition (torch.Tensor | None)
- Return type:
tuple[torch.Tensor, dict[str, dict[str, torch.Tensor]] | None]
- class noether.modeling.modules.untied.UntiedPerceiverBlock(config)¶
Bases:
noether.modeling.modules.blocks.perceiver.PerceiverBlockPerceiver block with per-type (untied) Q/output projections and MLP weights.
Same architecture and control flow as
PerceiverBlock; the Q and output projections in the attention module become per-type viaUntiedPerceiverAttention, and the feed-forward MLP becomesUntiedMLP. The KV projection remains shared since it operates on a single source (e.g. geometry encoding).token_specsmust be provided viaattn_kwargs["token_specs"].- Parameters:
config (noether.core.schemas.modules.untied.UntiedPerceiverBlockConfig) – Configuration for the untied perceiver block.
config – Configuration of the PerceiverBlock. See
PerceiverBlockConfigoptions. (for available)
- num_types¶
- attn¶
- mlp¶
- forward(q, kv=None, condition=None, attn_kwargs=None)¶
Forward pass with per-type Q/output projections and MLP.
Validates
token_specsupfront, then runs the same perceiver block logic as the parent withtoken_specsrouted to the untied sub-modules.- Parameters:
q (torch.Tensor) – Query tensor
(B, S, hidden_dim)with tokens from all domains concatenated intoken_specsorder.kv (torch.Tensor | None) – Key/value tensor from a single source (e.g. geometry encoding). Can be
Nonewhenkv_cacheis provided inattn_kwargs.condition (torch.Tensor | None) – Optional conditioning vector for modulation.
attn_kwargs (dict[str, Any] | None) – Must contain
"token_specs"key. Other entries (kv_cache, RoPE frequencies, masks) are forwarded to attention.
- Returns:
Tuple of (output, kv_cache).
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor] | None]