noether.modeling.modules.attention.perceiver

Classes

PerceiverAttention

Perceiver style attention module. This module is similar to a cross-attention modules.

Module Contents

class noether.modeling.modules.attention.perceiver.PerceiverAttention(config)

Bases: torch.nn.Module

Perceiver style attention module. This module is similar to a cross-attention modules.

Initialize the PerceiverAttention module.

Parameters:

config (noether.core.schemas.modules.AttentionConfig) – configuration of the attention module.

num_heads = None
head_dim
init_weights = None
use_rope = None
kv
q
proj
dropout = None
proj_dropout
forward(q, kv, attn_mask=None, q_freqs=None, k_freqs=None)

Forward function of the PerceiverAttention module.

Parameters:
  • q (torch.Tensor) – Query tensor, shape (batch size, number of points/tokens, hidden_dim).

  • kv (torch.Tensor) – Key/value tensor, shape (batch size, number of latent tokens, hidden_dim).

  • attn_mask (torch.Tensor | None) – When applying causal attention, an attention mask is required. Defaults to None.

  • q_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of queries. None if use_rope=False.

  • k_freqs (torch.Tensor | None) – Frequencies for Rotary Positional Embedding (RoPE) of keys. None if use_rope=False.

Returns:

Returns the output of the perceiver attention module.

Return type:

torch.Tensor