Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
4bc109f
Enable gptoss sink
LJ-underdog Dec 25, 2025
b86a860
Merge branch 'develop' into gptoss_sink
LJ-underdog Dec 25, 2025
6711cec
Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipel…
LJ-underdog Dec 25, 2025
1e3c54d
Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipel…
LJ-underdog Dec 25, 2025
7c9cb83
add gptoss sink test
LJ-underdog Dec 25, 2025
5163868
update CHANGELOG.md
LJ-underdog Dec 25, 2025
4d73213
Merge branch 'develop' into gptoss_sink
LJ-underdog Dec 29, 2025
5ab683b
fix test args error
LJ-underdog Dec 30, 2025
5c0e07a
Update test_fmha_fwd.cpp
LJ-underdog Dec 30, 2025
970b4f1
update sink test
LJ-underdog Dec 30, 2025
0eeedeb
Revert "update sink test"
LJ-underdog Dec 30, 2025
31db412
update sink test
LJ-underdog Dec 30, 2025
b37b174
update valid sink_v in splitkv pipeline
LJ-underdog Jan 5, 2026
a20868e
Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
LJ-underdog Jan 5, 2026
81b02a6
Update example_fmha_fwd.cpp
LJ-underdog Jan 6, 2026
ebf445f
fix lse error
LJ-underdog Jan 7, 2026
f2fddfa
fix clangformat error
LJ-underdog Jan 7, 2026
1d4e219
fix aiter scale error
LJ-underdog Jan 8, 2026
a667752
Update block_fmha_pipeline_qr_ks_vs.hpp
LJ-underdog Jan 8, 2026
68ff3e2
div scale_s for sink_value
LJ-underdog Jan 8, 2026
1984232
Update fmha_fwd_runner.hpp
LJ-underdog Jan 9, 2026
1db4995
update sink_value with bias
LJ-underdog Jan 12, 2026
8e1e5c1
Merge branch 'develop' into gptoss_sink
LJ-underdog Jan 12, 2026
7ae150a
Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
LJ-underdog Jan 12, 2026
d9a2b40
Fix typo in dropout parameter in fmha_batch_prefill_kernel
LJ-underdog Jan 12, 2026
1bc51b3
Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
LJ-underdog Jan 12, 2026
815044c
Update example_fmha_fwd.cpp
LJ-underdog Jan 13, 2026
b68d460
Merge branch 'develop' into gptoss_sink
LJ-underdog Jan 13, 2026
26a3e9f
Merge branch 'develop' into gptoss_sink
poyenc Jan 13, 2026
9cd4f13
Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs…
LJ-underdog Jan 13, 2026
3c68266
Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipel…
LJ-underdog Jan 13, 2026
f0e1d50
optimized some code
LJ-underdog Jan 13, 2026
2862c13
fix splitkv error
LJ-underdog Jan 13, 2026
be92171
update sink reference
LJ-underdog Jan 14, 2026
45c2b22
Merge branch 'develop' into gptoss_sink
LJ-underdog Jan 14, 2026
2ad9d5e
Update fmha_fwd_runner.hpp
LJ-underdog Jan 14, 2026
fdd767e
Update smoke_test_fwd_sink.sh
LJ-underdog Jan 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added preshuffleB support for abquant mode in blockscale GEMM.
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added support for gfx1153 target.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.

### Changed

Expand Down
5 changes: 4 additions & 1 deletion example/ck_tile/01_fmha/example_fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
"Comma-separated list of length 'b'. If empty, no override.")
.insert("init_sink", "0", "value to init the output tensor sink value for validation");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
Expand Down Expand Up @@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
int init_sink_value = arg_parser.get_int("init_sink");

ck_tile::stream_config stream_config{nullptr,
true,
Expand Down Expand Up @@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
init_method,
seed,
do_validation,
init_sink_value,
stream_config,
json);
}
Expand Down
29 changes: 21 additions & 8 deletions example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ struct fmha_fwd_args
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr

const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
const void* sink_ptr;

ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
Expand Down Expand Up @@ -519,6 +523,7 @@ struct fmha_batch_prefill_args
// 1) +
// kargs.kv_last_page_lens[b]
const void* seqstart_q_ptr;
const void* sink_ptr;

ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
Expand Down Expand Up @@ -638,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
args.cu_seqlen_k_ptr,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -688,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.s_randval,
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
args.cu_seqlen_k_ptr,
args.sink_ptr);
}
}();

Expand Down Expand Up @@ -848,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.window_size_right,
args.sink_size,
args.mask_type,
args.min_seqlen_q);
args.min_seqlen_q,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -893,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
}();

Expand Down Expand Up @@ -960,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -1008,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.window_size_left,
args.window_size_right,
args.sink_size,
args.mask_type);
args.mask_type,
args.sink_ptr);
}
}();

Expand Down Expand Up @@ -1187,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.sink_ptr);
}
else
{ // create batch mode kernel arguments
Expand Down Expand Up @@ -1239,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.sink_ptr);
}
}();

Expand Down
92 changes: 83 additions & 9 deletions example/ck_tile/01_fmha/fmha_fwd_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
return num_splits;
}

template <typename SMPLComputeDataType>
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
ck_tile::index_t nhead,
ck_tile::index_t real_seqlen_q,
ck_tile::index_t real_seqlen_k)
{
for(auto i_h = 0; i_h < nhead; i_h++)
{
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
}
// Append sink token at the end of each row
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
}
}
}

template <typename DataTypeConfig>
fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::index_t batch,
Expand Down Expand Up @@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::string init_method,
uint32_t seed,
int do_validation,
int init_sink_value,
const ck_tile::stream_config& stream_config,
std::optional<std::string> json = std::nullopt)
{
Expand Down Expand Up @@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,

ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
ck_tile::HostTensor<KDataType> k_host(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
Expand Down Expand Up @@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
bias_host);
}

else if(init_method == "ni")
{
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
Expand Down Expand Up @@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,

iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);

if(init_sink_value != 0)
{
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
// for close to rowmax values.
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
sink_host);
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
Expand Down Expand Up @@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
sink_buf.ToDevice(sink_host.data());
knew_buf.ToDevice(knew_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
Expand Down Expand Up @@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();

if(init_sink_value != 0)
args.sink_ptr = sink_buf.GetDeviceBuffer();
else
args.sink_ptr = nullptr;
args.batch = batch;
args.seqlen_q = shape_seqlen_q; // unused in group mode
args.hdim_q = hdim_q;
Expand Down Expand Up @@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
mask.type == mask_enum::mask_top_left));
}
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
if(lse)
if(init_sink_value != 0)
{
ck_tile::
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
// Create extended tensor with sink token
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
{nhead, real_seqlen_q, real_seqlen_k + 1});

// Copy original attention scores and append sink values
copy_attention_scores_with_sink(
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);

// Compute softmax on extended tensor
ck_tile::HostTensor<PDataType> p_extended(
{nhead, real_seqlen_q, real_seqlen_k + 1});

if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
}
else
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_with_sinks_ref, p_extended, p_compute_element_func);
}

// Extract only the original columns (exclude sink token column)
p_host_ref.ForEach(
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
}
else
{
ck_tile::
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
// No sink tokens - compute softmax directly
if(lse)
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
}
else
{
ck_tile::reference_batched_softmax<SMPLComputeDataType,
SMPLComputeDataType,
PDataType>(
s_host_ref, p_host_ref, p_compute_element_func);
}
}

if(p_drop > 0)
{
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
Expand Down
7 changes: 7 additions & 0 deletions example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
# 1 1 1 1 1 1 1 1 1 1
# l=2/r=0(br) l=2/r=0/s=2(br)

$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1

$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0

$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1

$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
Loading