Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,4 +1267,65 @@ TensorView* repeat(
return out_tv;
}

TensorView* asNested(
TensorView* data,
TensorView* extents,
int64_t ragged_dim) {
NVF_ERROR(data != nullptr, "asNested: data tensor is null");
NVF_ERROR(extents != nullptr, "asNested: extents tensor is null");

// Only 1D extents tensors are currently supported
NVF_ERROR_EQ(
extents->nDims(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extents->nDims(),
std::ranges::distance(extents->getLogicalDomain() | TensorDomain::kNoReductions),

extents can have reduction dimensions, e.g., tokens_per_expert is a reduction of tokens_per_expert_per_device.

1,
"asNested currently only supports 1D extents tensors");

// Get the logical domain of the input, excluding reductions
auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain());
auto inp_logical = data->getLogicalDomain() | TensorDomain::kNoReductions;

And there's also std::ranges::distance to replace std::ssize.


// Clone the logical domain to create the root domain for output
std::vector<IterDomain*> root_domain;
root_domain.reserve(inp_logical.size());
for (auto* id : inp_logical) {
root_domain.push_back(id->cloneWithoutRFactor());
}

ragged_dim = wrapDim(ragged_dim, std::ssize(inp_logical));

// Partition the specified dimension in root domain
// This replaces one IterDomain with (component_id, ragged_id)
auto [component_id, ragged_id] =
RaggedIterDomain::partition(root_domain.at(ragged_dim), extents);

// Build the logical domain: replace ragged_dim with component and ragged
std::vector<IterDomain*> logical_domain;
logical_domain.reserve(root_domain.size() + 1); // One extra for the split

for (const auto i : arange(root_domain.size())) {
if (static_cast<int64_t>(i) == ragged_dim) {
// Replace with component and ragged dimensions
logical_domain.push_back(component_id);
logical_domain.push_back(ragged_id);
} else {
logical_domain.push_back(root_domain.at(i));
}
}

// Create the output TensorView with the partitioned structure
auto* out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
root_domain,
logical_domain,
logical_domain,
TensorDomain::getContiguityFilledWith(logical_domain, true)),
data->getDataType().value());

// For now, just use LoadStoreOp to represent the nesting
// operation. Does it make more sense to have a specific TensorView
// op like ReshapeOp?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid the codebase made too many assumptions that LoadStoreOp keeps the logical domain unchanged (modulo reduction). So it's probably safer to use a different op. ReshapeOp sounds fine -- it's really like a splitting reshape except for non-uniform extents.

IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, data);

return out;
}

} // namespace nvfuser
21 changes: 21 additions & 0 deletions csrc/ops/alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,25 @@ NVF_API TensorView* repeat(
TensorView* inp,
const std::vector<int64_t>& repeat_times);

//! Create a nested tensor view from a data tensor and extents.
//!
//! The function partitions the specified dimension of the data tensor into
//! a component dimension and a ragged dimension based on the provided extents.
//!
//! \param data Input tensor to be converted to nested representation
//! \param extents Extents tensor defining the size of each component
//! Shape: [num_components], values: [extent0, extent1, ..., extent(n-1)]
//! \param ragged_dim Dimension to partition into nested structure
//! \return TensorView with a RaggedIterDomain at the specified dimension
//!
//! Example:
//! data shape: [10, ...]
//! extents: [3, 5, 2]
//! ragged_dim: 0
//! Result: nested tensor with 3 components. [3, [3, 5, 2], ...]
NVF_API TensorView* asNested(
TensorView* data,
TensorView* extents,
int64_t ragged_dim);

} // namespace nvfuser
142 changes: 142 additions & 0 deletions tests/cpp/test_ragged_iter_domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,4 +340,146 @@ TEST_F(RaggedIterDomainTest, TensorViewPartition) {
EXPECT_EQ(tv0->axis(0)->definition(), tv0->axis(1)->definition());
}

// asNested basic functionality
TEST_F(RaggedIterDomainTest, AsNestedBasic) {
Fusion fusion;
FusionGuard fg(&fusion);

auto data = makeSymbolicTensor(2, DataType::Float);
fusion.addInput(data);

auto extents = makeSymbolicTensor(1, DataType::Index);
fusion.addInput(extents);

// Create nested tensor from dimension 0
auto nested = asNested(data, extents, 0);

fusion.addOutput(nested);

// Verify the output is a new TensorView
EXPECT_TRUE(nested != nullptr);
EXPECT_NE(nested, data);
EXPECT_TRUE(nested->isA<TensorView>());

// Verify nested tensor has 3 dimensions: [component, ragged, original_dim1]
EXPECT_EQ(nested->nDims(), 3);

// First axis should be a regular IterDomain (component)
EXPECT_TRUE(nested->axis(0)->isStrictlyA<IterDomain>());
EXPECT_FALSE(nested->axis(0)->isA<RaggedIterDomain>());

// Second axis should be a RaggedIterDomain
EXPECT_TRUE(nested->axis(1)->isA<RaggedIterDomain>());

// Third axis should be the original second dimension
EXPECT_TRUE(nested->axis(2)->isStrictlyA<IterDomain>());

// Verify the definition exists (LoadStoreOp for aliasing)
EXPECT_TRUE(nested->definition() != nullptr);
EXPECT_TRUE(nested->definition()->isA<LoadStoreOp>());

// Verify the component and ragged IterDomains have Partition as their
// definition
EXPECT_TRUE(nested->axis(0)->definition() != nullptr);
EXPECT_TRUE(nested->axis(0)->definition()->isA<Partition>());
EXPECT_EQ(nested->axis(0)->definition(), nested->axis(1)->definition());
}

// asNested on different dimensions
TEST_F(RaggedIterDomainTest, AsNestedDifferentDimension) {
Fusion fusion;
FusionGuard fg(&fusion);

auto data = makeSymbolicTensor(3, DataType::Float);
fusion.addInput(data);

auto extents = makeSymbolicTensor(1, DataType::Index);
fusion.addInput(extents);

// Partition dimension 1 (middle dimension)
auto nested = asNested(data, extents, 1);

// Verify dimensions: [dim0, component, ragged, dim2]
EXPECT_EQ(nested->nDims(), 4);

// First axis is original dim0
EXPECT_TRUE(nested->axis(0)->isStrictlyA<IterDomain>());

// Second axis is component
EXPECT_TRUE(nested->axis(1)->isStrictlyA<IterDomain>());

// Third axis is ragged
EXPECT_TRUE(nested->axis(2)->isA<RaggedIterDomain>());

// Fourth axis is original dim2
EXPECT_TRUE(nested->axis(3)->isA<IterDomain>());
}

// asNested with 1D tensor
TEST_F(RaggedIterDomainTest, AsNested1DTensor) {
Fusion fusion;
FusionGuard fg(&fusion);

// Create a 1D TensorView [10]
auto data = makeSymbolicTensor(1, DataType::Float);
fusion.addInput(data);

// Create extents tensor
auto extents = makeSymbolicTensor(1, DataType::Index);
fusion.addInput(extents);

// Create nested tensor from the only dimension
auto nested = asNested(data, extents, 0);

fusion.addOutput(nested);

// Verify dimensions: [component, ragged]
EXPECT_EQ(nested->nDims(), 2);

// First axis is component
EXPECT_TRUE(nested->axis(0)->isStrictlyA<IterDomain>());

// Second axis is ragged
EXPECT_TRUE(nested->axis(1)->isA<RaggedIterDomain>());
}

// asNested validation - null data
TEST_F(RaggedIterDomainTest, AsNestedValidationNullData) {
Fusion fusion;
FusionGuard fg(&fusion);

auto extents = makeSymbolicTensor(1, DataType::Index);
fusion.addInput(extents);

// Null data should throw
EXPECT_THROW(asNested(nullptr, extents, 0), nvfuser::nvfError);
}

// asNested validation - null extents
TEST_F(RaggedIterDomainTest, AsNestedValidationNullExtents) {
Fusion fusion;
FusionGuard fg(&fusion);

auto data = makeSymbolicTensor(2, DataType::Float);
fusion.addInput(data);

// Null extents should throw
EXPECT_THROW(asNested(data, nullptr, 0), nvfuser::nvfError);
}

// asNested validation - multi-dimensional extents (not yet supported)
TEST_F(RaggedIterDomainTest, AsNestedValidationMultiDimExtents) {
Fusion fusion;
FusionGuard fg(&fusion);

auto data = makeSymbolicTensor(2, DataType::Float);
fusion.addInput(data);

// 2D extents should fail (only 1D supported currently)
auto extents_2d = makeSymbolicTensor(2, DataType::Index);
fusion.addInput(extents_2d);

EXPECT_THROW(asNested(data, extents_2d, 0), nvfuser::nvfError);
}

} // namespace nvfuser