Skip to content

Conversation

@YangKai0616
Copy link
Contributor

What does this PR do?

torch.histc with deterministic algorithms enabled behaves differently across devices:

CPU: only supports float input
CUDA: only supports int input

This PR updates grouped_mm_experts_forward to use the appropriate dtype based on device type, enabling compatibility with torch.use_deterministic_algorithms(True) on CUDA.

Simple reproduction script:

import torch

def test_histc_deterministic(device: str):
    print(f"\n=== Testing on {device.upper()} ===")
    
    num_experts = 8
    expert_ids = torch.randint(0, num_experts, (100,), device=device)
    
    torch.use_deterministic_algorithms(True)
    
    # Test float input
    try:
        result = torch.histc(expert_ids.float(), bins=num_experts, min=0, max=num_experts - 1)
        print(f"✅ float input works with deterministic mode")
    except RuntimeError as e:
        print(f"❌ float input fails: {e}")
    
    # Test int input
    try:
        result = torch.histc(expert_ids, bins=num_experts, min=0, max=num_experts - 1)
        print(f"✅ int input works with deterministic mode")
    except (RuntimeError, NotImplementedError) as e:
        print(f"❌ int input fails: {e}")
    
    torch.use_deterministic_algorithms(False)

if __name__ == "__main__":
    print(f"PyTorch version: {torch.__version__}")
    
    test_histc_deterministic("cpu")
    
    if torch.cuda.is_available():
        test_histc_deterministic("cuda")

Hi @vasqu , please help review this PR, thanks!

# using histc instead of bincount to avoid cuda graph issues
num_tokens_per_expert = torch.histc(expert_ids_g.float(), bins=self.num_experts, min=0, max=self.num_experts - 1)
# With deterministic algorithms, CPU only supports float input, CUDA only supports int input.
histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also apply .int() in the case of cuda to be sure

Suggested change
histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g
histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int()

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! thanks for the fix !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants