Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 9, 2025

What does this PR do?

Adds CP support to the kernels-based attention backends.

Our CP support is quickly gaining traction. Currently, we have a few attention backends that are fully based on kernels. In order for their adoption to grow and make them a bit more complete in terms of feature parity, I think we should make them CP-compatible, too.

Code to test:
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig, AutoModel


CKPT_ID = "black-forest-labs/FLUX.1-dev"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified"],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--attn-backend",
        type=str,
        choices=["flash_hub", "_flash_3_hub", "sage_hub"],
        default="flash_hub",
        help="Attention backend to use.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()
    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    else:
        cp_config = ContextParallelConfig(ulysses_degree=world_size)

    transformer = AutoModel.from_pretrained(
        CKPT_ID, 
        subfolder="transformer", 
        torch_dtype=torch.bfloat16, 
        parallel_config=cp_config
    )

    pipeline = DiffusionPipeline.from_pretrained(
        CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
    ).to(device)
    pipeline.transformer.set_attention_backend(args.attn_backend)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    generator = torch.Generator().manual_seed(42)
    image = pipeline(
        prompt,
        guidance_scale=3.5,
        num_inference_steps=50,
        generator=generator,
    ).images[0]

    if dist.get_rank() == 0:
        image.save(f"output_{args.cp_backend}_{args.attn_backend}.png")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

Outputs:

FA2+ Ulysses FA3 + Ulysses SAGE + Ulysses
Ring Ulysses Unified

@sayakpaul sayakpaul requested a review from DN6 December 9, 2025 09:47
@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Dec 9, 2025
Comment on lines +280 to +281
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only FA2 provides these.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants