Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8a0ea47
initial impl
matthiasdiener Mar 9, 2026
4270296
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Mar 11, 2026
6ddb77d
put into benchmarks subfolder
matthiasdiener Mar 11, 2026
fb2b3f3
restructure comment
matthiasdiener Mar 11, 2026
d4e9b1e
misc updates
matthiasdiener Mar 11, 2026
95358f4
python fix
matthiasdiener Mar 11, 2026
a675d17
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Mar 12, 2026
d0a320d
another embedded python fix
matthiasdiener Mar 12, 2026
6f45853
replace py code
matthiasdiener Mar 12, 2026
e5eaf10
Merge branch 'dev' into mdiener/ci-microbench
matthiasdiener Mar 13, 2026
55e7eb5
restore disabled parts of CI
matthiasdiener Mar 13, 2026
7072e82
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Mar 13, 2026
9c771b4
add attention, casting, normalization
matthiasdiener Mar 13, 2026
64e8da8
add timestamp and commit ID
matthiasdiener Mar 13, 2026
c986c97
add FP8 GEMM
matthiasdiener Mar 13, 2026
4f6dc86
fix name
matthiasdiener Mar 14, 2026
811e329
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Mar 16, 2026
bd6c3e7
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Mar 17, 2026
c9d6d4d
updates casting
matthiasdiener Mar 17, 2026
4bc11df
Merge branch 'dev' into mdiener/ci-microbench
matthiasdiener Apr 20, 2026
de21a77
remove attention
matthiasdiener Apr 20, 2026
1d6f869
fix grouped gemm
matthiasdiener Apr 20, 2026
12b4218
remove CI part
matthiasdiener Apr 20, 2026
75c8291
use adaptive_autorange, cleanups
matthiasdiener Apr 21, 2026
2e6da68
add csv to asv converter
matthiasdiener Apr 21, 2026
6353411
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Apr 26, 2026
a1c6453
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener Apr 27, 2026
33a3137
Merge remote-tracking branch 'upstream/dev' into mdiener/ci-microbench
matthiasdiener May 7, 2026
aa8997c
refactor
matthiasdiener May 7, 2026
fefaf13
remove asv converter
matthiasdiener May 7, 2026
7f2669d
cleanups, misc fixes
matthiasdiener May 7, 2026
117c2d7
Merge remote-tracking branch 'origin/dev' into mdiener/ci-microbench
matthiasdiener May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions benchmarks/microbenchmarks/benchmark_casting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
###############################################################################
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################
"""
FP8 casting micro-benchmark.

Benchmarks quantization (BF16 -> FP8) and dequantization (FP8 -> BF16) for
both E4M3 (activations/weights) and E5M2 (gradients) formats.

These casts are memory-bound; we report GB/s (input + output bytes).
Output: benchmark_casting.csv (written to cwd)
"""

import torch
import transformer_engine
import transformer_engine_torch as tex
from transformer_engine.pytorch import Float8Quantizer
from utils import (
MODEL_HIDDEN_SIZES, M_SIZE_LIST,
time_func, compute_gbps, run_benchmarks,
)

TE_FP8_E4M3 = tex.DType.kFloat8E4M3
TE_FP8_E5M2 = tex.DType.kFloat8E5M2

CAST_CONFIGS = [
# (name, direction, fp8_dtype)
("BF16-to-FP8-E4M3", "quantize", TE_FP8_E4M3),
("FP8-E4M3-to-BF16", "dequantize", TE_FP8_E4M3),
("BF16-to-FP8-E5M2", "quantize", TE_FP8_E5M2),
("FP8-E5M2-to-BF16", "dequantize", TE_FP8_E5M2),
]


def _generate_cast_test_cases():
test_cases = []
for model_name, hidden in MODEL_HIDDEN_SIZES:
for cast_name, direction, fp8_dtype in CAST_CONFIGS:
for M in M_SIZE_LIST:
test_cases.append({
"Case": f"{model_name}/{cast_name}",
"M": M,
"hidden_size": hidden,
"direction": direction,
"fp8_dtype": fp8_dtype,
"dtype_str": cast_name,
})
return test_cases


def bench_cast(Case, M, hidden_size, direction, fp8_dtype, dtype_str):
device = "cuda"

numel = M * hidden_size
scale = torch.ones(1, dtype=torch.float32, device=device)
amax = torch.zeros(1, dtype=torch.float32, device=device)
quantizer = Float8Quantizer(scale, amax, fp8_dtype)

if direction == "quantize":
x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device)
out = quantizer(x)
cast_func = lambda: quantizer.quantize(x, out=out)
total_bytes = numel * (2 + 1) # BF16 read + FP8 write
else:
x = torch.randn(M, hidden_size, dtype=torch.bfloat16, device=device)
fp8_tensor = quantizer(x)
cast_func = lambda: fp8_tensor.dequantize()
total_bytes = numel * (1 + 2) # FP8 read + BF16 write

ms = time_func(cast_func, method="blocked")
gbps = compute_gbps(total_bytes, ms)

print(f" {ms:.4f} ms | {gbps:.1f} GB/s")

return {
"Cast Time (ms)": f"{ms:.4f}",
"Cast GB/s": f"{gbps:.1f}",
}


if __name__ == "__main__":
run_benchmarks(
test_cases=_generate_cast_test_cases(),
bench_fn=bench_cast,
param_columns=["Case", "M", "hidden_size", "dtype_str"],
metric_columns=["Cast Time (ms)", "Cast GB/s"],
default_csv="benchmark_casting.csv",
)
82 changes: 82 additions & 0 deletions benchmarks/microbenchmarks/benchmark_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python
###############################################################################
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################


import torch
import transformer_engine.pytorch as te
from utils import (
MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes,
time_func, compute_tflops, run_benchmarks,
)

ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS)


def _generate_gemm_test_cases():
test_cases = []
for M in M_SIZE_LIST:
for case_name, (N, K) in ACTIVE_SHAPES.items():
test_cases.append({
"Case": case_name,
"M": M,
"N": N,
"K": K,
"dtype": torch.bfloat16,
})
return test_cases


def bench_gemm(Case, M, N, K, dtype):
device = "cuda"

linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype)
x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True)

fwd_func = lambda: linear(x)
out = fwd_func()
grad_out = torch.randn_like(out)

def fwd_bwd_func():
out = linear(x)
out.backward(grad_out)
x.grad = None
linear.weight.grad = None

fwd_bwd_func()

fwd_flops = 2 * M * N * K
bwd_flops = 2 * fwd_flops # dX + dW

fwd_ms = time_func(fwd_func)
fwd_bwd_ms = time_func(fwd_bwd_func)
bwd_ms = fwd_bwd_ms - fwd_ms

fwd_tflops = compute_tflops(fwd_flops, fwd_ms)
bwd_tflops = compute_tflops(bwd_flops, bwd_ms)

print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS")
print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)")

return {
"TE Forward Time (ms)": f"{fwd_ms:.2f}",
"TE Forward TFLOPS": f"{fwd_tflops:.2f}",
"TE Backward Time (ms)": f"{bwd_ms:.2f}",
"TE Backward TFLOPS": f"{bwd_tflops:.2f}",
}


if __name__ == "__main__":
run_benchmarks(
test_cases=_generate_gemm_test_cases(),
bench_fn=bench_gemm,
param_columns=["Case", "M", "N", "K", "dtype"],
metric_columns=[
"TE Forward Time (ms)", "TE Forward TFLOPS",
"TE Backward Time (ms)", "TE Backward TFLOPS",
],
default_csv="benchmark_gemm.csv",
)
94 changes: 94 additions & 0 deletions benchmarks/microbenchmarks/benchmark_gemm_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python
###############################################################################
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################
"""
FP8 GEMM micro-benchmark using te.Linear under fp8_autocast.

Same model shapes as benchmark_gemm.py.
Output: benchmark_gemm_fp8.csv (written to cwd)
"""

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
from utils import (
MODEL_CONFIGS, M_SIZE_LIST, gemm_shapes,
time_func, compute_tflops, run_benchmarks,
)

FP8_RECIPE = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max",
)

ACTIVE_SHAPES = gemm_shapes(MODEL_CONFIGS)


def _generate_gemm_test_cases():
test_cases = []
for M in M_SIZE_LIST:
for case_name, (N, K) in ACTIVE_SHAPES.items():
test_cases.append({
"Case": case_name,
"M": M,
"N": N,
"K": K,
"dtype": torch.bfloat16,
})
return test_cases


def bench_fp8_gemm(Case, M, N, K, dtype):
device = "cuda"

linear = te.Linear(K, N, bias=False).to(device=device, dtype=dtype)
x = torch.randn(M, K, dtype=dtype, device=device, requires_grad=True)
grad_out = torch.randn(M, N, dtype=dtype, device=device)

def fwd_func():
with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
return linear(x)

def fwd_bwd_func():
with te.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
out = linear(x)
out.backward(grad_out)
x.grad = None
linear.weight.grad = None

fwd_flops = 2 * M * N * K
bwd_flops = 2 * fwd_flops

fwd_ms = time_func(fwd_func)
fwd_bwd_ms = time_func(fwd_bwd_func)
bwd_ms = fwd_bwd_ms - fwd_ms

fwd_tflops = compute_tflops(fwd_flops, fwd_ms)
bwd_tflops = compute_tflops(bwd_flops, bwd_ms)

print(f" Forward {fwd_ms:.3f} ms | {fwd_tflops:.2f} TFLOPS")
print(f" Backward {bwd_ms:.3f} ms | {bwd_tflops:.2f} TFLOPS (derived)")

return {
"FP8 Forward Time (ms)": f"{fwd_ms:.2f}",
"FP8 Forward TFLOPS": f"{fwd_tflops:.2f}",
"FP8 Backward Time (ms)": f"{bwd_ms:.2f}",
"FP8 Backward TFLOPS": f"{bwd_tflops:.2f}",
}


if __name__ == "__main__":
run_benchmarks(
test_cases=_generate_gemm_test_cases(),
bench_fn=bench_fp8_gemm,
param_columns=["Case", "M", "N", "K", "dtype"],
metric_columns=[
"FP8 Forward Time (ms)", "FP8 Forward TFLOPS",
"FP8 Backward Time (ms)", "FP8 Backward TFLOPS",
],
default_csv="benchmark_gemm_fp8.csv",
)
Loading