Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions .github/workflows/ci-paddle.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
name: CI Paddle
on:
push:
branches: [paddle]
tags: ["v*"]
pull_request:
merge_group:
workflow_dispatch:

permissions:
contents: read

concurrency:
group: "${{ github.workflow }}-${{ github.ref }}"
cancel-in-progress: true

jobs:
test:
name: Test
runs-on:
group: H20
timeout-minutes: 30
env:
container_name: tilelang-paddle-test-${{ github.run_id }}
steps:
- name: Check docker image and run container
env:
FLAGS_fraction_of_gpu_memory_to_use: 0.15
CTEST_PARALLEL_LEVEL: 2
WITH_GPU: "ON"
CUDA_ARCH_NAME: Hopper
WITH_AVX: "ON"
PY_VERSION: "3.12"
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GPU_DEVICES: 3
no_proxy: "bcebos.com,apiin.im.baidu.com,gitee.com,aliyun.com,.baidu.com,.tuna.tsinghua.edu.cn"
run: |
docker_image=ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:cuda129-coverage-test
docker run -d -t --gpus device=${GPU_DEVICES} --name ${{ env.container_name }} \
-v "/dev/shm:/dev/shm" \
-v ${{ github.workspace }}/../../..:${{ github.workspace }}/../../.. \
-v ${{ github.workspace }}:/workspace \
-e FLAGS_fraction_of_gpu_memory_to_use \
-e CTEST_PARALLEL_LEVEL \
-e WITH_GPU \
-e CUDA_ARCH_NAME \
-e WITH_AVX \
-e PY_VERSION \
-e GITHUB_TOKEN \
-e no_proxy \
-w /workspace \
--network host \
${docker_image}

- name: Checkout repository
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
set -e
source ${{ github.workspace }}/../../../proxy
git config --global --add safe.directory "*"
# Clean workspace
find . -maxdepth 1 ! -name "." -exec rm -rf {} +
# Checkout
git init
git remote add origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
git fetch origin ${{ github.ref }} --depth=1
git checkout FETCH_HEAD
git submodule update --init --recursive
'

- name: Install dependencies
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
set -e
source ${{ github.workspace }}/../../../proxy

# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.local/bin/env

# Create and activate virtual environment
uv venv .venv --seed -p 3.12
source .venv/bin/activate

# Install paddle
uv pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu129/

# Install project and minimal test runner
uv pip install -r requirements.txt
uv pip install -e .
'

- name: Run tests
run: |
docker exec -t ${{ env.container_name }} /bin/bash -c '
set -e
source .venv/bin/activate
RUN_IN_PADDLE_CI=ON make test
'

- name: Terminate and delete the container
if: always()
run: |
set +e
docker stop ${{ env.container_name }}
docker rm ${{ env.container_name }}
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,40 @@
Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
******************************************************************************** -->

# SonicMoE: Accelerating MoE with IO and Tile-aware Optimizations
# SonicMoE: Accelerating MoE with IO and Tile-aware Optimizations ❤️ PaddlePaddle
[![arXiv](https://img.shields.io/badge/arXiv-2512.14080-b31b1b.svg)](https://arxiv.org/abs/2512.14080)

**SonicMoE** is a simple but blazing-fast Mixture-of-Experts (MoE) implementation optimized for NVIDIA Hopper architecture GPUs. It mainly leverages [CuTeDSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html) and [Triton](https://triton-lang.org/main/getting-started/tutorials/index.html) to deliver state-of-the-art performance through IO-aware optimizations. These 2 figures provide an overview of activation memory usage and training throughput.

![image](./assets/mem.png)
![image](./assets/tput.png)

> [!NOTE]
>
> This repo is a fork of the original SonicMoE project, with modifications to enhance compatibility and integration with PaddlePaddle.
>
> **Installation**
>
> ```bash
> git clone https://github.com/PFCCLab/sonic-moe.git
> cd sonic-moe
> pip install -r requirements.txt
> pip install .
> ```
>
> **Usage**
>
> ```python
> import paddle
> paddle.enable_compat(scope={"sonicmoe", "quack", "triton"}) # Enable torch proxy before importing sonicmoe
> import sonicmoe
> # use sonicmoe
> ```

The original README.md content is as follows:

---

## 📦 Installation

### Prerequisites
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************

torch>=2.7.1
# torch>=2.7.1
nvidia-cutlass-dsl==4.2.1
quack-kernels @ git+https://github.com/Dao-AILab/quack.git@3d0ab3ec2164749caac8f269f771e66a40efd2de
# quack-kernels @ git+https://github.com/Dao-AILab/quack.git@3d0ab3ec2164749caac8f269f771e66a40efd2de
quack-kernels @ git+https://github.com/PFCCLab/quack.git@7ef82a90403f4a407f82d522d658d2b7e87ef733
pytest
parameterized
ninja
rich
filelock # For JIT compile
20 changes: 20 additions & 0 deletions sonicmoe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************

import paddle
import inspect

if not (hasattr(paddle.library.CustomOpDef, "__call__") and inspect.isfunction(paddle.library.CustomOpDef.__call__)):
def __call__(self, *args, **kwargs):
return getattr(getattr(paddle.ops, self._namespace), self._name)(*args, **kwargs)

paddle.library.CustomOpDef.__call__ = __call__

def torch_compat_empty(*args, **kwargs):
if "device" in kwargs and kwargs["device"] == "cuda":
del kwargs["device"]
return paddle.empty(*args, **kwargs)

paddle.compat.proxy._extend_torch_proxy_overrides(
{
"torch.empty": paddle.compat.proxy.RawOverriddenAttribute(torch_compat_empty),
}
)

from .count_cumsum import count_cumsum
from .enums import KernelBackendMoE
from .functional import enable_quack_gemm, moe_TC_softmax_topk_layer
Expand Down
5 changes: 5 additions & 0 deletions sonicmoe/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,10 @@ class ActivationType(Enum):
SILU = "silu"


class ScoringFuncType(Enum):
SOFTMAX = "softmax"
SIGMOID = "sigmoid"


def is_glu(activation_type: ActivationType):
return activation_type in [ActivationType.SWIGLU, ActivationType.REGLU, ActivationType.GEGLU]
62 changes: 40 additions & 22 deletions sonicmoe/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from quack.gemm_interface import gemm

from ..count_cumsum import count_cumsum
from ..enums import ActivationType, is_glu
from ..enums import ActivationType, ScoringFuncType, is_glu
from ..quack_utils import gemm_dgated, gemm_gated
from .backward import _down_projection_backward, _softmax_topk_bwd, _token_broadcast_backward, _up_projection_backward
from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward
Expand Down Expand Up @@ -84,31 +84,35 @@ def general_routing_router_metadata(

class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
@staticmethod
def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]:
def forward(
ctx, router_logits: torch.Tensor, E: int, K: int, scoring_func: ScoringFuncType
) -> tuple[torch.Tensor, torch.Tensor]:
T = router_logits.size(0)

# change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy
topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
ctx.mark_non_differentiable(topk_router_indices)

_softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K)
_softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K, scoring_func)

ctx.save_for_backward(topk_router_score, topk_router_indices)
ctx.E = E
ctx.dtype = router_logits.dtype

return topk_router_score, topk_router_indices
outs = topk_router_score, topk_router_indices
return outs

@staticmethod
def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
T, K = dtopk_score.size()

topk_router_score, topk_router_indices = ctx.saved_tensors
topk_router_score, topk_router_indices = ctx.saved_tensor()
dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)

_softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K)
_softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K, scoring_func)

return dlogits, None, None
return dlogits


class _UpProjection(torch.autograd.Function):
Expand Down Expand Up @@ -138,7 +142,7 @@ def forward(
TK = total_expert_freq

if is_using_quack_gemm():
assert not torch.compiler.is_compiling()
# assert not torch.compiler.is_compiling()
assert is_glu_activation, "QuACK GEMM does not support non GLU activation yet"
z, y1 = gemm_gated(
x,
Expand Down Expand Up @@ -194,7 +198,8 @@ def forward(

@staticmethod
def backward(ctx, _: None, dz: torch.Tensor):
is_compiling = torch.compiler.is_compiling()
# is_compiling = torch.compiler.is_compiling()
is_compiling = False

if not is_compiling:
assert _ is None
Expand All @@ -217,10 +222,10 @@ def backward(ctx, _: None, dz: torch.Tensor):
s_scatter_idx,
s_reverse_scatter_idx,
num_activated_expert_per_token_offset,
) = ctx.saved_tensors
) = ctx.saved_tensor()

dw1 = torch.empty_like(w1)
db1 = None if b1 is None else torch.empty_like(b1)
dw1 = torch.empty_like(w1).as_strided(w1.shape, w1.stride())
db1 = None if b1 is None else torch.empty_like(b1).as_strided(b1.shape, b1.stride())

if is_using_quack_gemm():
assert not is_compiling
Expand Down Expand Up @@ -264,7 +269,13 @@ def backward(ctx, _: None, dz: torch.Tensor):
is_varlen_K=is_varlen_K,
)

return dx_reduced, dw1, db1, *[None] * 12
# return dx_reduced, dw1, db1, *[None] * 12
grads = []
grads.extend([dx_reduced, dw1])
if db1 is not None:
grads.append(db1)
grads.extend([None] * 5)
return tuple(grads)


class _DownProjection(torch.autograd.Function):
Expand All @@ -291,7 +302,7 @@ def forward(
H, I, E = w2.shape

if is_using_quack_gemm():
assert not torch.compiler.is_compiling()
# assert not torch.compiler.is_compiling()

assert b2 is None
y2 = gemm(y1, w2.permute(2, 1, 0), cu_seqlens_m=expert_frequency_offset)
Expand Down Expand Up @@ -358,14 +369,14 @@ def backward(ctx, dout: torch.Tensor):
x_gather_idx,
s_scatter_idx,
s_reverse_scatter_idx,
) = ctx.saved_tensors
) = ctx.saved_tensor()

dw2 = torch.empty_like(w2)
db2 = None if b2 is None else torch.empty_like(b2)
dz = torch.empty_like(z)
dw2 = torch.empty_like(w2).as_strided(w2.shape, w2.stride())
db2 = None if b2 is None else torch.empty_like(b2).as_strided(b2.shape, b2.stride())
dz = torch.empty_like(z).as_strided(z.shape, z.stride())

if is_using_quack_gemm():
assert not torch.compiler.is_compiling()
# assert not torch.compiler.is_compiling()
assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet"

s = topk_scores[s_scatter_idx]
Expand Down Expand Up @@ -393,7 +404,7 @@ def backward(ctx, dout: torch.Tensor):

ds = ds[s_reverse_scatter_idx]
else:
ds = torch.empty_like(topk_scores)
ds = torch.empty_like(topk_scores).as_strided(topk_scores.shape, topk_scores.stride())
_down_projection_backward(
dout=dout,
z=z,
Expand All @@ -417,7 +428,13 @@ def backward(ctx, dout: torch.Tensor):
if not is_varlen_K:
ds = ds.view(T, K)

return None, dz, dw2, db2, ds, *[None] * 10
# return None, dz, dw2, db2, ds, *[None] * 10
grads = []
grads.extend([None, dz, dw2])
if db2 is not None:
grads.append(db2)
grads.extend([ds, *[None] * 5])
return tuple(grads)


def moe_TC_softmax_topk_layer(
Expand All @@ -430,13 +447,14 @@ def moe_TC_softmax_topk_layer(
K: int,
stream_id: int,
activation_type: ActivationType | str = ActivationType.SWIGLU,
scoring_func: ScoringFuncType | str = ScoringFuncType.SOFTMAX,
is_inference_mode_enabled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert ((b1 is None) and (b2 is None)) or (
(b1 is not None) and (b2 is not None)
), "b1 and b2 has to be None or not None at the same time!"
router_logits = F.linear(x, router_w)
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K)
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K, scoring_func)
expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), router_w.size(0), do_cumsum=True)

(
Expand Down
Loading