Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/source/en/training/distributed_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,34 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](

From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.


### Ulysses Anything Attention

The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.

[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].

```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
```

> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.

We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:

| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|--------------------|------------------|-------------|------------------|------------|
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
Comment on lines +367 to +369
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this from a failed eval? Can it be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep the failed results here to demonstrate that Ulysses Anything can handle cases that the standard Ulysses, Ring and USP fail to process.

| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |

From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.

### parallel_config

Pass `parallel_config` during model initialization to enable context parallelism.
Expand Down
89 changes: 85 additions & 4 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Dict, List, Type, Union
from typing import Dict, List, Tuple, Type, Union

import torch
import torch.distributed as dist


if torch.distributed.is_available():
Expand All @@ -27,9 +29,10 @@
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
_gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook


Expand Down Expand Up @@ -208,6 +211,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
)
return x
else:
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)


Expand All @@ -233,7 +240,14 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
if self.parallel_config.ulysses_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)

return output[0] if is_tensor else tuple(output)

Expand Down Expand Up @@ -274,6 +288,73 @@ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_
return tensor


class AllGatherAnythingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
ctx.dim = dim
ctx.group = group
ctx.world_size = dist.get_world_size(group)
ctx.rank = dist.get_rank(group)
gathered_tensor = _all_gather_anything(tensor, dim, group)
return gathered_tensor

@staticmethod
def backward(ctx, grad_output):
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
return grad_splits[ctx.rank], None, None


class PartitionAnythingSharder:
@classmethod
def shard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
assert tensor.size()[dim] >= mesh.size(), (
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
)
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]

@classmethod
def unshard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
return tensor


@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
gather_shapes = []
for i in range(world_size):
rank_shape = list(copy.deepcopy(shape))
rank_shape[dim] = gather_dims[i]
gather_shapes.append(rank_shape)
return gather_shapes


@maybe_allow_in_graph
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
world_size = dist.get_world_size(group=group)

tensor = tensor.contiguous()
shape = tensor.shape
rank_dim = shape[dim]
gather_dims = _gather_size_by_comm(rank_dim, group)

gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)

gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]

dist.all_gather(gathered_tensors, tensor, group=group)
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
return gathered_tensor


def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")
Expand Down
45 changes: 45 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union

import torch
import torch.distributed as dist

from ..utils import get_logger

Expand Down Expand Up @@ -67,6 +68,9 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False

_rank: int = None
_world_size: int = None
Expand Down Expand Up @@ -94,6 +98,11 @@ def __post_init__(self):
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self.ulysses_anything:
if self.ulysses_degree == 1:
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")

@property
def mesh_shape(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -257,3 +266,39 @@ def __repr__(self):
#
# ContextParallelOutput:
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to


# Below are utility functions for distributed communication in context parallelism.
def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
r"""Gather the local size from all ranks.
size: int, local size return: List[int], list of size from all ranks
"""
# NOTE(Serving/CP Safety):
# Do NOT cache this collective result.
#
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
# may legitimately differ across ranks. If we cache based on the *local* `size`,
# different ranks can have different cache hit/miss patterns across time.
#
# That can lead to a catastrophic distributed hang:
# - some ranks hit cache and *skip* dist.all_gather()
# - other ranks miss cache and *enter* dist.all_gather()
# This mismatched collective participation will stall the process group and
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
# timeouts in Ulysses attention).
world_size = dist.get_world_size(group=group)
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(
gathered_sizes,
torch.tensor([size], device=gather_device, dtype=torch.int64),
group=group,
)

gathered_sizes = [s[0].item() for s in gathered_sizes]
# NOTE: DON'T use tolist here due to graph break - Explanation:
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
return gathered_sizes
Loading