Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 15, 2025

This PR just adds asNested as a TensorView operation. It works like a reshape and produces an output tensor with an input IterDomain converted to a component IterDomain and a RaggedIterDomain by using RaggedIterDomain::partition.

@github-actions
Copy link

github-actions bot commented Dec 15, 2025

Review updated until commit 5b99432

Description

  • Adds new asNested TensorView operation for creating nested tensor representations

  • Implements RaggedIterDomain partitioning to convert regular dimensions into component and ragged dimensions

  • Includes comprehensive test coverage with validation for edge cases and error conditions

  • Provides detailed API documentation explaining usage and behavior

Changes walkthrough

Relevant files
Enhancement
alias.cpp
Implement asNested function for nested tensor creation     

csrc/ops/alias.cpp

  • Implements asNested function that partitions specified tensor
    dimensions into component and ragged dimensions
  • Validates inputs (null checks, 1D extents requirement) and handles
    dimension wrapping
  • Creates new TensorView with partitioned structure using
    RaggedIterDomain::partition
  • Uses LoadStoreOp to represent the nesting operation with proper error
    handling
  • +61/-0   
    alias.h
    Add asNested API declaration and documentation                     

    csrc/ops/alias.h

  • Adds function declaration for asNested with comprehensive
    documentation
  • Documents parameters, return type, and usage examples
  • Explains the partitioning behavior and expected tensor shapes
  • +21/-0   
    Tests
    test_ragged_iter_domain.cpp
    Add comprehensive tests for asNested operation                     

    tests/cpp/test_ragged_iter_domain.cpp

  • Adds basic functionality test for asNested with 2D tensor input
  • Tests partitioning on different dimensions (middle dimension of 3D
    tensor)
  • Tests edge case with 1D tensor input
  • Includes validation tests for null inputs and multi-dimensional
    extents
  • +142/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Implementation correctness

    The implementation correctly uses RaggedIterDomain::partition to split the specified dimension into component and ragged IterDomains. The logic for building the logical domain by replacing the ragged_dim with component_id and ragged_id appears sound. The use of LoadStoreOp for representing the nesting operation is consistent with the comment about potentially using a specific TensorView op like ReshapeOp in the future.

    Input validation

    Input validation is comprehensive: null checks for both data and extents tensors, and validation that extents is 1D. The wrapDim call for ragged_dim handling appears correct for dimension indexing.

    Test coverage

    Tests cover basic functionality, different dimensions, 1D tensors, and validation scenarios. The tests verify the correct structure of the output nested tensor including component and ragged IterDomains, and proper error handling for invalid inputs.

    Test failures

    • (High, 95) CUDA driver too old for runtime on dlcluster_h100 (nvFuser tests failing)

      Test Name H100 Source
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/1024_3_1_0 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_2_0_1 Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_3_0_0 Link
      BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread5 Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype___bfloat Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_16_dtype___bfloat Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_4_dtype___bfloat Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_5_dtype___bfloat Link
      CombineMulSumAsMmaTestWithLayout.UseMatmulScheduler/TN Link
      General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KN_512_256_128_MmaMacro_m128_n128_k16_tma_store Link
      ... with 85 more test failures omitted. Check internal logs.
    • (High, 16) CUDA driver too old on dlcluster_h100 causing RNGTest failure

      Test Name H100 Source
      .thunder.tests.opinfos
      .thunder.tests.test_apex_cross_entropy_executor
      .thunder.tests.test_auto_register_torchops
      .thunder.tests.test_cudnn_executor
      .thunder.tests.test_einops
      .thunder.tests.test_grad
      .thunder.tests.test_nvfuser
      .thunder.tests.test_ops
      .thunder.tests.test_sdpaex_executor
      .thunder.tests.test_torch_compile_executor
      ... with 6 more test failures omitted. Check internal logs.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 15, 2025

    Greptile Summary

    Implements the asNested TensorView operation that converts a regular tensor into a nested tensor representation by partitioning a specified dimension into component and ragged dimensions.

    Key changes:

    • Added asNested(data, extents, ragged_dim) API in csrc/ops/alias.h with comprehensive documentation
    • Implementation in csrc/ops/alias.cpp follows established patterns: validates inputs, creates new TensorView with partitioned domain structure, uses LoadStoreOp for aliasing
    • Comprehensive test suite covering basic functionality, different dimensions, 1D tensors, and error cases
    • Bounds validation for ragged_dim parameter handled by wrapDim utility (contrary to previous thread comment)

    Implementation details:

    • Root domain clones input's logical domain
    • Logical domain replaces target dimension with (component, ragged) pair using RaggedIterDomain::partition
    • Uses LoadStoreOp (similar to reshape) rather than a dedicated op type (noted in comment for future consideration)

    Confidence Score: 5/5

    • This PR is safe to merge with no identified issues
    • The implementation follows established patterns in the codebase (similar to reshape), includes proper input validation, has comprehensive test coverage (6 test cases covering basic functionality, edge cases, and error conditions), and the previous thread's concern about missing bounds validation is actually already handled by the wrapDim utility function
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/ops/alias.h Added well-documented asNested API with clear parameter descriptions and usage example
    csrc/ops/alias.cpp Implemented asNested with proper validation, domain construction, and LoadStoreOp aliasing
    tests/cpp/test_ragged_iter_domain.cpp Added comprehensive test coverage for asNested including basic functionality, edge cases, and validation

    Sequence Diagram

    sequenceDiagram
        participant User
        participant asNested
        participant TensorDomain
        participant RaggedIterDomain
        participant IrBuilder
        
        User->>asNested: asNested(data, extents, ragged_dim)
        
        asNested->>asNested: Validate data != null
        asNested->>asNested: Validate extents != null
        asNested->>asNested: Validate extents is 1D
        
        asNested->>TensorDomain: Get logical domain (noReductions)
        TensorDomain-->>asNested: inp_logical
        
        asNested->>asNested: Clone logical domain to root_domain
        asNested->>asNested: wrapDim(ragged_dim, size) - validates bounds
        
        asNested->>RaggedIterDomain: partition(root_domain[ragged_dim], extents)
        RaggedIterDomain->>RaggedIterDomain: Validate input IterDomain
        RaggedIterDomain->>RaggedIterDomain: Validate not already ragged
        RaggedIterDomain->>RaggedIterDomain: Validate extents dtype is Index
        RaggedIterDomain->>RaggedIterDomain: Create component_id and ragged_id
        RaggedIterDomain-->>asNested: (component_id, ragged_id)
        
        asNested->>asNested: Build logical_domain by replacing<br/>ragged_dim with (component, ragged)
        
        asNested->>IrBuilder: create TensorDomain(root, logical, logical, contiguity)
        IrBuilder-->>asNested: TensorDomain
        
        asNested->>IrBuilder: create TensorView(domain, dtype)
        IrBuilder-->>asNested: output TensorView
        
        asNested->>IrBuilder: create LoadStoreOp(Set, out, data)
        IrBuilder-->>asNested: LoadStoreOp (defines aliasing)
        
        asNested-->>User: nested TensorView
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 15, 2025

    !test

    Base automatically changed from raggediterdomain_partition to main December 18, 2025 21:32
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    !test

    @naoyam naoyam requested a review from wujingyue December 18, 2025 22:02
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants