Skip to content

perf: move to seq-parallel atomic reduction in bwd_b#3

Open
TimoImhof wants to merge 1 commit into
mainfrom
perf/bwd_b
Open

perf: move to seq-parallel atomic reduction in bwd_b#3
TimoImhof wants to merge 1 commit into
mainfrom
perf/bwd_b

Conversation

@TimoImhof
Copy link
Copy Markdown
Owner

@TimoImhof TimoImhof commented Apr 12, 2026

The bwd_b kernel is currently the primary bottleneck in the fused backward pas:

fused_lora_bwd

It iterates over the sequence dimension ($S$) within a single thread block:

# dL/d(out).T @ (X @ A.T)
    for s in range(0, tl.cdiv(dL_dout_S, BLOCK_S)):
        # ... inner = X @ A.T ...
        # ... _total_acc = tl.dot(dL_out, _interm_tile) ...

leading to redundant memory traffic: Both $X$ and $A$ are re-loaded from global memory $S/BLOCK_S$ times. Since $S$ is typically the largest dimension, this creates massive, unnecessary VRAM pressure.

Proposed Fix: Sequence-Parallelism Mirroring the successful optimization of bwd_a - this PR refactors bwd_b to:

  • Parallelize over $S$: Map the grid to the sequence dimension so that multiple blocks process chunks of the sequence concurrently.
  • Eliminate Redundant Loads: Each block loads its chunk of $X$ and $A$ exactly once.
  • Atomic Accumulation: Uses tl.atomic_add to aggregate partial gradients into the final $B$ matrix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant