noether.modeling.modules.attention.anchor_attention.mixed

Classes

MixedAttention

Mixed attention with a selectable implementation for performance or readability.

Module Contents

class noether.modeling.modules.attention.anchor_attention.mixed.MixedAttention(config)

Bases: noether.modeling.modules.attention.DotProductAttention

Mixed attention with a selectable implementation for performance or readability.

This module allows for structured attention patterns where different groups of tokens (defined by TokenSpec) have specific interaction patterns (defined by AttentionPattern). Instead of full self-attention, you can specify, for example, that one type of token can only attend to itself, while another can attend to all tokens.

This is achieved by splitting the main Q, K, V tensors based on the token specs and then performing separate attention computations for each pattern.

Example input structure (forward pass signature) for implementing Anchor Attention:

x = torch.cat([surface_anchors, surface_queries, volume_anchors, volume_queries], dim=1)  # sequence dim
token_specs = [
    TokenSpec("surface_anchors", 100),
    TokenSpec("surface_queries", 50),
    TokenSpec("volume_anchors", 80),
    TokenSpec("volume_queries", 60),
]
attention_patterns = [
    AttentionPattern(query_tokens=["surface_anchors", "surface_queries"], key_value_tokens=["surface_anchors"]),
    AttentionPattern(query_tokens=["volume_anchors", "volume_queries"], key_value_tokens=["volume_anchors"]),
]
Parameters:

config (noether.core.schemas.modules.attention.MixedAttentionConfig) – Configuration for the MixedAttention module. See MixedAttentionConfig for the available options.

forward(x, token_specs, attention_patterns, attention_mask=None, freqs=None)

Apply mixed attention with flexible token-name-based patterns.

Parameters:
Return type:

torch.Tensor