Skip to content
312 changes: 278 additions & 34 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def make_preshuffle_b_layout(

if elem_bytes not in (1, 2):
raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}")
c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True)
c_k_bytes = c_k * fx.Index(int(elem_bytes))
c_k0 = c_k_bytes // c64
n0 = c_n // c16

c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True))
c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // fx.Index(int(elem_bytes)))

stride_nlane = c_kpack_elems
stride_klane = c16 * stride_nlane
Expand All @@ -108,6 +108,30 @@ def make_preshuffle_b_layout(
return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes)


def _unpack_int4_to_int8_pair(packed32):
"""Split packed int4 dword into two int8 dwords (even/odd nibbles).

7-op bit manipulation shared by all int4 unpack paths (W4A8, W4A16, W4A_FP8).
"""
c_08 = fx.Int32(0x08080808)
c_0f = fx.Int32(0x0F0F0F0F)
c_1e = fx.Int32(0x1E)
c_4 = fx.Int32(4)
s0 = (packed32 & c_08) * c_1e
even = (packed32 & c_0f) | s0
t = packed32 >> c_4
s1 = (t & c_08) * c_1e
odd = (t & c_0f) | s1
return even, odd


def _pack_i32_pair_to_i64(lo, hi, vector):
"""Pack two i32 values into one i64 via vector bitcast."""
v2 = vector.from_elements(T.vec(2, T.i32), [lo, hi])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])


def _i8x4_in_i32_to_bf16x4_i64(val_i32, arith, vector, scale_val=None):
"""Convert one i32 (4 signed int8 bytes) to 4 bf16 packed as i64.

Expand Down Expand Up @@ -205,18 +229,7 @@ def unpack_b_w4a16(packed32, arith, vector, scale_val=None):
Takes raw packed32 from load_b_raw_w4a16 and produces (b0, b1) --
two i64 values each containing 4 bf16 for one MFMA.
"""
c_08080808 = fx.Int32(0x08080808)
c_0f0f0f0f = fx.Int32(0x0F0F0F0F)
c_1e = fx.Int32(0x1E)
c_4_i32 = fx.Int32(4)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

even, odd = _unpack_int4_to_int8_pair(packed32)
b0 = _i8x4_in_i32_to_bf16x4_i64(even, arith, vector, scale_val=scale_val)
b1 = _i8x4_in_i32_to_bf16x4_i64(odd, arith, vector, scale_val=scale_val)
return (b0, b1)
Expand Down Expand Up @@ -252,12 +265,12 @@ def load_b_pack_k32(
raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}")

c64 = fx.Index(64)
base_k_bytes = base_k * arith.constant(int(elem_bytes), index=True)
base_k_bytes = base_k * fx.Index(int(elem_bytes))
k0_base = base_k_bytes // c64
k0 = k0_base + arith.constant(ki_step // 2, index=True)
k0 = k0_base + fx.Index(ki_step // 2)
k1 = lane_div_16
half_bytes = kpack_bytes // 2
k2_base = arith.constant((ki_step % 2) * half_bytes, index=True)
k2_base = fx.Index((ki_step % 2) * half_bytes)

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
Expand All @@ -273,22 +286,8 @@ def load_b_pack_k32(
static_position=[0],
dynamic_position=[],
)

c_08080808 = fx.Int32(0x08080808)
c_0f0f0f0f = fx.Int32(0x0F0F0F0F)
c_1e = fx.Int32(0x1E)
c_4_i32 = fx.Int32(4)

s0 = (packed32 & c_08080808) * c_1e
even = (packed32 & c_0f0f0f0f) | s0

t = packed32 >> c_4_i32
s1 = (t & c_08080808) * c_1e
odd = (t & c_0f0f0f0f) | s1

v2 = vector.from_elements(T.vec(2, T.i32), [even, odd])
v64 = vector.bitcast(T.vec(1, T.i64), v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])
even, odd = _unpack_int4_to_int8_pair(packed32)
return _pack_i32_pair_to_i64(even, odd, vector)

vec_elems = kpack_bytes // int(elem_bytes)
b16 = _buffer_load_vec(
Expand Down Expand Up @@ -324,7 +323,7 @@ def tile_chunk_coord_i32(
"""Map (thread, chunk_id) -> (row_local, col_local_i32) for X/A loads."""
if chunk_i32 not in (1, 2, 4):
raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}")
chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True)
chunk_off_i32 = fx.Index(i * total_threads * chunk_i32)
tile_idx_i32 = tx_i32_base + chunk_off_i32
coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4)
row_local = fx.get(coord_local, 0)
Expand Down Expand Up @@ -476,6 +475,251 @@ def lds_load_pack_k32(
"lds_store_16b_xor16",
"make_preshuffle_b_layout",
"load_b_pack_k32",
"load_b_raw_w4a16",
"unpack_b_w4a16",
"load_b_raw_w4a16_groupwise",
"unpack_b_w4a16_groupwise",
"load_b_raw_w4a8_k64",
"load_b_raw_w4a8_groupwise_k64",
"unpack_b_w4a8",
"unpack_b_w4a_fp8",
"swizzle_xor16",
"tile_chunk_coord_i32",
]


# ---------------------------------------------------------------------------
# Groupwise scale load helper (shared by W4A16 and W4A8 groupwise paths)
# ---------------------------------------------------------------------------

def _load_groupwise_scale(
buffer_ops,
arith,
*,
scale_rsrc,
expert_offset,
n_blk,
n_intra,
k_pos,
num_groups: int,
group_size: int,
n_per_expert: int,
):
"""Load one per-group scale value from the scale buffer.

Computes the linear index into the scale tensor from expert offset,
N position, and group index derived from ``k_pos``.
"""
c16 = fx.Index(16)
n_global = n_blk * c16 + n_intra
c_group_size = fx.Index(group_size)
c_gm1 = fx.Index(num_groups - 1)
c_npe = fx.Index(n_per_expert)
# n_global is the GLOBAL N index (includes expert offset), so use (G-1)
# to compensate: expert_offset*(G-1) + (expert_offset + n_within) = expert_offset*G + n_within
base_scale = expert_offset * c_gm1 + n_global
group_idx = k_pos // c_group_size
scale_idx_i32 = arith.index_cast(T.i32, base_scale + group_idx * c_npe)
return buffer_ops.buffer_load(scale_rsrc, scale_idx_i32, vec_width=1, dtype=T.f32)


# ---------------------------------------------------------------------------
# W4A16 groupwise load / unpack helpers
# ---------------------------------------------------------------------------

def load_b_raw_w4a16_groupwise(
Comment thread
MHYangAMD marked this conversation as resolved.
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A16 groupwise B load: buffer_loads for weight + scale.

Reuses :func:`load_b_raw_w4a16` for the weight load, then issues an
additional ``buffer_load_dword`` for the per-group scale.

Returns ``(packed32, scale_val)``.
"""
packed32 = load_b_raw_w4a16(
buffer_ops, arith, vector,
arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b,
base_k=base_k, ku=ku,
n_blk=n_blk, n_intra=n_intra,
lane_div_16=lane_div_16, elem_type=elem_type,
kpack_bytes=kpack_bytes,
)
k_pos = base_k + fx.Index(ku * 32)
scale_val = _load_groupwise_scale(
buffer_ops, arith,
scale_rsrc=scale_rsrc, expert_offset=expert_offset,
n_blk=n_blk, n_intra=n_intra, k_pos=k_pos,
num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert,
)
return (packed32, scale_val)


def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector):
"""Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16."""
return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val)


# ---------------------------------------------------------------------------
# W4A8 load / unpack helpers (8B K64 loads)
# ---------------------------------------------------------------------------

def load_b_raw_w4a8_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 per-row B load: 8-byte buffer_load_dwordx2 for one K64 step.

Loads both K32 halves in a single VMEM instruction (``buffer_load_dwordx2``).
Returns ``(packed32_half0, packed32_half1)`` for :func:`unpack_b_w4a8`.
"""
if kpack_bytes != 8:
raise ValueError(f"W4A8 requires kpack_bytes=8, got {kpack_bytes!r}")

c64 = fx.Index(64)
k0_base = base_k // c64
k0 = k0_base + fx.Index(ku)
k1 = lane_div_16

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)

b8 = _buffer_load_vec(
buffer_ops, vector, b_rsrc, idx_pack,
elem_type=elem_type, vec_elems=8, elem_bytes=1, offset_in_bytes=True,
)
b_i32x2 = vector.bitcast(T.vec(2, T.i32), b8)
half0 = vector.extract(b_i32x2, static_position=[0], dynamic_position=[])
half1 = vector.extract(b_i32x2, static_position=[1], dynamic_position=[])
return (half0, half1)


def load_b_raw_w4a8_groupwise_k64(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k,
ku: int,
n_blk,
n_intra,
lane_div_16,
elem_type,
scale_rsrc,
expert_offset,
num_groups: int,
group_size: int,
n_per_expert: int,
kpack_bytes: int = 8,
):
"""Phase 1 of W4A8 groupwise B load: 8B weight + two scale loads per K64.

Reuses :func:`load_b_raw_w4a8_k64` for the weight load, then issues two
``buffer_load_dword`` for per-group scales (each K32 half may belong to a
different group).

Returns ``(half0, half1, scale0, scale1)``.
"""
half0, half1 = load_b_raw_w4a8_k64(
buffer_ops, arith, vector,
arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b,
base_k=base_k, ku=ku,
n_blk=n_blk, n_intra=n_intra,
lane_div_16=lane_div_16, elem_type=elem_type,
kpack_bytes=kpack_bytes,
)

scale_kw = dict(
scale_rsrc=scale_rsrc, expert_offset=expert_offset,
n_blk=n_blk, n_intra=n_intra,
num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert,
)
scale0 = _load_groupwise_scale(
buffer_ops, arith, k_pos=base_k + fx.Index(ku * 2 * 32), **scale_kw,
)
scale1 = _load_groupwise_scale(
buffer_ops, arith, k_pos=base_k + fx.Index((ku * 2 + 1) * 32), **scale_kw,
)
return (half0, half1, scale0, scale1)


def unpack_b_w4a8(packed32, arith, vector):
"""Phase 2 of W4A8 B load: 7-op unpack from packed int4 to int8 i64.

Takes a raw ``packed32`` (one dword of packed int4) and produces one i64
value containing 8 signed int8 bytes for one MFMA K32 step.
"""
even, odd = _unpack_int4_to_int8_pair(packed32)
return _pack_i32_pair_to_i64(even, odd, vector)


def unpack_b_w4a_fp8(packed32, arith, vector, rocdl):
"""Unpack packed int4 (i32) to fp8 i64 for mfma_f32_16x16x32_fp8_fp8.

Pipeline: int4 -> int8 (7-op unpack) -> f32 (byte extract + sitofp)
-> fp8 (cvt_pk_fp8_f32) -> i64.
"""
even, odd = _unpack_int4_to_int8_pair(packed32)

c_8 = fx.Int32(8)
c_16 = fx.Int32(16)
c_24 = fx.Int32(24)

from flydsl._mlir.dialects._arith_ops_gen import ShRSIOp as _ShRSIOp
_uw = arith._to_raw
_av = arith.ArithValue

def _i32_int8x4_to_fp8x4(val):
"""Convert i32 containing 4 signed int8 bytes -> i32 containing 4 fp8 bytes."""
def _sext_byte(src, shl_amount, shr_amount):
shifted = src << shl_amount
shrsi_result = _ShRSIOp(_uw(shifted), _uw(shr_amount)).result
return _uw(arith.sitofp(T.f32, _av(shrsi_result)))

f0 = _sext_byte(val, c_24, c_24)
f1 = _sext_byte(val, c_16, c_24)
f2 = _sext_byte(val, c_8, c_24)
b3 = _ShRSIOp(_uw(val), _uw(c_24)).result
f3 = _uw(arith.sitofp(T.f32, _av(b3)))

zero = _uw(fx.Int32(0))
pk = rocdl.cvt_pk_fp8_f32(src_a=f0, src_b=f1, old=zero, word_sel=0, res=T.i32)
pk = rocdl.cvt_pk_fp8_f32(src_a=f2, src_b=f3, old=_uw(pk), word_sel=1, res=T.i32)
return pk

even_fp8 = _i32_int8x4_to_fp8x4(even)
odd_fp8 = _i32_int8x4_to_fp8x4(odd)
return _pack_i32_pair_to_i64(even_fp8, odd_fp8, vector)
4 changes: 3 additions & 1 deletion kernels/moe_blockscale_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

from flydsl._mlir import ir
from flydsl._mlir.dialects import llvm, scf, memref
from flydsl._mlir.dialects import fly as _fly_dialect
from flydsl._mlir.dialects import math as math_dialect
from kernels.kernels_common import _create_llvm_ptr
from flydsl.expr.typing import T
from flydsl.expr.arith import ArithValue

Expand Down Expand Up @@ -2339,7 +2341,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag):
byte_off = idx_elem_even * c2_i32
byte_off_idx = arith.index_cast(T.index, byte_off)
ptr_addr_idx = out_base_idx + byte_off_idx
out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr = _create_llvm_ptr(ptr_addr_idx, address_space=1)
out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr
frag_v = frag._value if hasattr(frag, "_value") else frag
llvm.AtomicRMWOp(
Expand Down
Loading
Loading