Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions docs/AMD_workshop/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,46 @@ powershell -ExecutionPolicy Bypass -Command "iwr -UseBasicParsing https://raw.gi
popcorn-cli register github
```

3. **Submit your solution:**
```bash
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test example.py
```

4. **Interactive mode** (choose GPU and options):
```bash
popcorn-cli submit example.py
```
## 🏃 Run Examples

Try out the example implementations to get familiar with the system:

### For Linux/macOS:
```bash
# Download and test v1.py (reference implementation)
wget https://raw.githubusercontent.com/gpu-mode/popcorn-cli/main/docs/AMD_workshop/v1.py
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test v1.py

# Download and test v2.py (basic optimization)
wget https://raw.githubusercontent.com/gpu-mode/popcorn-cli/main/docs/AMD_workshop/v2.py
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test v2.py

# Download and test v3.py (advanced optimization)
wget https://raw.githubusercontent.com/gpu-mode/popcorn-cli/main/docs/AMD_workshop/v3.py
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test v3.py
```

### For Windows (PowerShell):
```powershell
# Download and test v1.py (reference implementation)
Invoke-WebRequest -Uri "https://raw.githubusercontent.com/gpu-mode/popcorn-cli/main/docs/AMD_workshop/v1.py" -OutFile "v1.py"
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test v1.py

# Download and test v2.py (basic optimization)
Invoke-WebRequest -Uri "https://raw.githubusercontent.com/gpu-mode/popcorn-cli/main/docs/AMD_workshop/v2.py" -OutFile "v2.py"
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test v2.py

# Download and test v3.py (advanced optimization)
Invoke-WebRequest -Uri "https://raw.githubusercontent.com/gpu-mode/popcorn-cli/main/docs/AMD_workshop/v3.py" -OutFile "v3.py"
popcorn-cli submit --gpu MI300 --leaderboard amd-fp8-mm --mode test v3.py
```

### 💡 Pro Tips:
- Start with **v1.py** (reference implementation) to understand the baseline
- Try **v2.py** for basic optimizations
- Challenge yourself with **v3.py** for advanced Triton optimizations
- Use `--mode benchmark` instead of `--mode test` to see performance metrics


## 🛠️ Manual Installation

Expand All @@ -58,3 +89,11 @@ If the scripts don't work, you can manually install:
- Run `popcorn-cli --help` for usage information
- Check the [main repository](https://github.com/gpu-mode/popcorn-cli) and open an issue
- Join the [GPU Mode Discord](https://discord.gg/gpumode) and ask a question in #amd-competition

## 🧑‍🎓 Learn more from our favorite writeups

* https://github.com/luongthecong123/fp8-quant-matmul
* https://seb-v.github.io/optimization/update/2025/01/20/Fast-GPU-Matrix-multiplication.html
* https://akashkarnatak.github.io/amd-challenge/
* https://www.bilibili.com/read/cv41954307/?opus_fallback=1
* https://github.com/Snektron/gpumode-amd-fp8-mm
3 changes: 3 additions & 0 deletions docs/AMD_workshop/example.py → docs/AMD_workshop/v1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!POPCORN leaderboard amd-fp8-mm
#!POPCORN gpu MI300

import torch
from task import input_t, output_t

Expand Down
125 changes: 125 additions & 0 deletions docs/AMD_workshop/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!POPCORN leaderboard amd-fp8-mm
#!POPCORN gpu MI300

from task import input_t, output_t
import torch
import triton
import triton.language as tl


@triton.jit
def kernel(
A_ptr,
B_ptr,
A_scale_ptr,
B_scale_ptr,
C_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_Q: tl.constexpr = 128,
):
program_id = tl.program_id(0)
num_pid_across_n = tl.cdiv(N, BLOCK_N)

program_id_m = program_id // num_pid_across_n
program_id_n = program_id % num_pid_across_n

# Simple stride assumptions (no transpose)
A_stride_m, A_stride_k = 1, M
B_stride_n, B_stride_k = 1, N
C_stride_m, C_stride_n = N, 1

# Scale matrices: A is 1x128, B is 128x128 chunks
A_scale_stride_m, A_scale_stride_k = 1, M
B_scale_stride_n, B_scale_stride_k = 1, tl.cdiv(N, BLOCK_Q)

# Calculate output block position
offset_m = program_id_m * BLOCK_M
offset_n = program_id_n * BLOCK_N

# Create block offset arrays
block_offsets_m = offset_m + tl.arange(0, BLOCK_M)
block_offsets_n = offset_n + tl.arange(0, BLOCK_N)
block_offsets_k = tl.arange(0, BLOCK_K)

# Create pointers for A and B blocks
A_block_ptrs = A_ptr + (
block_offsets_m[:, None] * A_stride_m + block_offsets_k[None, :] * A_stride_k
)
B_block_ptrs = B_ptr + (
block_offsets_k[:, None] * B_stride_k + block_offsets_n[None, :] * B_stride_n
)

# Scale pointers
A_scale_block_ptrs = A_scale_ptr + (block_offsets_m[:, None] * A_scale_stride_m)
B_scale_block_ptrs = B_scale_ptr + (offset_n // BLOCK_Q) * B_scale_stride_n

# Main accumulator
master_accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# Process K dimension in BLOCK_Q chunks (128 elements at a time)
num_k_iters = K // BLOCK_Q
for _ in range(0, num_k_iters):
# Inner accumulator for current 128-element K chunk
inner_accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# Process the 128-element chunk in smaller BLOCK_K pieces
for _ in tl.range(0, BLOCK_Q // BLOCK_K):
A_block = tl.load(A_block_ptrs) # (BLOCK_M, BLOCK_K)
B_block = tl.load(B_block_ptrs) # (BLOCK_K, BLOCK_N)
inner_accumulator = tl.dot(A_block, B_block, inner_accumulator)

# Move to next K chunk
A_block_ptrs += BLOCK_K * A_stride_k
B_block_ptrs += BLOCK_K * B_stride_k

# Load scales and apply to inner result
A_scales = tl.load(A_scale_block_ptrs) # (BLOCK_M, 1)
B_scales = tl.load(B_scale_block_ptrs) # scalar
master_accumulator += inner_accumulator * (A_scales * B_scales)

# Move to next scale block
A_scale_block_ptrs += A_scale_stride_k
B_scale_block_ptrs += B_scale_stride_k

# Store final result
block_offsets_m = (program_id_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None])
block_offsets_n = (program_id_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :])
mask = (block_offsets_m < M) & (block_offsets_n < N)
C_block_ptrs = C_ptr + (block_offsets_m * C_stride_m + block_offsets_n * C_stride_n)
tl.store(C_block_ptrs, master_accumulator, mask=mask)


def custom_kernel(data: input_t) -> output_t:
A_tensor, B_tensor, A_scale_tensor, B_scale_tensor, C_tensor = data

M, K = A_tensor.shape
N, _ = B_tensor.shape

# Fixed, simple configuration - no dynamic tuning
BLOCK_M = 64
BLOCK_N = 64
BLOCK_K = 32

# Launch grid
num_blocks = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)

kernel[(num_blocks,)](
A_tensor,
B_tensor,
A_scale_tensor,
B_scale_tensor,
C_tensor,
M, N, K,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=4,
num_stages=2,
)

return C_tensor
152 changes: 152 additions & 0 deletions docs/AMD_workshop/v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!POPCORN leaderboard amd-fp8-mm
#!POPCORN gpu MI300

from task import input_t, output_t
import torch
import triton
import triton.language as tl

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count


@triton.jit
def kernel(
A_ptr,
B_ptr,
A_scale_ptr,
B_scale_ptr,
C_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_Q: tl.constexpr = 128,
TRANSPOSE: tl.constexpr = False,
):
program_id = tl.program_id(0)
num_pid_across_n = tl.cdiv(N, BLOCK_N)

program_id_m = program_id // num_pid_across_n
program_id_n = program_id % num_pid_across_n

if not TRANSPOSE:
A_stride_m, A_stride_k = 1, M
B_stride_n, B_stride_k = 1, N
else:
A_stride_m, A_stride_k = K, 1
B_stride_n, B_stride_k = K, 1
C_stride_m, C_stride_n = N, 1
# Scale matrices are stored in column-major order, with A being 1x128 and B being 128x128 chunks
# BLOCK_Q is 128
A_scale_stride_m, A_scale_stride_k = 1, M
B_scale_stride_n, B_scale_stride_k = 1, tl.cdiv(N, BLOCK_Q)

# Calculate the row and column indices in the output matrix for the current pid
offset_m = program_id_m * BLOCK_M
offset_n = program_id_n * BLOCK_N

# Arange to make a row and column ptrs
block_offsets_m = offset_m + tl.arange(0, BLOCK_M)
block_offsets_n = offset_n + tl.arange(0, BLOCK_N)
block_offsets_k = tl.arange(0, BLOCK_K)

# ptrs for BLOCK_M rows of A and BLOCK_N columns of B
A_block_ptrs = A_ptr + (
block_offsets_m[:, None] * A_stride_m + block_offsets_k[None, :] * A_stride_k
)
B_block_ptrs = B_ptr + (
block_offsets_k[:, None] * B_stride_k + block_offsets_n[None, :] * B_stride_n
)
# since a_scales are 1x128, a_scale_ptrs need to be of shape (BLOCK_M, 1)
# since N, K <= BLOCK_Q, b_scale_ptrs is always a scalar ptr
A_scale_block_ptrs = A_scale_ptr + (block_offsets_m[:, None] * A_scale_stride_m)
B_scale_block_ptrs = B_scale_ptr + (offset_n // BLOCK_Q) * B_scale_stride_n

# Initialize accumulator for the currrent pid (responsible for BLOCK_M * BLOCK_N elements)
master_accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

# In each iteration we we load BLOCK_Q elements from K dimension for BLOCK_M rows, resp. BLOCK_N columns
# We choose this to use only 1 scale per iteration
num_k_iters = K // BLOCK_Q
for _ in range(0, num_k_iters):
# Initialize accumulator for the current k iteration
inner_accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# In each iteration we load BLOCK_K elements from K dimension for BLOCK_M rows, resp. BLOCK_N columns
# We choose this to use small `tl.dot` for the inner accumulator
for _ in tl.range(0, BLOCK_Q // BLOCK_K):
A_block = tl.load(A_block_ptrs) # (BLOCK_M, BLOCK_K)
B_block = tl.load(B_block_ptrs) # (BLOCK_K, BLOCK_N)
inner_accumulator = tl.dot(
A_block, B_block, inner_accumulator
) # (BLOCK_M, BLOCK_N)

# Move along the K dimension of A, B
A_block_ptrs += BLOCK_K * A_stride_k
B_block_ptrs += BLOCK_K * B_stride_k

A_scales = tl.load(A_scale_block_ptrs) # (BLOCK_M, 1)
B_scales = tl.load(B_scale_block_ptrs) # ()
master_accumulator += inner_accumulator * (A_scales * B_scales)

# Move along the K dimension of A, B scales
A_scale_block_ptrs += A_scale_stride_k
B_scale_block_ptrs += B_scale_stride_k

# Store the result for the current pid
block_offsets_m = (
program_id_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
) # (BLOCK_M, 1)
block_offsets_n = (
program_id_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
) # (1, BLOCK_N)
mask = (block_offsets_m < M) & (block_offsets_n < N) # (BLOCK_M, BLOCK_N)
C_block_ptrs = C_ptr + (block_offsets_m * C_stride_m + block_offsets_n * C_stride_n)
tl.store(C_block_ptrs, master_accumulator, mask=mask)


@torch.compile(dynamic=False, mode="max-autotune-no-cudagraphs")
def contiguous(x):
return x.contiguous()


def get_config(M, N, K):
num_blocks_ref = (M // 128) * (N // 128)
TRANSPOSE = False
matrix_instr_nonkdim = 16
BLOCK_M, BLOCK_N, BLOCK_K = (128, 128, 64)
if num_blocks_ref * 8 < NUM_SMS: # 2 and 7
BLOCK_M, BLOCK_N, BLOCK_K = (32, 64, 128)
matrix_instr_nonkdim = 16
elif num_blocks_ref < NUM_SMS:
BLOCK_M, BLOCK_N, BLOCK_K = (64, 64, 64)

config = dict(
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
waves_per_eu=2,
matrix_instr_nonkdim=matrix_instr_nonkdim,
num_warps=4,
num_stages=2,
TRANSPOSE=TRANSPOSE,
)
return config


def custom_kernel(data: input_t) -> output_t:
A_tensor, B_tensor, A_scale_tensor, B_scale_tensor, C_tensor = data

M, K = A_tensor.shape
N, _ = B_tensor.shape

# heuristic
config = get_config(M, N, K)

num_blocks = triton.cdiv(M, config["BLOCK_M"]) * triton.cdiv(N, config["BLOCK_N"])
kernel[(num_blocks,)](
A_tensor, B_tensor, A_scale_tensor, B_scale_tensor, C_tensor, M, N, K, **config
)

return C_tensor