noether.modeling.modules.attention.perceiver¶
Classes¶
Perceiver style attention module. This module is similar to a cross-attention module. |
Module Contents¶
- class noether.modeling.modules.attention.perceiver.PerceiverAttention(config)¶
Bases:
torch.nn.ModulePerceiver style attention module. This module is similar to a cross-attention module.
Supports KV caching: when
kv_cacheis provided, the projected K/V tensors (with RoPE already applied) are loaded from the cache instead of being recomputed fromkv.- Parameters:
config (noether.core.schemas.modules.AttentionConfig) – Configuration for the PerceiverAttention module. See
AttentionConfigfor available options.
- num_heads = None¶
- head_dim¶
- init_weights = None¶
- use_rope = None¶
- kv¶
- q¶
- proj¶
- dropout = None¶
- proj_dropout¶
- forward(q, kv=None, attn_mask=None, q_freqs=None, k_freqs=None, kv_cache=None)¶
Forward function of the PerceiverAttention module.
- Parameters:
q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, hidden_dim).
kv (torch.Tensor | None) – Key/value tensor, shape (batch size, number of latent tokens, kv_dim). Can be
Nonewhenkv_cacheis provided.attn_mask (torch.Tensor | None) – When applying causal attention, an attention mask is required. Defaults to None.
q_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries. None if use_rope=False.
k_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of keys. None if use_rope=False. Not needed when loading from
kv_cache(RoPE was already applied).kv_cache (dict[str, torch.Tensor] | None) – Cached K/V tensors from a previous forward pass. Structure:
{"k": tensor, "v": tensor}. When provided,kvandk_freqsare ignored.
- Returns:
Tuple of (output, new_kv_cache).
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor] | None]