Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

Description

  • Add validation to prevent TMA usage when tensor views contain broadcast dimensions

  • Move early return check for no suitable inputs to avoid unnecessary computation

  • Add three test cases to verify TMA is disabled for tensors with dimension size of 1

  • Fixes TMA lowering validation errors caused by merging iteration and broadcast domains

Changes walkthrough

Relevant files
Bug_fix
pointwise_tma.cpp
Add broadcast dimension validation for TMA suitability     

csrc/scheduler/pointwise_tma.cpp

  • Added check in isTvSuitableForTma to return false for tensors with
    broadcast dimensions
  • Moved early return check for bits_per_element == 0 to avoid
    unnecessary computation
  • Added comment explaining the broadcast dimension restrictions for TMA
  • +25/-1   
    Tests
    test_pointwise.cpp
    Add test cases for TMA with broadcast dimensions                 

    tests/cpp/test_pointwise.cpp

  • Added OuterDimOne test case for tensors with first dimension = 1
  • Added InnerDimOne test case for tensors with last dimension = 1
  • Added MiddleDimOne test case for tensors with middle dimension = 1
  • All tests verify TMA is disabled when broadcast dimensions are present
  • +69/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    TMA Broadcast Detection

    The new check in isTvSuitableForTma correctly identifies broadcast dimensions and prevents TMA usage. The implementation looks solid with clear comments explaining the two potential issues: merge of iteration domain with broadcast dimension, and single broadcast domains violating 2D tile assumptions.

    if (std::any_of(
            tv->getLogicalDomain().begin(),
            tv->getLogicalDomain().end(),
            [](const IterDomain* id) { return id->isBroadcast(); })) {
      return false;
    }
    Early Return Logic

    The early return when bits_per_element == 0 is a good optimization that prevents unnecessary computation when no suitable inputs are found for TMA. This change improves both correctness and performance.

    const int64_t bits_per_element = getInputBitsPerElement(prop);
    if (bits_per_element == 0) {
      return nullptr;
    }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 22, 2025

    Greptile Summary

    Prevented TMA (Tensor Memory Accelerator) usage for tensors with broadcast dimensions by adding an early validation check in isTvSuitableForTma.

    • Rejected tensors with broadcast dimensions to avoid merging iteration domains with broadcast domains, which triggers TMA lowering validation errors
    • Moved bits_per_element calculation earlier (before TMA domain splitting) to fail fast when no suitable TMA inputs are found
    • Added three test cases covering edge cases with dimension size of 1 (outer, inner, and middle positions)

    Confidence Score: 4/5

    • This PR is safe to merge with minor considerations
    • The changes correctly address TMA compatibility issues by rejecting broadcast dimensions early. The logic is sound and well-tested. Minor deduction for redundant null check at line 187 that could be removed since the same check exists at line 125
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/scheduler/pointwise_tma.cpp Added early broadcast domain check in isTvSuitableForTma and moved bits_per_element calculation earlier to reject unsuitable tensors before domain splitting
    tests/cpp/test_pointwise.cpp Added three comprehensive test cases (OuterDimOne, InnerDimOne, MiddleDimOne) to verify TMA is disabled for tensors with dimension size of 1

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Pointwise Scheduler
        participant Heuristics as getPointwiseHeuristics
        participant Check as isTvSuitableForTma
        participant Validation as TMA Validation
        
        Scheduler->>Heuristics: Request TMA heuristics
        Heuristics->>Heuristics: Determine break point
        
        Note over Heuristics,Check: NEW: Early validation
        Heuristics->>Check: Check each input tensor
        Check->>Check: Scan logical domain for broadcast
        alt Has broadcast dimensions
            Check-->>Heuristics: return false (not suitable)
            Heuristics-->>Scheduler: return nullptr (disable TMA)
        else No broadcast dimensions
            Check-->>Heuristics: return true (suitable)
            Heuristics->>Heuristics: Calculate bits_per_element
            alt bits_per_element == 0
                Heuristics-->>Scheduler: return nullptr (no suitable inputs)
            else bits_per_element > 0
                Heuristics->>Heuristics: Compute TMA domain dimensions
                Heuristics->>Heuristics: Configure tile sizes
                Heuristics-->>Scheduler: return TMA params
                Scheduler->>Validation: Lower with TMA
                Validation->>Validation: Validate TMA domain == allocation domain
                Note over Validation: No merge errors (broadcast rejected early)
            end
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 187 to 189
    if (bits_per_element == 0) {
    return nullptr;
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: redundant check - bits_per_element == 0 is already checked at line 125-127, so this condition will never be true

    Suggested change
    if (bits_per_element == 0) {
    return nullptr;
    }
    // bits_per_element already validated at line 125-127

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants