Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 118 additions & 5 deletions backends/metax_gpu/cinn/compiler/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ typedef long long int64_t;
// Compatible with __half references in CINN-generated code
typedef __half float16;

// BFloat16 software type for CINN-generated code
struct bfloat16 {
unsigned short x;
__device__ bfloat16() {}
__device__ explicit bfloat16(float val) {
unsigned int ival = *(unsigned int*)&val;
// Round to nearest even
unsigned int lsb = (ival >> 16) & 1;
unsigned int rounding_bias = 0x7fff + lsb;
ival += rounding_bias;
x = (unsigned short)(ival >> 16);
}
__device__ explicit operator float() const {
unsigned int val = ((unsigned int)x) << 16;
return *(float*)&val;
}
__device__ bfloat16 operator+(const bfloat16& o) const { return bfloat16((float)*this + (float)o); }
__device__ bfloat16 operator-(const bfloat16& o) const { return bfloat16((float)*this - (float)o); }
__device__ bfloat16 operator*(const bfloat16& o) const { return bfloat16((float)*this * (float)o); }
__device__ bfloat16 operator/(const bfloat16& o) const { return bfloat16((float)*this / (float)o); }
__device__ bool operator<(const bfloat16& o) const { return (float)*this < (float)o; }
__device__ bool operator>(const bfloat16& o) const { return (float)*this > (float)o; }
__device__ bool operator==(const bfloat16& o) const { return x == o.x; }
__device__ bool operator!=(const bfloat16& o) const { return x != o.x; }
__device__ bool operator<=(const bfloat16& o) const { return (float)*this <= (float)o; }
__device__ bool operator>=(const bfloat16& o) const { return (float)*this >= (float)o; }
};

#define CINN_BF16_MIN bfloat16(-3.38953139e+38f)
#define CINN_BF16_MAX bfloat16(3.38953139e+38f)

#define CINN_UINT8_MIN 0
#define CINN_UINT8_MAX 255
#define CINN_INT16_MIN -32768
Expand Down Expand Up @@ -308,6 +339,51 @@ __device__ inline float16 FN_FP16(fma)(float16 a, float16 b, float16 c) { return
__device__ inline float16 FN_FP16(max)(float16 a, float16 b) { return __hgt(a, b) ? a : b; }
__device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b) ? a : b; }

// ===============================================================
// BFloat16 Functions
// ===============================================================
#define FN_BF16(func) cinn_custom_device_##func##_bf16

__device__ inline bfloat16 FN_BF16(ceil)(bfloat16 x) { return bfloat16(ceilf((float)x)); }
__device__ inline bfloat16 FN_BF16(floor)(bfloat16 x) { return bfloat16(floorf((float)x)); }
__device__ inline bfloat16 FN_BF16(round)(bfloat16 x) { return bfloat16(roundf((float)x)); }
__device__ inline bfloat16 FN_BF16(trunc)(bfloat16 x) { return bfloat16(truncf((float)x)); }
__device__ inline bfloat16 FN_BF16(sin)(bfloat16 x) { return bfloat16(sinf((float)x)); }
__device__ inline bfloat16 FN_BF16(cos)(bfloat16 x) { return bfloat16(cosf((float)x)); }
__device__ inline bfloat16 FN_BF16(exp)(bfloat16 x) { return bfloat16(expf((float)x)); }
__device__ inline bfloat16 FN_BF16(log)(bfloat16 x) { return bfloat16(logf((float)x)); }
__device__ inline bfloat16 FN_BF16(log2)(bfloat16 x) { return bfloat16(log2f((float)x)); }
__device__ inline bfloat16 FN_BF16(log10)(bfloat16 x) { return bfloat16(log10f((float)x)); }
__device__ inline bfloat16 FN_BF16(sqrt)(bfloat16 x) { return bfloat16(sqrtf((float)x)); }
__device__ inline bfloat16 FN_BF16(rsqrt)(bfloat16 x) { return bfloat16(rsqrtf((float)x)); }
__device__ inline bfloat16 FN_BF16(cbrt)(bfloat16 x) { return bfloat16(cbrtf((float)x)); }
__device__ inline bfloat16 FN_BF16(abs)(bfloat16 x) { return bfloat16(fabsf((float)x)); }
__device__ inline bool FN_BF16(isnan)(bfloat16 x) { return isnan((float)x); }
__device__ inline bool FN_BF16(isinf)(bfloat16 x) { return isinf((float)x); }
__device__ inline bool FN_BF16(isfinite)(bfloat16 x) { return isfinite((float)x); }
__device__ inline bfloat16 FN_BF16(erf)(bfloat16 x) { return bfloat16(erff((float)x)); }
__device__ inline bfloat16 FN_BF16(tan)(bfloat16 x) { return bfloat16(tanf((float)x)); }
__device__ inline bfloat16 FN_BF16(sinh)(bfloat16 x) { return bfloat16(sinhf((float)x)); }
__device__ inline bfloat16 FN_BF16(cosh)(bfloat16 x) { return bfloat16(coshf((float)x)); }
__device__ inline bfloat16 FN_BF16(tanh)(bfloat16 x) { return bfloat16(tanhf((float)x)); }
__device__ inline bfloat16 FN_BF16(asin)(bfloat16 x) { return bfloat16(asinf((float)x)); }
__device__ inline bfloat16 FN_BF16(acos)(bfloat16 x) { return bfloat16(acosf((float)x)); }
__device__ inline bfloat16 FN_BF16(atan)(bfloat16 x) { return bfloat16(atanf((float)x)); }
__device__ inline bfloat16 FN_BF16(asinh)(bfloat16 x) { return bfloat16(asinhf((float)x)); }
__device__ inline bfloat16 FN_BF16(acosh)(bfloat16 x) { return bfloat16(acoshf((float)x)); }
__device__ inline bfloat16 FN_BF16(atanh)(bfloat16 x) { return bfloat16(atanhf((float)x)); }
__device__ inline bfloat16 FN_BF16(sigmoid)(bfloat16 x) { return bfloat16(1.0f / (1.0f + expf(-(float)x))); }
__device__ inline bfloat16 FN_BF16(mod)(bfloat16 a, bfloat16 b) { return bfloat16(fmodf((float)a, (float)b)); }
__device__ inline bfloat16 FN_BF16(pow)(bfloat16 a, bfloat16 b) { return bfloat16(powf((float)a, (float)b)); }
__device__ inline bfloat16 FN_BF16(add)(bfloat16 a, bfloat16 b) { return bfloat16((float)a + (float)b); }
__device__ inline bfloat16 FN_BF16(sub)(bfloat16 a, bfloat16 b) { return bfloat16((float)a - (float)b); }
__device__ inline bfloat16 FN_BF16(mul)(bfloat16 a, bfloat16 b) { return bfloat16((float)a * (float)b); }
__device__ inline bfloat16 FN_BF16(div)(bfloat16 a, bfloat16 b) { return bfloat16((float)a / (float)b); }
__device__ inline bfloat16 FN_BF16(neg)(bfloat16 a) { return bfloat16(-(float)a); }
__device__ inline bfloat16 FN_BF16(fma)(bfloat16 a, bfloat16 b, bfloat16 c) { return bfloat16(fmaf((float)a, (float)b, (float)c)); }
__device__ inline bfloat16 FN_BF16(max)(bfloat16 a, bfloat16 b) { return (float)a > (float)b ? a : b; }
__device__ inline bfloat16 FN_BF16(min)(bfloat16 a, bfloat16 b) { return (float)a < (float)b ? a : b; }

// ===============================================================
// Warp Shuffle Functions (used by reduce operators)
// ===============================================================
Expand Down Expand Up @@ -348,6 +424,20 @@ __device__ inline __half FN_SHUFFLE(warp_shuffle_down_fp16)(__half v, int factor
unsigned short res = (unsigned short)__shfl_down((int)val, factor);
return __ushort_as_half(res);
}

// BFloat16 warp shuffle (bitcast through unsigned short)
__device__ inline bfloat16 FN_SHUFFLE(warp_shuffle_xor_bf16)(bfloat16 v, int factor) {
unsigned short res = (unsigned short)__shfl_xor((int)v.x, factor);
bfloat16 r; r.x = res; return r;
}
__device__ inline bfloat16 FN_SHUFFLE(warp_shuffle_up_bf16)(bfloat16 v, int factor) {
unsigned short res = (unsigned short)__shfl_up((int)v.x, factor);
bfloat16 r; r.x = res; return r;
}
__device__ inline bfloat16 FN_SHUFFLE(warp_shuffle_down_bf16)(bfloat16 v, int factor) {
unsigned short res = (unsigned short)__shfl_down((int)v.x, factor);
bfloat16 r; r.x = res; return r;
}
} // extern "C"

// ===============================================================
Expand Down Expand Up @@ -459,11 +549,10 @@ __device__ inline float16 cinn_max_fp16(const float16 left, const float16 right)
__device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) { return __hlt(left, right) ? left : right; }

// --- BF16 (BFloat16) ---
// [Note] If mxcc does not support __nv_bfloat16, this section should be commented out or produce an error
#if defined(__MACACC__) || defined(__CUDACC__) // Assuming support is available
// Placeholder: comment out the BF16 section if compilation errors occur
// __device__ inline __nv_bfloat16 cinn_sum_bf16(...) ...
#endif
__device__ inline bfloat16 cinn_sum_bf16(const bfloat16 left, const bfloat16 right) { return bfloat16((float)left + (float)right); }
__device__ inline bfloat16 cinn_prod_bf16(const bfloat16 left, const bfloat16 right) { return bfloat16((float)left * (float)right); }
__device__ inline bfloat16 cinn_max_bf16(const bfloat16 left, const bfloat16 right) { return (float)left > (float)right ? left : right; }
__device__ inline bfloat16 cinn_min_bf16(const bfloat16 left, const bfloat16 right) { return (float)left < (float)right ? left : right; }

// ===============================================================
// 3. Reduce Initialization Macros
Expand Down Expand Up @@ -512,6 +601,13 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right)
MACRO(max_fp16, -65504.0, float16, ##__VA_ARGS__) \
MACRO(min_fp16, 65504.0, float16, ##__VA_ARGS__)

// BF16 initial values
#define EXPAND_REDUCE_BF16_MACRO(MACRO, ...) \
MACRO(sum_bf16, 0.0, bfloat16, ##__VA_ARGS__) \
MACRO(prod_bf16, 1.0, bfloat16, ##__VA_ARGS__) \
MACRO(max_bf16, -3.38953139e+38f, bfloat16, ##__VA_ARGS__) \
MACRO(min_bf16, 3.38953139e+38f, bfloat16, ##__VA_ARGS__)


// ===============================================================
// 4. Warp Shuffle Wrappers (Using Legacy API & Full Down Strategy)
Expand Down Expand Up @@ -559,6 +655,11 @@ __device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int
return __ushort_as_half((unsigned short)__shfl_down((int)val, factor));
}

__device__ inline bfloat16 cinn_warp_shuffle_down_bfloat16_wrapper(bfloat16 v, int factor) {
unsigned short res = (unsigned short)__shfl_down((int)v.x, factor);
bfloat16 r; r.x = res; return r;
}

__device__ inline welford_fp32 cinn_warp_shuffle_down_welford_fp32_wrapper(welford_fp32 v, int factor) {
float m = __shfl_down(v.mean, factor);
float m2 = __shfl_down(v.m2, factor);
Expand All @@ -582,6 +683,11 @@ __device__ inline float16 cinn_warp_shuffle_idx_float16_wrapper(float16 v, int l
return __ushort_as_half((unsigned short)__shfl((int)val, lane));
}

__device__ inline bfloat16 cinn_warp_shuffle_idx_bfloat16_wrapper(bfloat16 v, int lane) {
unsigned short res = (unsigned short)__shfl((int)v.x, lane);
bfloat16 r; r.x = res; return r;
}

__device__ inline double cinn_warp_shuffle_idx_double_wrapper(double v, int lane) {
unsigned long long int val_u64 = *(unsigned long long int*)&v;
int lo = __shfl((int)val_u64, lane);
Expand Down Expand Up @@ -617,6 +723,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_FP64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_BOOL_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)
EXPAND_REDUCE_BF16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL)

// ===============================================================
// 5. Block Reduce & Discrete Reduce & Grid Reduce
Expand Down Expand Up @@ -673,6 +780,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_MACRO)
EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_MACRO)
EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_MACRO)
EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_MACRO)
EXPAND_REDUCE_BF16_MACRO(CINN_BLOCK_REDUCE_MACRO)

#define CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value) \
int tid = threadIdx.y * blockDim.x + threadIdx.x; \
Expand All @@ -699,6 +807,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_DISCRETE_REDUCE_MACRO)
EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO)
EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO)
EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO)
EXPAND_REDUCE_BF16_MACRO(CINN_DISCRETE_REDUCE_MACRO)

// ===============================================================
// ArgMin/ArgMax Support (ArgIdx Structures & Combine Functions)
Expand Down Expand Up @@ -800,6 +909,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO)
EXPAND_REDUCE_BF16_MACRO(CINN_GRID_REDUCE_MACRO)

__device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) {
__shared__ bool done;
Expand Down Expand Up @@ -888,6 +998,7 @@ CINN_CUSTOM_DEVICE_LT_NUM(int16, int16_t)
CINN_CUSTOM_DEVICE_LT_NUM(int32, int)
CINN_CUSTOM_DEVICE_LT_NUM(int64, int64_t)
CINN_CUSTOM_DEVICE_LT_NUM(fp16, float16)
CINN_CUSTOM_DEVICE_LT_NUM(bf16, bfloat16)
#undef CINN_CUSTOM_DEVICE_LT_NUM

#define CINN_CUSTOM_DEVICE_GT_NUM(TYPE_SUFFIX, TYPE) \
Expand All @@ -910,6 +1021,7 @@ CINN_CUSTOM_DEVICE_GT_NUM(int16, int16_t)
CINN_CUSTOM_DEVICE_GT_NUM(int32, int)
CINN_CUSTOM_DEVICE_GT_NUM(int64, int64_t)
CINN_CUSTOM_DEVICE_GT_NUM(fp16, float16)
CINN_CUSTOM_DEVICE_GT_NUM(bf16, bfloat16)
#undef CINN_CUSTOM_DEVICE_GT_NUM

#define CINN_CUSTOM_DEVICE_INDEX_ADD(TYPE_SUFFIX, TYPE) \
Expand Down Expand Up @@ -939,6 +1051,7 @@ CINN_CUSTOM_DEVICE_INDEX_ADD(int64, int64_t)
CINN_CUSTOM_DEVICE_INDEX_ADD(fp32, float)
CINN_CUSTOM_DEVICE_INDEX_ADD(fp64, double)
CINN_CUSTOM_DEVICE_INDEX_ADD(fp16, float16)
CINN_CUSTOM_DEVICE_INDEX_ADD(bf16, bfloat16)
#undef CINN_CUSTOM_DEVICE_INDEX_ADD

__device__ int cinn_custom_device_resize_bilinear(const int *buf,
Expand Down
Loading