Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,23 @@ def compute_hash(self) -> str:
factors.append(self.tensor_parallel_size)
factors.append(self.enable_dp_attention)

# ATOM_DEPTH_AWARE_COMPILE_CACHE: the piecewise torch.compile graph's structure
# (and its input count) depends on the model depth/shape. These
# are NOT covered by the vllm/compilation/parallel sub-hashes
# above, so without them a graph compiled for one num_hidden_layers
# is silently reused for another (e.g. a 7-layer debug copy of a
# 61-layer model), tripping IndexError in inductor
# copy_misaligned_inputs at warmup. Fold the structural HF fields
# into the cache key so each shape gets its own compiled-graph dir.
hf_config = getattr(self, "hf_config", None)
if hf_config is not None:
for _attr in (
"num_hidden_layers",
"first_k_dense_replace",
"num_nextn_predict_layers",
):
factors.append(getattr(hf_config, _attr, None))

hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
Expand Down
92 changes: 72 additions & 20 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm,
)

# ATOM_GFX1250_WORKAROUND (ported from ce1809f8473f): aiter's HIP
# concat_and_cache_mla and fused_qk_rope_concat_and_cache_mla use opus::gmem
# buffer-resource intrinsics that fault on gfx1250 ("Memory access fault on
# (nil)") regardless of dtype. The call sites below swap them (plus the
# preceding self.rotary_emb call) for aiter's Triton fused_qk_rope_cat_and_cache_mla
# which lowers to plain global_load/global_store.
from aiter.ops.triton.fusions.fused_kv_cache import (
fused_qk_rope_cat_and_cache_mla as _atom_triton_fused_rope_cat_and_cache_mla,
)

concat_and_cache_mla = mark_trace(
concat_and_cache_mla, prefix="kv_cache", torch_compile=False
)
Expand Down Expand Up @@ -868,21 +878,55 @@ def forward_impl_server_mode(
kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache

if context.is_prefill and not use_prefill_mla:
prefill_q = self.q_proj(q, x_scale=q_scale).view(
prefill_q_proj = self.q_proj(q, x_scale=q_scale).view(
-1, self.num_heads, self.qk_head_dim
)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :]
self.rotary_emb(positions, prefill_q_pe, k_rope)
prefill_q_nope = prefill_q_proj[..., : self.qk_nope_head_dim]
prefill_q_pe = prefill_q_proj[..., self.qk_nope_head_dim :]

if kv_cache.numel() > 0:
concat_and_cache_mla(
k_nope,
k_rope.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=self._k_scale,
# ATOM_GFX1250_WORKAROUND (ported from ce1809f8473f): aiter's HIP
# concat_and_cache_mla faults on gfx1250. Use the aiter Triton
# fused RoPE+concat+cache kernel, which also subsumes the
# standalone rotary_emb call.
#
# k_pe rotation: the HIP path rotated k_rope IN PLACE via
# self.rotary_emb(positions, prefill_q_pe, k_rope) before the
# cache write, so downstream consumers (_forward_prefill_mha,
# _forward_prefill_cached_chunked) saw the rotated k_rope. The
# Triton kernel does NOT rotate the input k_pe in place — it
# returns the rotated copy as _k_pe_out. We must rebind k_rope
# to that copy so the new-tokens attention uses rotated k_pe
# against rotated q_pe. (The cached-prefix path reads from
# kv_cache, which already holds the rotated k_pe.)
_k_nope_3d = k_nope.unsqueeze(1) if k_nope.dim() == 2 else k_nope
_k_rope_3d = k_rope.unsqueeze(1) if k_rope.dim() == 2 else k_rope
# Return tuple order: (q_out, decode_q_pe_out, k_pe_out,
# q_nope_zeros_out). We want the third (rotated k_pe).
prefill_q, _decode_q_pe_out, _k_pe_out, _q_nope_zeros_out = (
_atom_triton_fused_rope_cat_and_cache_mla(
prefill_q_nope,
prefill_q_pe,
_k_nope_3d,
_k_rope_3d,
kv_cache,
attn_metadata.slot_mapping.flatten(),
positions,
self.rotary_emb.cos_cache,
self.rotary_emb.sin_cache,
self._k_scale,
is_neox=self.rotary_emb.is_neox_style,
apply_scale=(self.kv_cache_dtype == "fp8"),
q_out_dtype=prefill_q_proj.dtype,
)
)
k_rope = _k_pe_out
else:
# No kv-cache write needed (e.g. dummy run); apply RoPE the
# old way so prefill_q has the rotated pe slice. rotary_emb
# rotates k_rope in place here, matching the HIP behaviour.
self.rotary_emb(positions, prefill_q_pe, k_rope)
prefill_q = prefill_q_proj

if attn_metadata.has_cached:
chunk_meta = getattr(attn_metadata, "mla_chunk_meta", None)
Expand Down Expand Up @@ -911,23 +955,31 @@ def forward_impl_server_mode(
device=q_nope.device,
)
if kv_cache.numel() > 0:
fused_qk_rope_concat_and_cache_mla(
# ATOM_GFX1250_WORKAROUND (ported from ce1809f8473f): aiter's HIP
# fused_qk_rope_concat_and_cache_mla
# ('fuse_qk_rope_concat_and_cache_mla_per_head_kernel') faults
# inside captured cudagraphs on gfx1250 (opus::gmem
# buffer-resource issue). Swap to the equivalent aiter Triton
# kernel.
_kv_cache_v = kv_cache.view(
kv_cache.shape[0], -1, self.kv_lora_rank + self.qk_rope_head_dim
)
_k_nope_3d = k_nope.unsqueeze(1) if k_nope.dim() == 2 else k_nope
_k_rope_3d = k_rope.unsqueeze(1) if k_rope.dim() == 2 else k_rope
_atom_triton_fused_rope_cat_and_cache_mla(
q_nope,
q_rope,
k_nope,
k_rope,
kv_cache.view(
kv_cache.shape[0], -1, self.kv_lora_rank + self.qk_rope_head_dim
),
q_out,
_k_nope_3d,
_k_rope_3d,
_kv_cache_v,
attn_metadata.slot_mapping,
self._k_scale,
self._q_scale,
positions,
self.rotary_emb.cos_cache,
self.rotary_emb.sin_cache,
self._k_scale,
is_neox=self.rotary_emb.is_neox_style,
is_nope_first=True,
apply_scale=(self.kv_cache_dtype == "fp8"),
q_out=q_out,
)
# q_out = self.fused_kv_bmm(q, q_scale, k_nope, k_rope, positions, kv_cache, attn_metadata)

Expand Down
38 changes: 33 additions & 5 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.gated_rmsnorm_fp8_group_quant import gated_rmsnorm_fp8_group_quant
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
# ATOM_GFX1250_WORKAROUND (ported from ce1809f8473f): aiter HIP rmsnorm2d_fwd
# faults on gfx1250; route to the AITER Triton equivalents at every call site.
from aiter.ops.triton.normalization.rmsnorm import (
_rmsnorm_forward as _triton_rmsnorm_forward,
_rmsnorm_forward_with_add as _triton_rmsnorm_forward_with_add,
)
from atom.config import QuantizationConfig
from atom.model_ops.utils import atom_parameter
from atom.quant_spec import LayerQuantConfig
Expand Down Expand Up @@ -51,24 +57,46 @@ def silu(input: Tensor, inplace: bool = False) -> Tensor:
return torch._C._nn.silu(input)


@torch_compile_guard()
def rmsnorm2d_fwd_fake_tensors(
x: torch.Tensor, weight: torch.Tensor, eps: float, dim: int
) -> torch.Tensor:
return torch.empty_like(x)


@torch_compile_guard(gen_fake=rmsnorm2d_fwd_fake_tensors)
def rmsnorm2d_fwd_(
x: torch.Tensor, weight: torch.Tensor, eps: float, dim: int
) -> torch.Tensor:
# ATOM_GFX1250_WORKAROUND: route to Triton; HIP rmsnorm2d_fwd faults on gfx1250.
ori_shape = x.shape
x = x.reshape(-1, dim)
return rmsnorm2d_fwd(x, weight, eps).view(ori_shape)
out, _ = _triton_rmsnorm_forward(x, weight, eps)
return out.view(ori_shape)


@torch_compile_guard()
def rmsnorm2d_fwd_with_add_fake_tensors(
x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, eps: float, dim: int
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(x)


@torch_compile_guard(gen_fake=rmsnorm2d_fwd_with_add_fake_tensors)
def rmsnorm2d_fwd_with_add_(
x: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor, eps: float, dim: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# ATOM_GFX1250_WORKAROUND: route to Triton; HIP rmsnorm2d_fwd_with_add faults on gfx1250.
# The Triton _fused_add_rmsnorm_kernel only takes ONE input-side row-stride and reuses
# it for x/residual_in/residual_out, so force the same contiguous layout before launch.
ori_shape = x.shape
x = x.reshape(-1, dim)
x = x.reshape(-1, dim).contiguous()
residual = residual.contiguous()
weight_c = weight.contiguous()
out = torch.empty_like(x)
residual_out = torch.empty_like(x)
rmsnorm2d_fwd_with_add(out, x, residual, residual_out, weight, eps)
rsigma = torch.empty((x.shape[0],), dtype=torch.float32, device=x.device)
_triton_rmsnorm_forward_with_add(
out, x, residual, residual_out, weight_c, rsigma, eps
)
return out.view(ori_shape), residual_out.view(ori_shape)


Expand Down
72 changes: 59 additions & 13 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import Callable, List, Optional, Tuple
import logging

import torch
from aiter import ActivationType, QuantType, dtypes, get_hip_quant, topk_gating
Expand Down Expand Up @@ -58,6 +59,8 @@
from atom.plugin.moe import FusedMoEDecoratorForPluginMode
from atom.quantization.quark.utils import weight_dequant_fp8

logger = logging.getLogger("atom")


class FusedMoeWeightScaleSupported(Enum):
"""Supported quantization strategies for MoE weight scales."""
Expand Down Expand Up @@ -749,14 +752,22 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig):
or self.quant_type == QuantType.per_1x32
)
gfx = get_gfx()
self.is_gfx1250 = gfx == "gfx1250"
# gfx1250 grouped a8w4 MoE kernel only supports the non-interleaved
# (gate|up separated) scale layout; reject is_guinterleave up front.
if self.is_gfx1250 and self.is_guinterleave:
raise NotImplementedError(
"gfx1250 MoE only supports is_guinterleave=False; "
"unset ATOM_MOE_GU_ITLV."
)
if envs.is_set("ATOM_USE_TRITON_MOE"):
self.use_triton = envs.ATOM_USE_TRITON_MOE
else:
self.use_triton = (
gfx.startswith("gfx94")
or gfx.startswith("gfx12")
or (gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM)
)
logger.info(f"Mxfp4MoEMethod use_triton = {self.use_triton}")
if self.use_triton:
from atom.model_ops.utils import has_triton_kernels

Expand Down Expand Up @@ -935,19 +946,54 @@ def process_weights_after_loading(self, layer):
layer.w2_weight.is_shuffled = True

# shuffle scale
w13_scale_2d = layer.w13_weight_scale.reshape(
-1, layer.w13_weight_scale.shape[-1]
)
w2_scale_2d = layer.w2_weight_scale.reshape(-1, layer.w2_weight_scale.shape[-1])
if self.is_gfx1250:
# gfx1250 grouped a8w4 MoE kernel reads the e8m0 scale preshuffled by
# _grouped_a8w4_prepare_scale_batch (warp_tile = tile_n // n_warp =
# 64 // 2 = 32, tile_k = 128). w13 is the (gate|up) operand
# (rows = 2*inter, k_dim = model_dim); w2 (down) has
# rows = model_dim, k_dim = inter.
from aiter.fused_moe import _grouped_a8w4_prepare_scale_batch

_GROUPED_WARP_TILE_N = 32
_GROUPED_TILE_K = 128
layer.w13_weight_scale = atom_parameter(
_grouped_a8w4_prepare_scale_batch(
layer.w13_weight_scale.data,
experts=self.num_experts,
rows=2 * self.intermediate_size,
k_dim=self.hidden_size,
warp_tile=_GROUPED_WARP_TILE_N,
tile_k=_GROUPED_TILE_K,
device=layer.w13_weight_scale.device,
)
)
layer.w2_weight_scale = atom_parameter(
_grouped_a8w4_prepare_scale_batch(
layer.w2_weight_scale.data,
experts=self.num_experts,
rows=self.hidden_size,
k_dim=self.intermediate_size,
warp_tile=_GROUPED_WARP_TILE_N,
tile_k=_GROUPED_TILE_K,
device=layer.w2_weight_scale.device,
)
)
else:
w13_scale_2d = layer.w13_weight_scale.reshape(
-1, layer.w13_weight_scale.shape[-1]
)
w2_scale_2d = layer.w2_weight_scale.reshape(
-1, layer.w2_weight_scale.shape[-1]
)

shuffled_w13_scale = shuffle_scale(
w13_scale_2d, self.num_experts, self.is_guinterleave, True
)
shuffled_w2_scale = shuffle_scale(
w2_scale_2d, self.num_experts, self.is_guinterleave, False
)
layer.w13_weight_scale = atom_parameter(shuffled_w13_scale)
layer.w2_weight_scale = atom_parameter(shuffled_w2_scale)
shuffled_w13_scale = shuffle_scale(
w13_scale_2d, self.num_experts, self.is_guinterleave, True
)
shuffled_w2_scale = shuffle_scale(
w2_scale_2d, self.num_experts, self.is_guinterleave, False
)
layer.w13_weight_scale = atom_parameter(shuffled_w13_scale)
layer.w2_weight_scale = atom_parameter(shuffled_w2_scale)

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
Expand Down
27 changes: 7 additions & 20 deletions atom/model_ops/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,14 @@ def _temperature_sample(
) -> torch.Tensor:
"""Temperature-based Gumbel-max sampling.

When ``needs_independent_noise`` is True the per-row exponential noise
tensor is freshly drawn with shape ``(num_tokens, vocab_size)`` so that
fan-out siblings produced by ``SamplingParams.n > 1`` diverge instead
of collapsing onto the same token when they share logits. Otherwise we
keep the cached ``(1, vocab_size)`` row broadcasted across the batch,
which preserves the existing run-to-run determinism optimization.
ATOM_GFX1250_WORKAROUND (ported from ce1809f8473f): aiter HIP
mixed_sample_outer_exponential faults on gfx1250 (opus::gmem
buffer-resource issue). Use torch-native equivalent. Simplification:
always do greedy argmax for now; bring-up uses temperature=0.
Revisit when non-greedy sampling is needed.
"""
num_tokens, vocab_size = logits.shape
sampled_tokens = torch.empty(num_tokens, dtype=torch.int, device=logits.device)
if needs_independent_noise:
exponential = torch.empty(
(num_tokens, vocab_size), dtype=torch.float, device=logits.device
).exponential_(1)
else:
exponential = get_per_token_exponential(vocab_size, logits.device).expand(
num_tokens, vocab_size
)
mixed_sample_outer_exponential(
sampled_tokens, logits, exponential, temperatures, eps=self.eps
)
return sampled_tokens
del needs_independent_noise # noqa: F841 — unused on the workaround path
return logits.argmax(dim=-1).to(torch.int)

def _topk_topp_sample(
self,
Expand Down
Loading
Loading