Skip to content

[BUG] L2/98 Fused Triton Kernel Crashes with double free or corruption / malloc(): invalid next size #147

@wuyii8941

Description

@wuyii8941

[BUG] L2/98 Fused Triton Kernel Crashes with double free or corruption / malloc(): invalid next size

Summary

The KernelAgent-generated Triton kernel for L2 Task 98 (98_Matmul_AvgPool_GELU_Scale_Max) crashes with a glibc heap corruption abort on every invocation. The PyTorch reference implementation runs correctly; the crash is isolated to the fused Triton kernel.

Environment:

  • GPU: NVIDIA RTX A6000 (Ampere)
  • PyTorch: 2.5.1+cu121
  • CUDA: 12.1

Reproduction

import torch
import torch.nn.functional as F

B, K, N = 32, 512, 8192
x = torch.randn(B, K, dtype=torch.bfloat16, device="cuda")
W = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
b = torch.randn(N,    dtype=torch.bfloat16, device="cuda")

# Path A: PyTorch reference — OK
y = F.linear(x.float(), W.float(), b.float())
y = F.avg_pool1d(y.unsqueeze(1), 16, 16).squeeze(1)
y = F.gelu(y) * 2.0
y = y.max(dim=1).values
print("Path A [OK]:", y.shape)  # [32]

# Path B: Fused Triton kernel — CRASH
out = kernel_function(x, W, b, pool_kernel_size=16, scale_factor=2.0)
# → malloc(): invalid next size (unsorted)  /  double free or corruption (out)
# → Aborted (core dumped)

Root Cause

The kernel fuses the AvgPool averaging loop (while r < POOL) inside the K-reduction loop, creating a 3-level nested structure:

g-loop (dynamic)
  k-loop (dynamic)
    r-loop  ← POOL=16 is tl.constexpr → fully unrolled by Triton
      sumW_T += tl.load(...)
    tl.dot(x, sumW_T / POOL, acc)

Because POOL is a tl.constexpr, Triton's LLVM backend fully unrolls the r-loop into 16 consecutive load+add sequences. The sumW_T accumulator tensor ([BLOCK_K, BLOCK_G], fp32) must remain live across all 16 unrolled iterations, and its live range overlaps with both x ([BLOCK_B, BLOCK_K]) and acc ([BLOCK_B, BLOCK_G]) in the same scheduling scope.

Triton's register allocator cannot correctly schedule the spill/fill operations for sumW_T across the unrolled body, causing spill-buffer writes to overflow into adjacent CUDA-managed memory. The corruption propagates to CPU heap metadata, which glibc detects at the next free() call and triggers abort().

This affects all four autotune configs because the problem is structural (loop nesting), not a tile-size threshold.

Additionally, the original W_TRANSPOSED=False branch used tl.trans(sumW) before passing to tl.dot, which imposes a layout conversion that Triton handles unreliably for non-contiguous intermediate tensors—a secondary contributing factor.


Fix

Move the POOL-averaging out of the Triton kernel into the Python wrapper as a one-time PyTorch operation. The kernel then receives a pre-averaged W_avg [G, K] and b_avg [G] and performs a standard 2-level (g, k) matmul loop—no r-loop, no tl.trans.

Wrapper (add 4 lines before kernel launch)

G = N // POOL

# Pre-compute averaged weight and bias on GPU (pure PyTorch, no kernel math)
if not W_TRANSPOSED:
    W_avg = W.float().view(G, POOL, K).mean(dim=1).contiguous()    # [G, K]
else:
    W_avg = W.float().view(K, G, POOL).mean(dim=2).t().contiguous() # [G, K]
b_avg = b.float().view(G, POOL).mean(dim=1).contiguous()            # [G]

Kernel (replace the entire if not W_TRANSPOSED / else block)

# Before: 3-level loop (g, k, r) with sumW_T accumulator + tl.trans
# After:  standard 2-level loop (g, k), direct load of pre-averaged W_avg

# Load W_avg tile as [BLOCK_K, BLOCK_G] — no tl.trans needed.
# w_avg layout is [G, K]; swap index order to get the transposed tile directly.
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_g[None, :] * stride_wg)
w_tile = tl.load(w_ptrs, mask=mask_k[:, None] & mask_g[None, :], other=0.0)
# [BLOCK_B, BLOCK_K] x [BLOCK_K, BLOCK_G] → [BLOCK_B, BLOCK_G]
acc = tl.dot(x.to(tl.bfloat16), w_tile.to(tl.bfloat16), acc)

The bias loop is similarly simplified to a single tl.load of b_avg[offs_g].

Why this is correct

avg_pool1d(linear(x, W, b), kernel=POOL) is mathematically equivalent to linear(x, W_avg, b_avg) where W_avg[g, :] = mean(W[g*POOL:(g+1)*POOL, :]). Pre-computing the average in the wrapper produces identical numerical results while giving the kernel a standard matmul structure that Triton compiles reliably.


Impact

  • Severity: P0 — 100% crash rate, no fallback; the kernel is completely unusable.
  • Scope: All batch sizes, all autotune configs, both W_TRANSPOSED=True/False paths.

Verification

After applying the fix, Path B returns values numerically close to Path A:

Path A [OK]: tensor([31.15, 40.82, 38.61, ...])
Path B [OK]: tensor([31.15, 40.82, 38.61, ...])   # max_diff < 0.1 (bf16 rounding)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions