Skip to content

Commit 13a7c05

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Fix output offset calculation and add symint support to ComputeGraph
Fix output argument indexing in VulkanBackend::execute() and extend ComputeGraph to transparently handle symint values. The output loop previously computed the args index as `i + num_inputs`, which breaks when non-tensor arguments (e.g. symints) sit between the tensor inputs and outputs in the args array. Fix by computing the offset from the end: `args.size() - num_outputs`. ComputeGraph changes add symint support so that operators can read symint values uniformly: - `extract_scalar<T>()` now handles SymInt values, allowing operators to call extract_scalar on arguments that may be either plain ints or symints without special-casing. - `read_symint()` falls back to reading plain Int values, so values stored as Int (rather than SymInt objects) can be read uniformly. Pull Request resolved: pytorch#18050 ghstack-source-id: 353546683 @exported-using-ghexport Differential Revision: [D95970167](https://our.internmc.facebook.com/intern/diff/D95970167/)
1 parent 981bc60 commit 13a7c05

3 files changed

Lines changed: 10 additions & 4 deletions

File tree

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
671671
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);
672672

673673
const size_t num_inputs = compute_graph->inputs().size();
674+
const size_t num_outputs = compute_graph->outputs().size();
674675
bool should_propagate_resize = false;
675676
#ifdef ET_EVENT_TRACER_ENABLED
676677
runtime::EventTracer* event_tracer = context.event_tracer();
@@ -770,14 +771,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
770771
"ETVK_COPY_OUTPUTS",
771772
/* delegate_debug_id = */ -1);
772773
#endif // ET_EVENT_TRACER_ENABLED
773-
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
774-
const size_t o = i + num_inputs;
774+
const size_t output_offset = args.size() - num_outputs;
775+
for (size_t i = 0; i < num_outputs; i++) {
776+
const size_t o = output_offset + i;
775777
const ValueRef oref = compute_graph->outputs()[i].value;
776778
if (compute_graph->val_is_tensor(oref)) {
777779
VK_CHECK_COND(args[o]->isTensor());
778780
maybe_resize_output(compute_graph, i, args[o]->toTensor());
779-
// args holds inputs directly followed by outputs, so the i'th output
780-
// for compute_graph corresponds to the o'th arg
781781
compute_graph->maybe_cast_and_copy_from_staging(
782782
compute_graph->outputs()[i].staging,
783783
args[o]->toTensor().mutable_data_ptr(),

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,9 @@ void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
725725
}
726726

727727
int32_t ComputeGraph::read_symint(const ValueRef idx) {
728+
if (values_.at(idx).isInt()) {
729+
return static_cast<int32_t>(values_.at(idx).toInt());
730+
}
728731
return get_symint(idx)->get();
729732
}
730733

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ class ComputeGraph final {
573573
if (value.isBool()) {
574574
return static_cast<T>(value.toBool());
575575
}
576+
if (value.isSymInt()) {
577+
return utils::safe_downcast<T>(read_symint(idx));
578+
}
576579
VK_THROW("Cannot extract scalar from Value with type ", value.type());
577580
}
578581

0 commit comments

Comments
 (0)