noether.modeling.modules.attention.anchor_attention.mixed

Classes

MixedAttention

Mixed attention with a selectable implementation for performance or readability.

Module Contents

class noether.modeling.modules.attention.anchor_attention.mixed.MixedAttention(config)

Bases: noether.modeling.modules.attention.DotProductAttention

Mixed attention with a selectable implementation for performance or readability.

This module allows for structured attention patterns where different groups of tokens (defined by TokenSpec) have specific interaction patterns (defined by AttentionPattern). Instead of full self-attention, you can specify, for example, that one type of token can only attend to itself, while another can attend to all tokens.

This is achieved by splitting the main Q, K, V tensors based on the token specs and then performing separate attention computations for each pattern.

Example input structure (forward pass signature) for implementing Anchor Attention:

x = torch.cat([surface_anchors, surface_queries, volume_anchors, volume_queries], dim=1) # sequence dim token_specs = [

TokenSpec(“surface_anchors”, 100), TokenSpec(“surface_queries”, 50), TokenSpec(“volume_anchors”, 80), TokenSpec(“volume_queries”, 60),

] attention_patterns = [

AttentionPattern(query_tokens=[“surface_anchors”, “surface_queries”], key_value_tokens=[“surface_anchors”]), AttentionPattern(query_tokens=[“volume_anchors”, “volume_queries”], key_value_tokens=[“volume_anchors”]),

]

Parameters:

config (noether.core.schemas.modules.attention.MixedAttentionConfig) – Configuration for the MixedAttention module.

forward(x, token_specs, attention_patterns, attention_mask=None, freqs=None)

Apply mixed attention with flexible token-name-based patterns.

Parameters:
  • x (torch.Tensor) – Input tensor [batch_size, n_tokens, dim]

  • token_specs (collections.abc.Sequence[noether.core.schemas.modules.attention.TokenSpec]) – Sequence of token specifications defining the input structure: assumes that the input x is a concatenation of tokens in the order of token_specs.

  • attention_patterns (collections.abc.Sequence[noether.core.schemas.modules.attention.AttentionPattern]) – Sequence of attention patterns to apply. Each pattern defines which token groups (queries) attend to which other token groups (keys/values). The provided patterns must be exhaustive and non-overlapping. This means every token group defined in token_specs must be a query in exactly one pattern.

  • attention_mask (torch.Tensor | None) – Optional attention mask (not currently supported)

  • freqs (torch.Tensor | None) – RoPE frequencies for positional encoding

Return type:

torch.Tensor