[Benchmark] Add compute_seq_len_sweep_config_with_probe with linear/quadratic scaling support#1218
Conversation
…r scaling (linkedin#1200) Adds a new helper alongside the existing compute_seq_len_sweep_config that internalizes both the probe and the seq-len inversion, with a scaling_method argument supporting "linear" (default) and "quadratic". For O(L^2) kernels, the inversion uses L_max = sqrt(usable / (B * c_per_BL2)) instead of the linear max_tokens / batch_size path. Migrates benchmark_sparse_multi_token_attention.py to the new helper and drops its manual `peak_bytes // (probe_L * probe_L)` workaround. The existing estimate_kernel_peak_memory and compute_seq_len_sweep_config are unchanged; linear-scaling benchmark callers don't need to migrate.
|
@Tcc0403 @Mecoli1219 Please take a look |
| batch_size = max(1, min(max_batch_size, probe_batch_size)) | ||
|
|
||
| if scaling_method == "linear": | ||
| c_per_BL = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len)) | ||
| max_seq_len_from_mem = max(1, int(usable_bytes / (batch_size * c_per_BL))) | ||
| else: | ||
| c_per_BL2 = max(1.0, peak_bytes / (probe_batch_size * probe_seq_len * probe_seq_len)) | ||
| max_seq_len_from_mem = max(1, int(math.sqrt(usable_bytes / (batch_size * c_per_BL2)))) | ||
|
|
||
| seq_len = min(max_seq_len, max_seq_len_from_mem) | ||
| seq_len = 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024 |
There was a problem hiding this comment.
Is it possible to just plug this part to compute_seq_len_sweep_config?
There was a problem hiding this comment.
@Tcc0403 Good call. Pushed 6c204db which extracts two private helpers — _max_seqlen_under_memory (handles both linear and quadratic inversion) and _snap_pow2_seqlen — and collapses both public functions to thin orchestration over them.
compute_seq_len_sweep_config treats kernel_bytes_per_token as a unit-probe (B=L=1, linear) so the inversion math reduces to the existing max_tokens = usable / bpt quantity. No behavior change for the existing callers; the duplicated inversion/snap logic is gone.
Per @Tcc0403 review: instead of two parallel implementations of the inversion + power-of-2 snap, extract `_max_seqlen_under_memory` (handles both linear and quadratic) and `_snap_pow2_seqlen`. Both public APIs become thin orchestration layers over them. `compute_seq_len_sweep_config` now treats `kernel_bytes_per_token` as a unit-probe (B=L=1, scaling=linear) so the math collapses to the existing `max_tokens = usable / bpt` behavior — no behavior change for the 16+ existing callers.
| return 2 ** int(math.log2(seq_len)) if seq_len >= 1024 else 1024 | ||
|
|
||
|
|
||
| def compute_seq_len_sweep_config_with_probe( |
There was a problem hiding this comment.
we can replace all occurrences of compute_seq_len_sweep_config with yours, keeping only one helper function
Per @Tcc0403's review on linkedin#1218: replace all callers of the old compute_seq_len_sweep_config with the probe-aware variant and delete the old function. Single public helper, single way to compute a sweep config. The unified compute_seq_len_sweep_config takes probe_fn + probe_seq_len directly and runs estimate_kernel_peak_memory internally. The scaling_method="quadratic" path that compute_seq_len_sweep_config_with_probe existed to support is now first-class on the unified function. Caller pattern simplifies from: peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) kernel_bpt = peak_bytes // probe_seq_len config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) to: config = compute_seq_len_sweep_config(model, probe_fn=_probe, probe_seq_len=probe_seq_len) Net -154 lines across 33 benchmark scripts. The benchmark_multi_token_attention caller, which previously did a manual peak_bytes // (probe_L * probe_L) quadratic inversion, now uses scaling_method="quadratic" via the unified API.
|
@Tcc0403 Pushed a8a8f40. Collapsed the two public helpers into one — The |
Summary
Refs #1200. Addresses non-linear memory scaling in benchmark sweep config inference.
The existing
compute_seq_len_sweep_configinverts memory viamax_tokens = usable_bytes / kernel_bytes_per_token, which only holds for linear-scaling kernels. For O(L²) kernels (e.g.benchmark_sparse_multi_token_attention.py), this overestimates capacity by orders of magnitude — the existing workaround there divides byprobe_L * probe_L, but the downstream sweep math still treats the result as linear bytes-per-token.Per discussion on the issue (#1200 (comment)), this PR adds a new helper rather than threading
scaling_methodthrough the existing function — 16+ benchmark scripts callestimate_kernel_peak_memorytoday, and a wider signature change would conflict with in-flight benchmark refactors (#1199, #1180). Linear-scaling callers are unchanged; only quadratic-scaling benchmarks opt in.What changed
benchmark/scripts/benchmark_model_configs.py— addscompute_seq_len_sweep_config_with_probe(model_cfg, probe_fn, probe_seq_len, probe_batch_size=1, scaling_method="linear" | "quadratic", ...). Internalizes the probe call + inversion; reusesestimate_kernel_peak_memoryfor the measurement.benchmark/scripts/benchmark_sparse_multi_token_attention.py— switches thetoken_lengthsweep mode to the new helper withscaling_method="quadratic", dropping the manualpeak_bytes // (probe_L * probe_L)workaround.estimate_kernel_peak_memoryandcompute_seq_len_sweep_configare untouched.Validation
Hardware: A10G 24GB (g5.xlarge).
Synthetic O(L²) probe (B=2, L=2048, allocates
B * L * Lfloats) usingLLAMA_3_8Bconfig andmax_seq_len=2**20to bypass the model cap so the raw inversion is visible:The 8× gap (≈17× before snap-to-power-of-2) demonstrates the inversion difference:
linearclaims a sweep at L=65536 fits, when in reality L² at that size would require multiple TBs.quadraticlands at a realistic L=8192. This matches the issue's premise — for non-linear-scaling kernels, the existing inversion overestimates capacity and would OOM at the predicted boundary.Testing Done
quadraticpredicts L=8192 vslinearpredicts L=65536 for the same probe (8× separation, scales as expected).benchmark_sparse_multi_token_attention.pyimports + helper resolution verified locally.cc @Tcc0403