Skip to content

[inductor] add cat_linear_fused pre-grad pass for F.linear(cat(...))#3233

Open
reger-men wants to merge 1 commit into
ROCm:developfrom
reger-men:pr2-cat-linear-fused
Open

[inductor] add cat_linear_fused pre-grad pass for F.linear(cat(...))#3233
reger-men wants to merge 1 commit into
ROCm:developfrom
reger-men:pr2-cat-linear-fused

Conversation

@reger-men
Copy link
Copy Markdown

Pre-grad FX pass that rewrites F.linear(torch.cat([t_0, ...], dim=-1), W, b) into a reduce-sum of per-piece F.linear calls on contiguous weight slices. Avoids materialising the cat on the forward path (and the cat's grad on the backward path), which on bf16 GEMMs is a measurable HBM-bandwidth win.

Conservative gating: only fires when every cat operand is a last-dim slice with the same leading shape, bias is None (or only on the first partial linear), total cat width is below MAX_TOTAL_CAT_WIDTH, and the cat axis is the last axis (handles negative indexing).

Off by default; opt in via

torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass

Implementation note: the helper _val_of uses an explicit is None check rather than or, because node.meta.get("val") or node.meta.get("example_value") calls Tensor.__bool__ and raises when meta["val"] is a multi-element tensor (which can happen on a post-grad graph or a test graph that populates meta["val"] directly).

Test plan

  • test_cat_linear_fused.py matcher tests: canonical pattern fires, 3-part fires, rejects more than MAX_PARTS, rejects pieces below MIN_PIECE_WIDTH, rejects total > MAX_TOTAL_CAT_WIDTH, rejects mul-parented operand, rejects cat on a non-last axis
  • integration tests: torch.compile with the pass installed produces results within 1e-4 of eager and bumps the cat_linear_fused counter; with the pass not installed the counter stays at 0

@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 19, 2026

Jenkins build for afa7652a0774f49d4fe22fbff8c00725a81557cc commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@reger-men reger-men force-pushed the pr2-cat-linear-fused branch from afa7652 to fb14886 Compare May 20, 2026 18:06
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 20, 2026

Jenkins build for fb148869e8839c1ec24ce6920b9a6f0ec7f1ba5e commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@reger-men reger-men force-pushed the pr2-cat-linear-fused branch from fb14886 to d928e13 Compare May 21, 2026 09:24
Pre-grad FX pass that rewrites F.linear(torch.cat([t_0, ...], dim=-1), W, b)
into a reduce-sum of per-piece F.linear calls on contiguous weight
slices. Avoids materialising the cat on the forward path (and the
cat-s grad on the backward path), which on bf16 GEMMs is a measurable
HBM-bandwidth win.

Conservative gating: only fires when every cat operand is a last-dim
slice with the same leading shape, bias is None (or only on the first
partial linear), total cat width is below MAX_TOTAL_CAT_WIDTH, and
cat axis is the last axis (handles negative indexing).

Off by default; opt in via
torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass

Test under test/inductor/test_cat_linear_fused.py covers correctness
vs eager (forward + bf16 gradients), fire counter under the flag, and
the negative gates (cat-on-non-last-axis, mismatched leading shape,
mul-parented operand, too-many-parts, too-narrow piece, total > cap).

Implementation note: the helper _val_of uses an explicit "is None"
check rather than "or", because
    node.meta.get("val") or node.meta.get("example_value")
calls Tensor.__bool__ and raises when meta["val"] is a multi-element
tensor (as set in unit tests that mock shape metadata).
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 21, 2026

Jenkins build for d928e13248451100d4e5008013164ebefdb012b9 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Detected error during Pytorch building:

20 warnings generated when compiling for host.
[5080/8176] Building CXX object third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_binary_list.cpp.o
[5081/8176] Building CXX object third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_concat.cpp.o
[5082/8176] Building CXX object third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/ncsp_group_normalization.cpp.o
[5083/8176] Building CXX object third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/ref_binary.cpp.o
FAILED: third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/ref_binary.cpp.o 
/opt/cache/bin/sccache /opt/cache/bin/c++ -DDNNL_ENABLE_CPU_ISA_HINTS -DDNNL_ENABLE_ITT_TASKS -DDNNL_ENABLE_MAX_CPU_ISA -DDNNL_X64=1 -DIDEEP_USE_MKL -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DROCM_VERSION=70203 -DTORCH_HIP_VERSION=702 -DUSE_LAYERNORM_FAST_RECIPROCAL -D__STDC_CONSTANT_MACROS -D__STDC_LIMIT_MACROS -I/var/lib/jenkins/pytorch/build/third_party/ideep/mkl-dnn/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/include -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/third_party -I/var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/src -isystem /opt/rocm-7.2.3/include -isystem /var/lib/jenkins/pytorch/build/third_party/gloo -isystem /var/lib/jenkins/pytorch/cmake/../third_party/gloo -isystem /var/lib/jenkins/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /var/lib/jenkins/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /var/lib/jenkins/pytorch/cmake/../third_party/googletest/googletest/include -isystem /var/lib/jenkins/pytorch/third_party/protobuf/src -isystem /opt/conda/envs/py_3.12/include -isystem /var/lib/jenkins/pytorch/third_party/XNNPACK/include -isystem /var/lib/jenkins/pytorch/third_party/ittapi/include -isystem /var/lib/jenkins/pytorch/cmake/../third_party/eigen -isystem /opt/rocm/include -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -fopenmp -fvisibility-inlines-hidden  -fno-omit-frame-pointer -mno-omit-leaf-frame-pointer -Wall -Wno-unknown-pragmas -Wundef -fvisibility=internal   -fPIC -Wformat -Wformat-security -D_FORTIFY_SOURCE=2 -fstack-protector-strong -fcf-protection=full  -Wmissing-field-initializers  -Wno-strict-overflow -Wno-maybe-uninitialized -Wno-stringop-overflow -Wno-array-bounds  -O3 -DNDEBUG -DNDEBUG -std=c++20 -fPIC -DMKL_HAS_SBGEMM -DMKL_HAS_SHGEMM -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -MD -MT third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/ref_binary.cpp.o -MF third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/ref_binary.cpp.o.d -o third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/ref_binary.cpp.o -c /var/lib/jenkins/pytorch/third_party/ideep/mkl-dnn/src/cpu/ref_binary.cpp
thread 'main' panicked at 'failed to shut down worker thread', /root/.cargo/registry/src/github.com-1ecc6299db9ec823/jobserver-0.1.9/src/lib.rs:650:16
note: Run with `RUST_BACKTRACE=1` for a backtrace.
[5084/8176] Building CXX object third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/cpu_resampling_list.cpp.o
[5085/8176] Building CXX object third_party/ideep/mkl-dnn/src/cpu/CMakeFiles/dnnl_cpu.dir/primitive_attr_postops.cpp.o

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant