noether.modeling.modules.untied¶
Classes¶
Configuration for a linear layer with per-type (untied) weight banks. |
|
Configuration for multi-head attention with per-type (untied) QKV and output projections. |
|
Configuration for an MLP with per-type (untied) weights. |
|
Configuration for a transformer block with per-type (untied) attention and MLP weights. |
|
Configuration for a perceiver block with per-type (untied) Q/output projections and MLP weights. |
|
Linear layer with per-domain weight banks. |
|
Multi-head attention with per-type QKV and output projections. |
|
Perceiver cross-attention with per-type Q and output projections. |
|
Multi-layer perceptron with per-type weights. |
|
Pre-norm transformer block with per-type (untied) attention and MLP weights. |
|
Perceiver block with per-type (untied) Q/output projections and MLP weights. |
Module Contents¶
- class noether.modeling.modules.untied.UntiedLinearConfig(/, **data)¶
Bases:
pydantic.BaseModelConfiguration for a linear layer with per-type (untied) weight banks.
Composes a
LinearProjectionConfig(shared across types) with anum_typesfield: each token type gets its own independent weight matrix with the geometry described by the linear projection config.Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- linear_projection: noether.modeling.modules.layers.linear_projection.LinearProjectionConfig¶
Shared geometry (input/output dims, bias, init) for every per-type weight bank.
- class noether.modeling.modules.untied.UntiedMixedAttentionConfig(/, **data)¶
Bases:
noether.modeling.modules.attention.anchor_attention.mixed.MixedAttentionConfigConfiguration for multi-head attention with per-type (untied) QKV and output projections.
Extends
MixedAttentionConfigwith anum_typesfield: the QKV and output projections areUntiedLinearlayers so each token type gets its own projection weights. Attention itself is still computed across all tokens viaMixedAttention._process_pattern_batched().Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- projection_config()¶
Configuration for the per-type QKV and output projections.
- Return type:
- class noether.modeling.modules.untied.UntiedMLPConfig(/, **data)¶
Bases:
pydantic.BaseModelConfiguration for an MLP with per-type (untied) weights.
Composes an
MLPConfig(architecture: dims, activation, init) with anum_typesfield. The untied MLP mirrorsMLP’s topology (input -> [hidden]*(num_layers+1) -> outputwith activations between layers) but usesUntiedLinearfor every linear layer.Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- mlp: noether.modeling.modules.mlp.mlp.MLPConfig¶
Underlying MLP architecture (dims, activation, init).
- class noether.modeling.modules.untied.UntiedTransformerBlockConfig(/, **data)¶
Bases:
pydantic.BaseModelConfiguration for a transformer block with per-type (untied) attention and MLP weights.
Composes a
TransformerBlockConfig(shared layout: dims, heads, layer scale, drop path, etc.) with anum_typesfield. Both sub-layers have per-type weights:UntiedMultiHeadAttentionfor attention andUntiedMLPfor the feed-forward.Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- transformer_block: noether.modeling.modules.blocks.transformer.TransformerBlockConfig¶
Shared transformer-block layout (dims, heads, layer scale, drop path, etc.).
- attention_config()¶
Configuration for the UntiedMultiHeadAttention sub-layer.
- Return type:
- untied_mlp_config()¶
Configuration for the UntiedMLP sub-layer.
- Return type:
- class noether.modeling.modules.untied.UntiedPerceiverBlockConfig(/, **data)¶
Bases:
pydantic.BaseModelConfiguration for a perceiver block with per-type (untied) Q/output projections and MLP weights.
Composes a
PerceiverBlockConfig(shared layout: dims, heads, layer scale, drop path, etc.) with anum_typesfield. The Q and output projections inPerceiverAttentionbecome per-type viaUntiedLinear, while the KV projection stays shared (it operates on a single geometry encoding). The MLP is also replaced withUntiedMLP.Create a new model by parsing and validating input data from keyword arguments.
Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.
self is explicitly positional-only to allow self as a field name.
- Parameters:
data (Any)
- perceiver_block_config: noether.modeling.modules.blocks.perceiver.PerceiverBlockConfig¶
Shared perceiver-block layout (dims, heads, kv_dim, layer scale, drop path, etc.).
- untied_mlp_config()¶
Configuration for the UntiedMLP sub-layer.
- Return type:
- class noether.modeling.modules.untied.UntiedLinear(config)¶
Bases:
torch.nn.ModuleLinear layer with per-domain weight banks.
Groups token specs by
TokenSpec.domainand applies a separateF.linearper domain. Automatically picks the fastest strategy at runtime:torch._grouped_mm(CUDA, equal-size groups)Reshape + single
torch.bmm(equal-size groups, any device)Padded
torch.bmm(moderate skew)Split +
F.linearloop (heavy skew or very few groups)
Per-type weights are stored as independent 2D
nn.Parameterentries in annn.ParameterList(one matrix per type). The bmm-based fast paths stack them into a 3D tensor on the fly. Storing them as 2D is what letstorch.optim.Muon(which rejects non-2D parameters) update each type’s weight matrix independently.Domains must be strictly consecutive in
token_specs.- Parameters:
config (UntiedLinearConfig) – Number of types and shared linear-projection geometry.
- bias: torch.nn.ParameterList | None¶
- num_types¶
- init_weights¶
- input_dim¶
- output_dim¶
- weight¶
- reset_parameters()¶
Initialize the per-type weight banks following
LinearProjection’s scheme.Supported modes (from
config.linear_projection.init_weights):"torch"(PyTorchnn.Lineardefaults, applied per type),"truncnormal"/"truncnormal002"(truncated normal, std=0.02, zero bias),"zeros"(zero weight and bias).- Return type:
None
- forward(x, token_specs)¶
Apply per-domain linear projections.
- Parameters:
x (torch.Tensor) – Input tensor
(B, S, D_in)with tokens concatenated intoken_specsorder. Specs for the same domain must be adjacent.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications whose sizes sum to
S.
- Returns:
Output tensor
(B, S, D_out)in the same positional order asx.- Return type:
- class noether.modeling.modules.untied.UntiedMixedAttention(config)¶
Bases:
noether.modeling.modules.attention.anchor_attention.mixed.MixedAttentionMulti-head attention with per-type QKV and output projections.
Each token type has its own QKV and output projection weights (via
UntiedLinear). The attention computation is pattern-aware, mirroringMixedAttention:If
attention_patternsis supplied (typically by a wrappingMultiBranchAnchorAttentionsuch asSelfAnchorAttention/CrossAnchorAttention/JointAnchorAttention), those patterns are honored.If no patterns are supplied (standalone usage), it falls back to a single all-to-all pattern — every token attends to every other token.
This lets the same module act as either a drop-in pattern-free attention or the inner workhorse of a multi-branch anchor attention, while keeping QKV and output projection weights untied per token type.
- Parameters:
config (UntiedMixedAttentionConfig) – Configuration specifying attention dims and
num_types.config – Configuration for the MixedAttention module. See
MixedAttentionConfigfor the available options.
- q¶
- k¶
- v¶
- proj¶
- forward(x, token_specs, attention_patterns, key_padding_mask=None, freqs=None, kv_cache=None)¶
Apply attention with per-type QKV/output projections.
Positional argument order matches
MixedAttention.forward()so this module is a drop-in replacement for the sharedMixedAttentionthatMultiBranchAnchorAttentionowns — a wrapping anchor attention can callself.mixed_attention(x, token_specs, patterns, ...)unchanged.- Parameters:
x (torch.Tensor) – Input tensor
(B, S, D)with tokens concatenated intoken_specsorder.Smust equal the sum oftoken_specssizes.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications defining the input structure.
attention_patterns (collections.abc.Sequence[noether.core.schemas.modules.attention.AttentionPattern]) – Optional attention patterns. When
None, falls back to a single all-to-all pattern over every name intoken_specs(standalone usage).key_padding_mask (torch.Tensor | None) – Optional boolean mask
(B, S).True= real token.freqs (torch.Tensor | None) – RoPE frequencies for positional encoding.
kv_cache (dict[str, torch.Tensor] | None) – Optional dictionary for caching key/value tensors. Not yet supported.
- Returns:
Output tensor
(B, S, D).- Return type:
- class noether.modeling.modules.untied.UntiedPerceiverAttention(config, num_types)¶
Bases:
noether.modeling.modules.attention.perceiver.PerceiverAttentionPerceiver cross-attention with per-type Q and output projections.
The Q and output projections are
UntiedLinearlayers (one weight bank per token type), while the KV projection remains a sharednn.Linearsince it operates on a single source (e.g. geometry encoding).- Parameters:
config (noether.core.schemas.modules.attention.AttentionConfig) – Attention configuration (
AttentionConfig-compatible).num_types (int) – Number of distinct token types for per-type projections.
config – Configuration for the PerceiverAttention module. See
AttentionConfigfor available options.
- q¶
- proj¶
- forward(q, token_specs, kv=None, attn_mask=None, q_freqs=None, k_freqs=None, kv_cache=None)¶
Forward pass with per-type Q and output projections.
- Parameters:
q (torch.Tensor) – Query tensor
(B, S, hidden_dim)with tokens concatenated intoken_specsorder.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications defining the per-type structure.
kv (torch.Tensor | None) – Key/value tensor from a single source (e.g. geometry). Can be
Nonewhenkv_cacheis provided.attn_mask (torch.Tensor | None) – Optional attention mask.
q_freqs (torch.Tensor | None) – RoPE frequencies for queries.
k_freqs (torch.Tensor | None) – RoPE frequencies for keys.
kv_cache (dict[str, torch.Tensor] | None) – Cached K/V tensors from a previous forward pass.
- Returns:
Tuple of (output, new_kv_cache).
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor] | None]
- class noether.modeling.modules.untied.UntiedMLP(config)¶
Bases:
torch.nn.ModuleMulti-layer perceptron with per-type weights.
Mirrors the topology of
MLP(input -> [hidden]*(num_layers+1) -> output, with activations between layers), but usesUntiedLinearfor every linear layer so each token type learns independent weights.- Parameters:
config (UntiedMLPConfig) – Configuration for the untied MLP.
- layers¶
- activation¶
- forward(x, token_specs)¶
Apply the untied MLP.
- Parameters:
x (torch.Tensor) – Input tensor
(B, S, mlp_config.input_dim).token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Token specifications defining the input structure.
- Returns:
Output tensor
(B, S, mlp_config.output_dim).- Return type:
- class noether.modeling.modules.untied.UntiedTransformerBlock(config)¶
Bases:
noether.modeling.modules.blocks.transformer.TransformerBlockPre-norm transformer block with per-type (untied) attention and MLP weights.
Same architecture and control flow as
TransformerBlock; only the attention and MLP sub-modules differ. The parent’s__init__builds all the shared plumbing (norms, modulation, layer scale, drop path, and the full forward pass) — including whichever attention module the configuredattention_constructorselects. This subclass then:Replaces
self.mlpwithUntiedMLP.Injects per-type QKV/output projections into the attention sub-layer without disturbing its attention pattern:
When the parent built a
MultiBranchAnchorAttention(SelfAnchorAttention/CrossAnchorAttention/JointAnchorAttention), only its innermixed_attentionis swapped forUntiedMixedAttention. The outer wrapper keeps emitting the correct per-branch attention patterns — soself_untiedbehaves likeselfwith untied weights (not like a joint attention).Otherwise (default
dot_productattention, or any other non-multi-branch constructor)self.attention_blockis replaced outright withUntiedMixedAttention, which falls back to a single all-to-all pattern.
Overrides
_mlp_forward()to routetoken_specsfromattn_kwargsintoUntiedMLP.
Attention receives
token_specsautomatically because the parent’s forward passes**attn_kwargstoself.attention_block, which in turn forwards them on to the innerUntiedMixedAttention.- Parameters:
config (UntiedTransformerBlockConfig) – Configuration for the untied transformer block.
config – Configuration for the transformer block. See
TransformerBlockConfigfor available options.
- num_types¶
- mlp¶
- forward(x, condition=None, attn_kwargs=None)¶
Validate
token_specsupfront, then delegate toTransformerBlock.forward().The untied weight banks are indexed by name-order in
token_specs, so duplicates would silently merge types and missing types would index out of range. We check here (before attention runs) to surface a clear error.- Parameters:
x (torch.Tensor)
condition (torch.Tensor | None)
- Return type:
tuple[torch.Tensor, dict[str, dict[str, torch.Tensor]] | None]
- class noether.modeling.modules.untied.UntiedPerceiverBlock(config)¶
Bases:
noether.modeling.modules.blocks.perceiver.PerceiverBlockPerceiver block with per-type (untied) Q/output projections and MLP weights.
Same architecture and control flow as
PerceiverBlock; the Q and output projections in the attention module become per-type viaUntiedPerceiverAttention, and the feed-forward MLP becomesUntiedMLP. The KV projection remains shared since it operates on a single source (e.g. geometry encoding).token_specsmust be provided viaattn_kwargs["token_specs"].- Parameters:
config (UntiedPerceiverBlockConfig) – Configuration for the untied perceiver block.
config – Configuration of the PerceiverBlock. See
PerceiverBlockConfigoptions. (for available)
- num_types¶
- attn¶
- mlp¶
- forward(q, kv=None, condition=None, attn_kwargs=None)¶
Forward pass with per-type Q/output projections and MLP.
Validates
token_specsupfront, then runs the same perceiver block logic as the parent withtoken_specsrouted to the untied sub-modules.- Parameters:
q (torch.Tensor) – Query tensor
(B, S, hidden_dim)with tokens from all domains concatenated intoken_specsorder.kv (torch.Tensor | None) – Key/value tensor from a single source (e.g. geometry encoding). Can be
Nonewhenkv_cacheis provided inattn_kwargs.condition (torch.Tensor | None) – Optional conditioning vector for modulation.
attn_kwargs (dict[str, Any] | None) – Must contain
"token_specs"key. Other entries (kv_cache, RoPE frequencies, masks) are forwarded to attention.
- Returns:
Tuple of (output, kv_cache).
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor] | None]