Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,18 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,

// The specialized rowwise cast-only kernel vectorizes full 32-element chunks.
// Shapes with a partial row tail (for example, N=48) must use the generic kernel,
// otherwise the last chunk reads/writes past the logical end of the row.
const bool is_full_rowwise_chunk =
(cols % specialized::CastTraits<IType, OType, true, false>::chunkElems == 0);

const bool scaling_type_has_specialized_support =
(scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) ||
(scaling_type == ScalingType::BIDIMENSIONAL);

if (specialized::hasSpec<IS_DBIAS, IS_DACT, IS_ACT, IType, OType>() &&
!WITH_GEMM_SWIZZLED_SCALES) {
!WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) {
switch (scaling_type) {
case ScalingType::ROWWISE: {
using traits = specialized::CastTraits<IType, OType, true, false>;
Expand All @@ -664,10 +674,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop,

break;
}
case ScalingType::COLWISE: {
NVTE_WARN("Colwise scaling will fallback to original kernel.");
break;
}
case ScalingType::BIDIMENSIONAL: {
using traits = specialized::CastTraits<IType, OType, true, true>;
auto kernel = specialized::quantize_mxfp8_kernel_cast_only<traits>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) {
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} // anonymous namespace

inline bool is_cast_only_enabled() {
static bool enabled = []() {
const char *env = std::getenv("ENABLE_CAST_ONLY");
return env != nullptr && (env[0] == '1');
}();
return enabled;

// // FIXME: when finish debugging, remove this
// const char* env = std::getenv("ENABLE_CAST_ONLY");
// return env != nullptr && (env[0] == '1');
}

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename IType, typename OType>
inline bool hasSpec() {
return false;
Expand All @@ -112,19 +100,19 @@ inline bool hasSpec() {
// OType could be [fp8e5m2, fp8e4m3]
template <>
inline bool hasSpec<false, false, false, fp16, fp8e5m2>() {
return is_cast_only_enabled();
return true;
}
template <>
inline bool hasSpec<false, false, false, fp16, fp8e4m3>() {
return is_cast_only_enabled();
return true;
}
template <>
inline bool hasSpec<false, false, false, bf16, fp8e5m2>() {
return is_cast_only_enabled();
return true;
}
template <>
inline bool hasSpec<false, false, false, bf16, fp8e4m3>() {
return is_cast_only_enabled();
return true;
}

template <int32_t _M, int32_t _N>
Expand Down
Loading