noether.modeling.modules.attention.anchor_attention.mixed¶
Classes¶
Mixed attention with a selectable implementation for performance or readability. |
Module Contents¶
- class noether.modeling.modules.attention.anchor_attention.mixed.MixedAttention(config)¶
Bases:
noether.modeling.modules.attention.DotProductAttentionMixed attention with a selectable implementation for performance or readability.
This module allows for structured attention patterns where different groups of tokens (defined by TokenSpec) have specific interaction patterns (defined by AttentionPattern). Instead of full self-attention, you can specify, for example, that one type of token can only attend to itself, while another can attend to all tokens.
This is achieved by splitting the main Q, K, V tensors based on the token specs and then performing separate attention computations for each pattern.
Supports KV caching for efficient inference. When a
TokenSpechassize=None, its key/value representations are loaded from the providedkv_cacheinstead of being computed from the input tensor.Example input structure (forward pass signature) for implementing Anchor Attention:
x = torch.cat([surface_anchors, surface_queries, volume_anchors, volume_queries], dim=1) # sequence dim token_specs = [ TokenSpec("surface_anchors", 100), TokenSpec("surface_queries", 50), TokenSpec("volume_anchors", 80), TokenSpec("volume_queries", 60), ] attention_patterns = [ AttentionPattern(query_tokens=["surface_anchors", "surface_queries"], key_value_tokens=["surface_anchors"]), AttentionPattern(query_tokens=["volume_anchors", "volume_queries"], key_value_tokens=["volume_anchors"]), ]
- Parameters:
config (noether.core.schemas.modules.attention.MixedAttentionConfig) – Configuration for the MixedAttention module. See
MixedAttentionConfigfor the available options.
- forward(x, token_specs, attention_patterns, key_padding_mask=None, freqs=None, kv_cache=None)¶
Apply mixed attention with flexible token-name-based patterns.
- Parameters:
x (torch.Tensor) – Input tensor [batch_size, n_tokens, dim]. Only contains tokens with
size is not Noneintoken_specs.token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Sequence of token specifications defining the input structure. Tokens with
size=Noneare loaded fromkv_cache.attention_patterns (collections.abc.Sequence[noether.core.schemas.modules.attention.AttentionPattern]) – Sequence of attention patterns to apply. Each pattern defines which token groups (queries) attend to which other token groups (keys/values). The provided patterns must be exhaustive and non-overlapping. This means every input (non-cached) token group must be a query in exactly one pattern.
key_padding_mask (torch.Tensor | None) – Optional boolean mask of shape
(batch_size, n_tokens).Trueindicates a real token;Falseindicates a padding token that should not be attended to. The mask is sliced per attention pattern to cover only the key/value tokens of that pattern. Not supported when using KV cache (cached tokens are assumed to be valid).freqs (torch.Tensor | None) – RoPE frequencies for positional encoding (only for input tokens).
kv_cache (dict[str, dict[str, torch.Tensor]] | None) – KV cache from a previous forward pass. Structure:
{token_name: {"k": tensor, "v": tensor}}.
- Returns:
Tuple of (output, new_kv_cache). Output has the same shape as
x(only input tokens).new_kv_cachecontains anchor K/V for future cached inference, orNonewhen using cached tokens.- Return type:
tuple[torch.Tensor, dict[str, dict[str, torch.Tensor]] | None]