-
Notifications
You must be signed in to change notification settings - Fork 27
KernelAgent-Oink: Add SM100 CuTeDSL RMSNorm custom ops plugin for vLLM #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
oink::fused_add_rms_norm backed by an SM100 CuTeDSL RMSNorm kernel. The ops are torch.compile-friendly (stride-preserving for padded-row inputs) and the fused op matches vLLM's in-place residual-add RMSNorm semantics.
Jack-Khuu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initial comments
Need to go through rmsnorm.py
|
|
||
| import math | ||
| import operator | ||
| from typing import Callable, Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Prefer <type> | None instead of Optional keyword
| numerical behaviour and performance close to the original reference | ||
| implementations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original reference implementations.
Commit hash would be nice if you have it handy
| if sm >= 100: | ||
| # Use the tuned CuTeDSL SM100 kernel. The public API already | ||
| # contains all necessary gating and layout checks internally. | ||
| _rms = _get_rmsnorm_mod() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: pull _rms out of conditional
sm = _get_sm(x.device)
_rms = _get_rmsnorm_mod()
if sm >= 100:
return <>
return _rms.rmsnorm_ref(...)| assert weight.dim() == 1, "weight must be 1D [N]" | ||
|
|
||
| sm = _get_sm(x.device) | ||
| if sm >= 100: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Check inverse to reduce nesting
if sm < 100:
# Non-SM100: keep semantics in-place (correctness-first).| local_rank = os.environ.get("LOCAL_RANK") | ||
| if local_rank is not None: | ||
| try: | ||
| return int(local_rank) | ||
| except ValueError: | ||
| pass | ||
| return 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore suggestion if we want to guard/enable on "off/on" "yes/no"
| local_rank = os.environ.get("LOCAL_RANK") | |
| if local_rank is not None: | |
| try: | |
| return int(local_rank) | |
| except ValueError: | |
| pass | |
| return 0 | |
| rank = os.environ.get("LOCAL_RANK", "0") | |
| return int(rank) |
| import subprocess | ||
| import sys | ||
| import threading | ||
| from typing import Optional, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to keep it for <= Python 3.9 support that's fine. If not let's use | None and tuple for 3.10+
| f"falling back to staged SMEM path (returncode={rc}).", | ||
| file=sys.stderr, | ||
| ) | ||
| failing_proc = proc_128 if proc_128 is not None else proc_256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to spit out both error traces since both exist+fail?
Since 128 is the fallback, fixing the 256 probe makes more sense right?
| _CLUSTER_DIRECT_GMEM_PROBE_WARNED = False | ||
|
|
||
|
|
||
| def _probe_cluster_direct_gmem_max_copy_bits() -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't called until ~line 2560, do we want to move this lower?
Specifically somewhere after 263-299 which are still configurating the env variables (and called on import)
| """Resolve copy width (in bits) from the (import-time) policy string.""" | ||
| if _COPY_BITS_POLICY in {"128"}: | ||
| return 128 | ||
| if _COPY_BITS_POLICY in {"256"} and can_use_256: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why in instead of ==?
| # This relies on internal CuTeDSL runtime pointer fields (`_desc`, `_pointer`, | ||
| # etc.). If these internals change in a future CuTeDSL upgrade, callers | ||
| # should catch AttributeError and fall back to the regular launch path. | ||
| device_ptr = int(device_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why cast if the function expects an int?
| sm_count = ( | ||
| sm_count * sm_count_multiple | ||
| if N <= 8192 | ||
| else sm_count // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems strange, why would we ever want to run fewer than sm_count?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for clustered launches sm_count is effectively a cluster-count heuristic (matching Quack’s naming/launch shape). Launch uses grid=[sm_count, cluster_n, 1], so total CTAs is sm_count * cluster_n.
| _PTR_FAST_LAUNCH_TLS = threading.local() | ||
|
|
||
|
|
||
| def _env_flag(name: str, default: bool) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this already exists in init, remove dupes like that
| elif N <= 8192: | ||
| # Allow an override (used by 2-rows/CTA path for N≈6k/8k) | ||
| try: | ||
| return self._tpr_override # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems redundant wrt the first override check?
| @@ -0,0 +1,2927 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review, it's nice that we have an extension system now to try things in VLLM so I mostly want to spend time reviewing the kernel itself and what'd make it easier for vendors like VLLM to actually merge this in. Mostly reiterating points I made here https://x.com/marksaroufim/status/2009096176789016600?s=20
A lot of it stems from this file is too long but I think it shouldn't be too hard to clean it up
- we don't need the cache to work over multiple cute DSL versions, presumably they're making breaking changes fairly frequently so let's just pick the latest version and update as needed
- The code almost looks like a splatted autotune run because it's trying to handle many cases and choose between different optimization. I think we should just try and ship the one specific config that is fast on some specific shapes on a specific model that the VLLM team cares about on B200. Otherwise they'll have trouble reviewing this code even if it's faster and I'd rather we generalize the code progressively as the need arises
- A lot of the pointer marshalling code can be deleted in favor of using
tvm-ffi, a good chunk of the file is doing this and this will be error prone - Point 2 also will have unexpected side effects, where tons of fallback makes it unpredictable for an end user precisely which kernel configuration will run which is something all of our numerics sensitive customers will really care about. A user would often like to explicitly state whether they want an op to be in place or not. I'd argue that instead of environment variables gating specific optimizations we should have arguments to a function or separate functions. Even further PyTorch now has an intra kernel dispatcher where we can make guarantees on which specific kernel will be called for a specific shape
- Finally while I think an e2e test in VLLM works great, we probably also want some smaller unit tests comparing numerics vs vanilla PyTorch code and Quack right here
- Switch correctness gate to PyTorch ref + record err stats\n- Tighten Softmax/LayerNorm tolerances (Quack-like)\n- Quack-style benchmark suite layout + SVG plots\n- Packaging/README polish for publishability
Add the kernelagent-oink vLLM plugin that registers Blackwell (SM100) RMSNorm
custom ops via torch.library.custom_op under the oink:: namespace:
The SM100 CuTeDSL implementation is layout-aware and preserves padded-row
strides (stride(1)==1, stride(0)>=N) so torch.compile/CUDA-graph capture sees a
stable stride contract. Includes small-M latency tuning for DSv3-like N=7168
and maintains high-M bandwidth, with correctness-first fallbacks on non-SM100.