noether.modeling.modules.attention.anchor_attention.mixed

Classes

MixedAttention

Mixed attention with a selectable implementation for performance or readability.

Module Contents

class noether.modeling.modules.attention.anchor_attention.mixed.MixedAttention(config)

Bases: noether.modeling.modules.attention.DotProductAttention

Mixed attention with a selectable implementation for performance or readability.

This module allows for structured attention patterns where different groups of tokens (defined by TokenSpec) have specific interaction patterns (defined by AttentionPattern). Instead of full self-attention, you can specify, for example, that one type of token can only attend to itself, while another can attend to all tokens.

This is achieved by splitting the main Q, K, V tensors based on the token specs and then performing separate attention computations for each pattern.

Supports KV caching for efficient inference. When a TokenSpec has size=None, its key/value representations are loaded from the provided kv_cache instead of being computed from the input tensor.

Example input structure (forward pass signature) for implementing Anchor Attention:

x = torch.cat([surface_anchors, surface_queries, volume_anchors, volume_queries], dim=1)  # sequence dim
token_specs = [
    TokenSpec("surface_anchors", 100),
    TokenSpec("surface_queries", 50),
    TokenSpec("volume_anchors", 80),
    TokenSpec("volume_queries", 60),
]
attention_patterns = [
    AttentionPattern(query_tokens=["surface_anchors", "surface_queries"], key_value_tokens=["surface_anchors"]),
    AttentionPattern(query_tokens=["volume_anchors", "volume_queries"], key_value_tokens=["volume_anchors"]),
]
Parameters:

config (noether.core.schemas.modules.attention.MixedAttentionConfig) – Configuration for the MixedAttention module. See MixedAttentionConfig for the available options.

forward(x, token_specs, attention_patterns, key_padding_mask=None, freqs=None, kv_cache=None)

Apply mixed attention with flexible token-name-based patterns.

Parameters:
  • x (torch.Tensor) – Input tensor [batch_size, n_tokens, dim]. Only contains tokens with size is not None in token_specs.

  • token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Sequence of token specifications defining the input structure. Tokens with size=None are loaded from kv_cache.

  • attention_patterns (collections.abc.Sequence[noether.core.schemas.modules.attention.AttentionPattern]) – Sequence of attention patterns to apply. Each pattern defines which token groups (queries) attend to which other token groups (keys/values). The provided patterns must be exhaustive and non-overlapping. This means every input (non-cached) token group must be a query in exactly one pattern.

  • key_padding_mask (torch.Tensor | None) – Optional boolean mask of shape (batch_size, n_tokens). True indicates a real token; False indicates a padding token that should not be attended to. The mask is sliced per attention pattern to cover only the key/value tokens of that pattern. Not supported when using KV cache (cached tokens are assumed to be valid).

  • freqs (torch.Tensor | None) – RoPE frequencies for positional encoding (only for input tokens).

  • kv_cache (dict[str, dict[str, torch.Tensor]] | None) – KV cache from a previous forward pass. Structure: {token_name: {"k": tensor, "v": tensor}}.

Returns:

Tuple of (output, new_kv_cache). Output has the same shape as x (only input tokens). new_kv_cache contains anchor K/V for future cached inference, or None when using cached tokens.

Return type:

tuple[torch.Tensor, dict[str, dict[str, torch.Tensor]] | None]