-
Notifications
You must be signed in to change notification settings - Fork 750
Add Reshape scalar optimization and Gather scalar input support #4146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Updated tensor API to treat empty tensors as no-ops in slice_assign and to filter out tensors with size 0 along the concatenation dimension in cat. Added tests to verify correct behavior when handling empty tensors and empty slice assignments. This prevents backend errors and ensures consistent results for edge cases.
Added documentation notes to BoolTensorOps, IntTensorOps, and FloatTensorOps traits specifying that empty slice assignments and empty tensors are handled at the high-level tensor API and will not be passed to backend implementations. This clarifies backend responsibilities and prevents unnecessary checks for empty cases.
Streamlined the filtering of non-empty tensors by directly mapping to primitives, reducing unnecessary intermediate steps and improving code clarity in the tensor concatenation logic.
This update enables correct handling of outer-scope references for If, Loop, and Scan nodes by extracting referenced names from subgraphs, adding them as additional node inputs, and deferring subgraph construction until type inference when all types are resolved. The changes include new mechanisms for tracking and passing outer-scope types, updating code generation to bind these references, and ensuring subgraph inputs use the correct names and types from the parent graph. This improves compatibility with ONNX models that use control flow and nested subgraphs referencing parent-scope values.
Introduces a new subgraph_helper module to centralize code generation logic for If, Loop, and Scan nodes in burn-import. This reduces duplication and improves maintainability by sharing routines for outer-scope bindings, scope registration, and forward code generation. Minor cleanups and formatting improvements are made in onnx-ir for subgraph processing and logging.
This commit eliminates the use of the GraphBuilder and GraphBuilders variants from AttributeValue and related code paths. Subgraph handling for If, Loop, and Scan nodes now only supports DeferredGraph and Graph, simplifying the control flow and post-processing logic. The subgraph rewiring logic for GraphBuilder has been removed, as subgraphs are now always finalized before post-processing.
The documentation for DeferredGraph now provides a detailed explanation of why subgraph construction is deferred in ONNX control flow nodes, including an example and a step-by-step description of the lazy building process. This improves clarity for future maintainers and users of the code.
Introduces new ONNX models and Rust tests to verify subgraph handling of outer-scope variable references in If, Loop, and Scan operators. Updates build.rs to include these models and expands the test suite to cover single and multi-variable outer-scope references, ensuring correct behavior for nested control flow and parent-scope variable access.
Replaces usage of ArgType with full Argument for outer-scope references in GraphState and node processors (If, Loop, Scan). This change ensures constant values (e.g., LSTM weights) are preserved when subgraphs reference parent graph values. Also adds logic to merge output types for If node branches and updates related documentation/comments.
Enhanced the extract_config method in LstmProcessor to robustly determine input_size by checking multiple sources: weight tensor's static shape, weight constant data, and input tensor's static shape. This improves compatibility with models where weight tensors are dynamically computed or input shapes are more reliably available.
Replaces usage of 'outer_name' with 'outer_var' for scalar argument bindings in the generate_outer_scope_bindings function, ensuring correct variable assignment.
Refactors get_onnx_input_count to log a warning when the __onnx_input_count attribute is present but not an Int64, and falls back to using the input count. This improves debuggability and robustness when encountering unexpected attribute types.
Moved the build_outer_scope_from_inputs function from if_node.rs, loop_node.rs, and scan_node.rs into processor.rs to eliminate code duplication. Updated imports in affected node files to use the shared implementation.
Enhanced the merge_branch_types function to handle tensor rank and dtype differences more robustly. Now, it prefers higher rank for tensors with matching dtypes, logs warnings for incompatible dtypes or type categories, and clears static shapes when ranks differ. This provides more accurate static type inference for ONNX If nodes.
Corrects logic to preserve original names for Constant arguments and use sanitized names only for non-Constant (Dynamic) arguments during subgraph code generation.
Expanded comments in subgraph tests to include step-by-step calculations for expected outputs, clarifying how Relu, Sigmoid, and Tanh operations combine for given input values. This improves test readability and aids future maintenance.
Expanded comments to explain why only dynamic and constant node inputs are tracked for future use, detailing the rationale for excluding static initializers from runtime clone management.
When resolving outer-scope references, the original name is now preserved for constant arguments, matching the logic used for sanitized lookups. This ensures consistent naming for constants and improves clarity in graph state resolution.
Updated comments to better explain the logic for preserving or sanitizing argument names based on whether they are constant or dynamic. No functional changes were made.
Expanded comments to clarify the distinction between loop-provided inputs and outer-scope references in Loop/Scan body subgraphs. Added references to ONNX documentation for better context.
Introduced a helper closure to clarify and centralize the logic for identifying outer-scope inputs in ONNX subgraphs. This improves readability and ensures consistent handling of inputs with and without initializers.
Cleaned up the import statements in outer_scope_multi_var.py by removing the unused numpy_helper import from onnx.
# Conflicts: # crates/burn-import/src/burn/node/if_node.rs # crates/burn-import/src/burn/node/loop_node.rs # crates/burn-import/src/burn/node/scan_node.rs
Added logic to collect referenced names in subgraphs and filter outer-scope bindings to only generate bindings for variables actually used in the subgraph. This prevents unused variable warnings in generated code for If, Loop, and Scan nodes.
This commit adds robust support for referencing parent graph constants and initializers from subgraphs (e.g., If branches), preserving value_store references for outer-scope Static/Constant arguments. It introduces a new ONNX test (outer_scope_constant) and updates the Reshape operator to handle scalar inputs, addressing issues encountered in models like Silero VAD. Additional logging and debugging utilities are included to aid in diagnosing value store and tensor store state.
Simplifies the code generation for scalar-to-tensor reshapes by using Tensor::full with TensorKind, removing type-specific match arms. Adds ToTokens implementation for TensorKind and new tests for scalar reshape scenarios.
Introduces logic to remove unused constant nodes from the ONNX graph after subgraphs are constructed. This improves memory usage and performance by ensuring only referenced constants are retained.
Enhances the Silero VAD model check by downloading test audio, generating ONNX Runtime reference outputs, and running a suite of 12 test cases comparing Burn model outputs to ONNX results with tolerance. Updates README with new workflow and test details, adds serde dependencies, and refactors main.rs for automated test validation.
Removes merging logic for branch output types and now infers output types solely from the then_branch. The merge_branch_types function is deleted, as runtime compatibility is assumed and only one branch executes.
Adds handling for scalar inputs in ONNX Gather and Reshape nodes, including codegen, type inference, and tests. This enables correct processing of patterns like Reshape(scalar, [-1]) followed by Gather, as seen in models such as Silero VAD. New ONNX test cases and Rust tests verify that scalar inputs are passed through and remain scalars after these operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds an optimization to keep scalar values as scalars when reshaped with [-1] or [1], rather than unnecessarily converting them to rank-1 tensors. It also adds support for gathering from scalar inputs in the Gather operation to complete the optimization chain. This optimization appears in models like Silero VAD and avoids wasteful tensor conversions.
Key Changes
- Added scalar-to-scalar optimization in Reshape for shapes
[-1]and[1] - Extended Gather operator to handle scalar inputs in both onnx-ir and burn-import
- Added comprehensive ONNX integration tests for both optimization paths
Reviewed changes
Copilot reviewed 8 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| crates/onnx-ir/src/node/reshape.rs | Added Case 2 logic to keep scalars as scalars when reshaped to [-1] or [1], updated determine_output_type signature |
| crates/onnx-ir/src/node/gather.rs | Added scalar input support in type inference and config extraction, replaced error handling with scalar-specific logic, added unit test |
| crates/burn-import/src/burn/node/gather.rs | Added forward_scalar_gather function for scalar pass-through, updated match to handle scalar inputs, added snapshot tests |
| crates/burn-import/onnx-tests/tests/reshape/reshape_scalar_to_scalar.py | Python script to generate ONNX test model for reshape scalar optimization |
| crates/burn-import/onnx-tests/tests/reshape/reshape_scalar_to_scalar.onnx | Generated ONNX model for reshape scalar test |
| crates/burn-import/onnx-tests/tests/reshape/mod.rs | Added reshape_scalar_to_scalar integration test |
| crates/burn-import/onnx-tests/tests/gather/gather_scalar_input.py | Python script to generate ONNX test model for gather with scalar input |
| crates/burn-import/onnx-tests/tests/gather/gather_scalar_input.onnx | Generated ONNX model for gather scalar test |
| crates/burn-import/onnx-tests/tests/gather/mod.rs | Added gather_scalar_input integration test |
| crates/burn-import/onnx-tests/build.rs | Added new test models to build inputs |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #4146 +/- ##
==========================================
+ Coverage 68.39% 68.43% +0.03%
==========================================
Files 1281 1281
Lines 156146 156274 +128
==========================================
+ Hits 106802 106948 +146
+ Misses 49344 49326 -18 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Cleaned up the GatherNode implementation by removing commented-out panic code for unsupported input types.
Introduces tests to verify that Reshape operations with scalar inputs and single-element shapes ([-1] or [1]) retain scalar output, while multi-element shapes convert to tensors. This ensures correct behavior and optimization for scalar reshaping.
| // Shape is [-1] or [1] - effectively a single element, keep as scalar | ||
| if shape_values.len() == 1 && (shape_values[0] == -1 || shape_values[0] == 1) { | ||
| return ArgType::Scalar(input_info.dtype); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is only possible when the value is static or from a constant right? No runtime values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. The optimization only applies when the shape value is statically (static or constant) known at graph-build time.
(Note: ignore lots of commits; this PR is a left over from #4119)
This optimization appears in models like Silero VAD where scalar values are reshaped but should remain scalars for efficiency. Previously, Reshape(scalar, [-1]) would convert the scalar to a rank-1 tensor with one element, which is wasteful. The downstream Gather operation also needed to handle scalar inputs to complete the optimization chain.
Eventually we should generalize to mark a node no-op, see #4147
Example:
Unoptmized:
Optimized:
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
#4119
Changes
Testing