Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,11 @@ def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
False
return True

all_mask_keys = list(get_mask_map("simplified").keys()) + list(
get_mask_map("generic").keys()
)
no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key]

def check_feature(
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
Expand All @@ -821,6 +826,12 @@ def check_feature(
or kernel_ctx.pipeline.F_logits == "f"
):
return False
# sink_size is only meaningful when no masking is applied
if (
kernel_ctx.pipeline.F_mask in no_mask_keys
and kernel_ctx.pipeline.F_sink == "t"
):
return False
return True

return [check_mode, check_hdim, check_feature]
Expand Down
9 changes: 9 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,11 @@ def get_fwd_splitkv_blobs(

factories = get_factories_for_targets(targets, get_factory)

all_mask_keys = list(get_mask_map("simplified").keys()) + list(
get_mask_map("generic").keys()
)
no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key]

for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
Expand All @@ -899,6 +904,10 @@ def get_fwd_splitkv_blobs(
or pipeline.F_logits == "f"
):
continue
# sink_size is only meaningful when no masking is applied
if pipeline.F_mask in no_mask_keys and pipeline.F_sink == "t":
continue

k = Kernel(
F_arch=factory.arch,
F_idx=0,
Expand Down
9 changes: 9 additions & 0 deletions example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,11 @@ def get_fwd_blobs(

factories = get_factories_for_targets(targets, get_factory)

all_mask_keys = list(get_mask_map("simplified").keys()) + list(
get_mask_map("generic").keys()
)
no_mask_keys = [mask_key for mask_key in all_mask_keys if "no" in mask_key]

for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
Expand All @@ -666,6 +671,10 @@ def get_fwd_blobs(
or pipeline.F_logits == "f"
):
continue
# sink_size is only meaningful when no masking is applied
if pipeline.F_mask in no_mask_keys and pipeline.F_sink == "t":
continue

k = FmhaFwdKernel(
F_arch=factory.arch,
F_idx=0,
Expand Down
5 changes: 4 additions & 1 deletion example/ck_tile/01_fmha/mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ struct mask_info
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
tmp.type = mask_enum::no_mask;
tmp.left = -1;
tmp.right = -1;
tmp.sink = 0;
}
else if(str == "1" || str == "t")
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct BlockFmhaPipelineProblem
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
static_assert(FmhaMask::IsMasking || !kHasSink);
};

template <typename QDataType_,
Expand Down Expand Up @@ -182,6 +183,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
static_assert(FmhaMask::IsMasking || !kHasSink);
};

template <typename QDataType_,
Expand Down Expand Up @@ -236,6 +238,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kHasSink = Traits::kHasSink;
static_assert(FmhaMask::IsMasking || !kHasSink);
};

// extract tile size attributes to remove dependency on traits
Expand Down