noether.modeling.modules.decoders

Submodules

Classes

DeepPerceiverDecoder

A deep Perceiver decoder module. Can be configured with different number of layers and hidden dimensions.

Package Contents

class noether.modeling.modules.decoders.DeepPerceiverDecoder(config)

Bases: torch.nn.Module

A deep Perceiver decoder module. Can be configured with different number of layers and hidden dimensions. However, it should be noted that this layer is not a full-fledged Perceiver, since it only has a cross-attention mechanism.

Parameters:

config (noether.core.schemas.modules.decoders.DeepPerceiverDecoderConfig) – Configuration for the DeepPerceiverDecoder module. See DeepPerceiverDecoderConfig for available options.

blocks
forward(kv, queries, attn_kwargs=None, condition=None)

Forward pass of the model.

Parameters:
  • kv (torch.Tensor) – The key-value tensor (batch_size, num_latent_tokens, dim).

  • queries (torch.Tensor) – The query tensor (batch_size, num_output_queries, dim).

  • attn_kwargs (dict[str, Any] | None) – Dict with arguments for the attention (such as the attention mask or rope frequencies). Defaults to None.

  • condition (torch.Tensor | None) – Optional conditioning tensor that can be used in the attention mechanism. This can be used to pass additional conditioning information, etc.

Returns:

The predictions as sparse tensor (batch_size * num_output_pos, num_out_values).

Return type:

torch.Tensor