Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
arg_parser.print();

std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
Expand Down Expand Up @@ -41,6 +43,20 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp32")
{
return run_gemm_example_prec_type<GemmConfig, Invoker, float>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "tf32")
{
return run_gemm_example_prec_type<GemmConfig,
Invoker,
float,
float,
float,
ck_tile::tf32_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig,
Expand Down
42 changes: 30 additions & 12 deletions example/ck_tile/03_gemm/gemm_basic_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@ struct BasicInvoker
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
typename CDEElementWise,
typename ComputeDataType = ADataType>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
{
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
}

constexpr bool is_fp32_input = std::is_same_v<ADataType, float>;
[[maybe_unused]] constexpr bool is_tf32_compute = std::is_same_v<ComputeDataType, ck_tile::tf32_t>;

// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t M_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t N_Tile = is_fp32_input ? 128 : 256;
constexpr ck_tile::index_t K_Tile = 64;

#if CK_TILE_USE_WMMA
Expand All @@ -37,13 +41,23 @@ struct BasicInvoker
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(CK_GFX950_SUPPORT)
// gfx950: fp32 uses 16x16x16 tile (native MFMA)
// tf32 uses 32x32x16 tile (3x bf16 32x32x16 MFMA emulation)
constexpr ck_tile::index_t M_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t N_Warp = (is_fp32_input && !is_tf32_compute) ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = (is_fp32_input && !is_tf32_compute) ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t M_Warp = is_fp32_input ? 4 : 2;
constexpr ck_tile::index_t N_Warp = is_fp32_input ? 4 : 2;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = is_fp32_input ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_fp32_input ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif

Expand All @@ -61,11 +75,15 @@ struct BasicInvoker
BLayout,
CLayout>;

using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
ComputeDataType>;

using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;

Expand Down
20 changes: 19 additions & 1 deletion example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,24 @@ struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<Pre
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;

template <>
struct GemmTypeConfig<float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};

template <>
struct GemmTypeConfig<ck_tile::tf32_t, ck_tile::tf32_t, float>
{
using ADataType = float;
using BDataType = float;
using AccDataType = float;
using CDataType = float;
};

template <>
struct GemmTypeConfig<ck_tile::half_t>
{
Expand Down Expand Up @@ -446,7 +464,7 @@ inline auto create_args()
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/fp32/pk_int4_t/tf32")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
Expand Down
56 changes: 31 additions & 25 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ template <typename GemmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
typename ComputeDataType = ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
Expand Down Expand Up @@ -151,7 +152,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
DsLayout,
CLayout,
true,
CDEElementWise>(
CDEElementWise,
ComputeDataType>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
Expand All @@ -169,7 +171,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
DsLayout,
CLayout,
false,
CDEElementWise>(
CDEElementWise,
ComputeDataType>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
Expand Down Expand Up @@ -209,11 +212,12 @@ std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_ge
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename BDataType = ADataType,
typename CDataType = ADataType,
typename ALayout = ck_tile::tensor_layout::gemm::RowMajor,
typename BLayout = ck_tile::tensor_layout::gemm::ColumnMajor,
typename CLayout = ck_tile::tensor_layout::gemm::RowMajor,
typename ComputeDataType = ADataType>
int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
Expand Down Expand Up @@ -349,21 +353,22 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent,
flush_cache,
rotating_count);
CLayout,
ComputeDataType>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent,
flush_cache,
rotating_count);

c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

Expand Down Expand Up @@ -393,7 +398,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,

if(arg_parser.get_int("v") == 1)
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType, ComputeDataType>(
a_m_k, b_k_n, c_m_n_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
Expand Down Expand Up @@ -427,7 +432,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
CLayout,
ComputeDataType>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);

c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());

Expand Down
11 changes: 8 additions & 3 deletions example/ck_tile/03_gemm/run_gemm_example_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
template <typename GemmConfig,
typename Invoker,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
typename BPrecType = APrecType,
typename CPrecType = APrecType,
typename ComputeDataType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
Expand Down Expand Up @@ -54,7 +55,11 @@ int run_gemm_example_prec_type(std::string a_layout,
Invoker,
APrecType,
BPrecType,
CPrecType>(
CPrecType,
decltype(a_layout_type),
decltype(b_layout_type),
Row,
ComputeDataType>(
arg_parser, a_layout_type, b_layout_type, Row{});
}
},
Expand Down
21 changes: 21 additions & 0 deletions include/ck_tile/core/numeric/numeric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

namespace ck_tile {

using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits

// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
Expand Down Expand Up @@ -101,6 +103,25 @@ struct numeric_traits<float>
using bitwise_type = uint32_t;
};

template <>
struct numeric_traits<tf32_t>
{
static constexpr int exp = 8;
static constexpr int mant = 10;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr int PackedSize = 1;
using bitwise_type = uint32_t;
};

} // namespace ck_tile

#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
Expand Down
27 changes: 27 additions & 0 deletions include/ck_tile/core/numeric/type_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,33 @@ CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)

enum class tf32_rounding_mode
{
truncate = 0,
standard = 1, // RTNE
};

template <tf32_rounding_mode rounding = tf32_rounding_mode::truncate>
CK_TILE_HOST_DEVICE constexpr float float_to_tf32(float x)
{
uint32_t i = bit_cast<uint32_t>(x);
if constexpr(rounding == tf32_rounding_mode::standard)
{
if((i & 0x7f800000) != 0x7f800000)
{
i += 0xfff + ((i >> 13) & 1);
}
}
i &= 0xFFFFE000u;
return bit_cast<float>(i);
}

template <typename Y, std::enable_if_t<std::is_same_v<Y, tf32_t>, bool> = false>
CK_TILE_HOST_DEVICE constexpr float type_convert(float x)
{
return float_to_tf32(x);
}

CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
Expand Down
Loading
Loading