[inductor] add cat_linear_fused pre-grad pass for F.linear(cat(...))#3233
Open
reger-men wants to merge 1 commit into
Open
[inductor] add cat_linear_fused pre-grad pass for F.linear(cat(...))#3233reger-men wants to merge 1 commit into
reger-men wants to merge 1 commit into
Conversation
|
Jenkins build for afa7652a0774f49d4fe22fbff8c00725a81557cc commit finished as FAILURE |
afa7652 to
fb14886
Compare
|
Jenkins build for fb148869e8839c1ec24ce6920b9a6f0ec7f1ba5e commit finished as FAILURE |
fb14886 to
d928e13
Compare
Pre-grad FX pass that rewrites F.linear(torch.cat([t_0, ...], dim=-1), W, b)
into a reduce-sum of per-piece F.linear calls on contiguous weight
slices. Avoids materialising the cat on the forward path (and the
cat-s grad on the backward path), which on bf16 GEMMs is a measurable
HBM-bandwidth win.
Conservative gating: only fires when every cat operand is a last-dim
slice with the same leading shape, bias is None (or only on the first
partial linear), total cat width is below MAX_TOTAL_CAT_WIDTH, and
cat axis is the last axis (handles negative indexing).
Off by default; opt in via
torch._inductor.config.pre_grad_custom_pass = cat_linear_fused_pre_grad_pass
Test under test/inductor/test_cat_linear_fused.py covers correctness
vs eager (forward + bf16 gradients), fire counter under the flag, and
the negative gates (cat-on-non-last-axis, mismatched leading shape,
mul-parented operand, too-many-parts, too-narrow piece, total > cap).
Implementation note: the helper _val_of uses an explicit "is None"
check rather than "or", because
node.meta.get("val") or node.meta.get("example_value")
calls Tensor.__bool__ and raises when meta["val"] is a multi-element
tensor (as set in unit tests that mock shape metadata).
|
Jenkins build for d928e13248451100d4e5008013164ebefdb012b9 commit finished as FAILURE Detected error during Pytorch building: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Pre-grad FX pass that rewrites
F.linear(torch.cat([t_0, ...], dim=-1), W, b)into a reduce-sum of per-pieceF.linearcalls on contiguous weight slices. Avoids materialising the cat on the forward path (and the cat's grad on the backward path), which on bf16 GEMMs is a measurable HBM-bandwidth win.Conservative gating: only fires when every cat operand is a last-dim slice with the same leading shape, bias is None (or only on the first partial linear), total cat width is below
MAX_TOTAL_CAT_WIDTH, and the cat axis is the last axis (handles negative indexing).Off by default; opt in via
Implementation note: the helper
_val_ofuses an explicitis Nonecheck rather thanor, becausenode.meta.get("val") or node.meta.get("example_value")callsTensor.__bool__and raises whenmeta["val"]is a multi-element tensor (which can happen on a post-grad graph or a test graph that populatesmeta["val"]directly).Test plan
test_cat_linear_fused.pymatcher tests: canonical pattern fires, 3-part fires, rejects more than MAX_PARTS, rejects pieces below MIN_PIECE_WIDTH, rejects total > MAX_TOTAL_CAT_WIDTH, rejectsmul-parented operand, rejects cat on a non-last axistorch.compilewith the pass installed produces results within 1e-4 of eager and bumps thecat_linear_fusedcounter; with the pass not installed the counter stays at 0