noether.modeling.modules.attention.perceiver

Classes

PerceiverAttention

Perceiver style attention module. This module is similar to a cross-attention module.

Module Contents

class noether.modeling.modules.attention.perceiver.PerceiverAttention(config)

Bases: torch.nn.Module

Perceiver style attention module. This module is similar to a cross-attention module.

Supports KV caching: when kv_cache is provided, the projected K/V tensors (with RoPE already applied) are loaded from the cache instead of being recomputed from kv.

Parameters:

config (noether.core.schemas.modules.AttentionConfig) – Configuration for the PerceiverAttention module. See AttentionConfig for available options.

num_heads = None
head_dim
init_weights = None
use_rope = None
kv
q
proj
dropout = None
proj_dropout
forward(q, kv=None, attn_mask=None, q_freqs=None, k_freqs=None, kv_cache=None)

Forward function of the PerceiverAttention module.

Parameters:
  • q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, hidden_dim).

  • kv (torch.Tensor | None) – Key/value tensor, shape (batch size, number of latent tokens, kv_dim). Can be None when kv_cache is provided.

  • attn_mask (torch.Tensor | None) – When applying causal attention, an attention mask is required. Defaults to None.

  • q_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries. None if use_rope=False.

  • k_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of keys. None if use_rope=False. Not needed when loading from kv_cache (RoPE was already applied).

  • kv_cache (dict[str, torch.Tensor] | None) – Cached K/V tensors from a previous forward pass. Structure: {"k": tensor, "v": tensor}. When provided, kv and k_freqs are ignored.

Returns:

Tuple of (output, new_kv_cache).

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor] | None]