-
Notifications
You must be signed in to change notification settings - Fork 73
Adds asNested TensorView operation #5684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d87e6d7
77c6a07
f16fc4d
23d55f1
8392332
787dfec
a0b40a3
dbdd917
cdbd81e
d4c8d7f
9575a13
a054ae0
69dbe0f
db3b359
2348dde
7090b9c
b07e285
a2c504b
b1d8cf4
8a73bb2
550e0c5
82bd85e
f215f07
5b99432
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1267,4 +1267,65 @@ TensorView* repeat( | |||||
| return out_tv; | ||||||
| } | ||||||
|
|
||||||
| TensorView* asNested( | ||||||
| TensorView* data, | ||||||
| TensorView* extents, | ||||||
| int64_t ragged_dim) { | ||||||
| NVF_ERROR(data != nullptr, "asNested: data tensor is null"); | ||||||
| NVF_ERROR(extents != nullptr, "asNested: extents tensor is null"); | ||||||
|
|
||||||
| // Only 1D extents tensors are currently supported | ||||||
| NVF_ERROR_EQ( | ||||||
| extents->nDims(), | ||||||
| 1, | ||||||
| "asNested currently only supports 1D extents tensors"); | ||||||
|
|
||||||
| // Get the logical domain of the input, excluding reductions | ||||||
| auto inp_logical = TensorDomain::noReductions(data->getLogicalDomain()); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
And there's also std::ranges::distance to replace std::ssize. |
||||||
|
|
||||||
| // Clone the logical domain to create the root domain for output | ||||||
| std::vector<IterDomain*> root_domain; | ||||||
| root_domain.reserve(inp_logical.size()); | ||||||
| for (auto* id : inp_logical) { | ||||||
| root_domain.push_back(id->cloneWithoutRFactor()); | ||||||
| } | ||||||
|
|
||||||
| ragged_dim = wrapDim(ragged_dim, std::ssize(inp_logical)); | ||||||
|
|
||||||
| // Partition the specified dimension in root domain | ||||||
| // This replaces one IterDomain with (component_id, ragged_id) | ||||||
| auto [component_id, ragged_id] = | ||||||
| RaggedIterDomain::partition(root_domain.at(ragged_dim), extents); | ||||||
|
|
||||||
| // Build the logical domain: replace ragged_dim with component and ragged | ||||||
| std::vector<IterDomain*> logical_domain; | ||||||
| logical_domain.reserve(root_domain.size() + 1); // One extra for the split | ||||||
|
|
||||||
| for (const auto i : arange(root_domain.size())) { | ||||||
| if (static_cast<int64_t>(i) == ragged_dim) { | ||||||
| // Replace with component and ragged dimensions | ||||||
| logical_domain.push_back(component_id); | ||||||
| logical_domain.push_back(ragged_id); | ||||||
| } else { | ||||||
| logical_domain.push_back(root_domain.at(i)); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // Create the output TensorView with the partitioned structure | ||||||
| auto* out = IrBuilder::create<TensorView>( | ||||||
| IrBuilder::create<TensorDomain>( | ||||||
| root_domain, | ||||||
| logical_domain, | ||||||
| logical_domain, | ||||||
| TensorDomain::getContiguityFilledWith(logical_domain, true)), | ||||||
| data->getDataType().value()); | ||||||
|
|
||||||
| // For now, just use LoadStoreOp to represent the nesting | ||||||
| // operation. Does it make more sense to have a specific TensorView | ||||||
| // op like ReshapeOp? | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid the codebase made too many assumptions that LoadStoreOp keeps the logical domain unchanged (modulo reduction). So it's probably safer to use a different op. ReshapeOp sounds fine -- it's really like a splitting reshape except for non-uniform extents. |
||||||
| IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, data); | ||||||
|
|
||||||
| return out; | ||||||
| } | ||||||
|
|
||||||
| } // namespace nvfuser | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extentscan have reduction dimensions, e.g.,tokens_per_expertis a reduction oftokens_per_expert_per_device.