noether.core.distributed.gather

Functions

Module Contents

noether.core.distributed.gather.get_device_and_bfloat16supported()
noether.core.distributed.gather.get_bool_gather_supported()
noether.core.distributed.gather.all_gather_grad(x, batch_dim=0)
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

noether.core.distributed.gather.all_gather_nograd(x, batch_dim=0)
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

noether.core.distributed.gather.all_gather_nograd_clipped(x, max_length=None, batch_dim=0)
Parameters:
Return type:

torch.Tensor