-
Notifications
You must be signed in to change notification settings - Fork 223
Finish project 03_nf4_dequant #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xfarawayx
wants to merge
2
commits into
InfiniTensor:2025-winter-project
Choose a base branch
from
xfarawayx:2025-winter-project
base: 2025-winter-project
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,504
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # NF4 反量化 CUDA Kernel | ||
|
|
||
| NF4(Normal Float 4)双重量化权重的 GPU 反量化实现,兼容 bitsandbytes 格式。技术细节与实验结果见 [`docs/report.md`](docs/report.md)。 | ||
|
|
||
| ## 目录结构 | ||
|
|
||
| ``` | ||
| 03_nf4_dequant/ | ||
| ├── run.sh # 统一入口脚本 | ||
| ├── kernel/ | ||
| │ ├── CMakeLists.txt # CMake 构建 (自动检测 GPU 架构) | ||
| │ ├── main.cu # 主程序: 文件 IO、kernel 启动、性能计时 | ||
| │ └── nf4_dequant_kernel.cuh # 反量化 kernel 实现 | ||
| ├── kernel_noncuda/ # 国产 GPU 适配 | ||
| │ ├── iluvatar/ # 天数智芯 (clang++ / CUDA 兼容) | ||
| │ ├── moore/ # 摩尔线程 (mcc / MUSA) | ||
| │ └── mutex/ # 沐曦 (mxcc / MACA) | ||
| ├── scripts/ | ||
| │ ├── generate_data.py # 用 bitsandbytes 生成 NF4 量化数据 + 参考输出 | ||
| │ ├── verify.py # 正确性验证: CUDA 输出 vs bitsandbytes 参考 | ||
| │ └── bench_bnb.py # bitsandbytes 性能基准 | ||
| ├── docs/ | ||
| │ └── report.md # 实现报告 | ||
| └── data/ # 生成的数据 (自动创建) | ||
| ``` | ||
|
|
||
| ## 快速开始 | ||
|
|
||
| ```bash | ||
| # 全流程: 生成数据 → 编译 → 运行 kernel → 验证正确性 → bnb 性能对比 | ||
| ./run.sh all | ||
| ./run.sh # 等价于 ./run.sh test | ||
| ``` | ||
|
|
||
| ## 子命令与选项 | ||
|
|
||
| | 命令 | 说明 | | ||
| |------|------| | ||
| | `./run.sh generate` | 仅生成 NF4 量化测试数据 | | ||
| | `./run.sh build` | 仅编译 CUDA kernel | | ||
| | `./run.sh test` | 生成数据 → 编译 → 运行 → 验证正确性 (默认) | | ||
| | `./run.sh bench` | bitsandbytes 基准性能测试 | | ||
| | `./run.sh all` | 完整流程 | | ||
|
|
||
| | 选项 | 默认值 | 说明 | | ||
| |------|--------|------| | ||
| | `--rows` | 4096 | 矩阵行数 | | ||
| | `--cols` | 4096 | 矩阵列数 | | ||
| | `--blocksize` | 64 | 量化块大小 (64/128/256/…/4096) | | ||
| | `--compute_type` | bf16 | 输出类型 (bf16/fp16) | | ||
| | `--seed` | 42 | 随机种子 | | ||
| | `--gpu_arch` | 自动检测 | GPU 架构, 如 80(A100)/89(4090)/90(H100) | | ||
| | `--warmup` | 10 | 预热次数 | | ||
| | `--repeats` | 100 | 计时重复次数 | | ||
| | `--sweep` | - | bench 时扫描多种矩阵大小 | | ||
|
|
||
| ```bash | ||
| # 示例 | ||
| ./run.sh --rows 4096 --cols 11008 --blocksize 128 | ||
| ./run.sh --compute_type fp16 | ||
| ./run.sh bench --sweep | ||
| ``` | ||
|
|
||
| ## 依赖 | ||
|
|
||
| - CUDA Toolkit | ||
| - Python 3.8+, PyTorch (CUDA), bitsandbytes >= 0.43, NumPy |
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| cmake_minimum_required(VERSION 3.18) | ||
| project(nf4_dequant LANGUAGES CXX CUDA) | ||
|
|
||
| set(CMAKE_CUDA_STANDARD 17) | ||
| set(CMAKE_CXX_STANDARD 17) | ||
|
|
||
| # ---------- GPU 架构 ---------- | ||
| # 用法: | ||
| # cmake .. -DGPU_ARCH=80 # A100 | ||
| # cmake .. -DGPU_ARCH=89 # RTX 4090 | ||
| # cmake .. -DGPU_ARCH=90 # H100 | ||
| # cmake .. -DGPU_ARCH="80;89;90" # 多架构 | ||
| # cmake .. # 自动检测 | ||
| if(DEFINED GPU_ARCH) | ||
| set(CMAKE_CUDA_ARCHITECTURES ${GPU_ARCH}) | ||
| message(STATUS "GPU architecture (user-specified): ${CMAKE_CUDA_ARCHITECTURES}") | ||
| else() | ||
| # 自动检测当前 GPU 的 compute capability | ||
| execute_process( | ||
| COMMAND nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits | ||
| OUTPUT_VARIABLE _detected_arch | ||
| OUTPUT_STRIP_TRAILING_WHITESPACE | ||
| RESULT_VARIABLE _detect_result | ||
| ) | ||
| if(_detect_result EQUAL 0 AND _detected_arch) | ||
| # 取第一块 GPU,格式 "8.0" → "80" | ||
| string(REGEX MATCH "^[0-9]+\\.[0-9]+" _first_arch "${_detected_arch}") | ||
| string(REPLACE "." "" _arch_num "${_first_arch}") | ||
| set(CMAKE_CUDA_ARCHITECTURES ${_arch_num}) | ||
| message(STATUS "GPU architecture (auto-detected): sm_${_arch_num}") | ||
| else() | ||
| # 检测失败时回退到常见架构 | ||
| set(CMAKE_CUDA_ARCHITECTURES "80;89;90") | ||
| message(STATUS "GPU architecture (fallback): ${CMAKE_CUDA_ARCHITECTURES}") | ||
| endif() | ||
| endif() | ||
|
|
||
| # ---------- 临时目录 ---------- | ||
| set(NF4_LOCAL_TMP_DIR "${CMAKE_BINARY_DIR}/.tmp") | ||
| file(MAKE_DIRECTORY "${NF4_LOCAL_TMP_DIR}") | ||
| set(NF4_TMP_ENV_PREFIX | ||
| "${CMAKE_COMMAND} -E env TMPDIR=${NF4_LOCAL_TMP_DIR} TMP=${NF4_LOCAL_TMP_DIR} TEMP=${NF4_LOCAL_TMP_DIR}") | ||
| set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${NF4_TMP_ENV_PREFIX}") | ||
| set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${NF4_TMP_ENV_PREFIX}") | ||
|
|
||
| # ---------- 构建目标 ---------- | ||
| add_executable(nf4_dequant main.cu) | ||
| target_include_directories(nf4_dequant PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,269 @@ | ||
| // NF4 反量化 CUDA 程序 | ||
| // 用法: ./nf4_dequant <weight_file> <output_file> [bf16|fp16] [warmup] [repeats] | ||
|
|
||
| #include <cstdio> | ||
| #include <cstdlib> | ||
| #include <cstdint> | ||
| #include <cstring> | ||
| #include <cmath> | ||
| #include <vector> | ||
| #include <string> | ||
| #include <chrono> | ||
| #include <algorithm> | ||
|
|
||
| #include <cuda_runtime.h> | ||
|
|
||
| #include "nf4_dequant_kernel.cuh" | ||
|
|
||
| // CUDA 错误检查 | ||
| #define CUDA_CHECK(call) \ | ||
| do { \ | ||
| cudaError_t err = (call); \ | ||
| if (err != cudaSuccess) { \ | ||
| fprintf(stderr, "CUDA error at %s:%d: %s\n", \ | ||
| __FILE__, __LINE__, cudaGetErrorString(err)); \ | ||
| exit(EXIT_FAILURE); \ | ||
| } \ | ||
| } while (0) | ||
|
|
||
| // 二进制权重文件布局: header (rows, cols, blocksize) + packed_weights + absmax_q + absmax2 + code2 + offset | ||
| struct NF4Data { | ||
| int64_t num_rows; | ||
| int64_t num_cols; | ||
| int32_t blocksize; | ||
|
|
||
| std::vector<uint8_t> packed_weights; | ||
| std::vector<uint8_t> absmax_q; | ||
| std::vector<uint16_t> absmax2; // fp16 raw bits | ||
| std::vector<uint16_t> code2; // fp16[256] raw bits | ||
| float offset; | ||
|
|
||
| int64_t n_elements; | ||
| int32_t num_blocks; | ||
| int32_t num_groups; | ||
| int32_t s2_blocksize; | ||
| }; | ||
|
|
||
| bool read_nf4_data(const char* filepath, NF4Data& data) { | ||
| FILE* f = fopen(filepath, "rb"); | ||
| if (!f) { | ||
| fprintf(stderr, "[ERROR] Cannot open file: %s\n", filepath); | ||
| return false; | ||
| } | ||
|
|
||
| // Header | ||
| fread(&data.num_rows, sizeof(int64_t), 1, f); | ||
| fread(&data.num_cols, sizeof(int64_t), 1, f); | ||
| fread(&data.blocksize, sizeof(int32_t), 1, f); | ||
|
|
||
| data.n_elements = data.num_rows * data.num_cols; | ||
| data.num_blocks = (int32_t)((data.n_elements + data.blocksize - 1) / data.blocksize); | ||
|
|
||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int64_t packed_size = data.n_elements / 2; | ||
| data.packed_weights.resize(packed_size); | ||
| fread(data.packed_weights.data(), 1, packed_size, f); | ||
|
|
||
| data.absmax_q.resize(data.num_blocks); | ||
| fread(data.absmax_q.data(), 1, data.num_blocks, f); | ||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // 从剩余字节反推 num_groups(文件中未显式存储) | ||
| long current_pos = ftell(f); | ||
| fseek(f, 0, SEEK_END); | ||
| long file_size = ftell(f); | ||
| fseek(f, current_pos, SEEK_SET); | ||
|
|
||
| long remaining = file_size - current_pos; | ||
| long fixed_tail = 256 * 2 + 4; // code2 (512B) + offset (4B) | ||
| long absmax2_bytes = remaining - fixed_tail; | ||
| data.num_groups = (int32_t)(absmax2_bytes / 2); | ||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| data.s2_blocksize = (data.num_blocks + data.num_groups - 1) / data.num_groups; | ||
|
|
||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| data.absmax2.resize(data.num_groups); | ||
| fread(data.absmax2.data(), 2, data.num_groups, f); | ||
|
|
||
| data.code2.resize(256); | ||
| fread(data.code2.data(), 2, 256, f); | ||
|
|
||
| fread(&data.offset, sizeof(float), 1, f); | ||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| fclose(f); | ||
| return true; | ||
| } | ||
|
|
||
| int main(int argc, char* argv[]) { | ||
| if (argc < 3) { | ||
| fprintf(stderr, "用法: %s <weight_file> <output_file> [bf16|fp16] [warmup] [repeats]\n", argv[0]); | ||
| return 1; | ||
| } | ||
|
|
||
| const char* weight_file = argv[1]; | ||
| const char* output_file = argv[2]; | ||
| std::string compute_type = (argc > 3) ? argv[3] : "bf16"; | ||
| int warmup = (argc > 4) ? atoi(argv[4]) : 10; | ||
| int repeats = (argc > 5) ? atoi(argv[5]) : 100; | ||
|
|
||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bool use_bf16 = (compute_type == "bf16"); | ||
|
|
||
| // 读取数据 | ||
| printf("[INFO] 读取权重文件: %s\n", weight_file); | ||
| NF4Data data; | ||
| if (!read_nf4_data(weight_file, data)) return 1; | ||
|
|
||
| printf(" num_rows = %ld\n", (long)data.num_rows); | ||
| printf(" num_cols = %ld\n", (long)data.num_cols); | ||
| printf(" blocksize = %d\n", data.blocksize); | ||
| printf(" n_elements = %ld\n", (long)data.n_elements); | ||
| printf(" num_blocks = %d\n", data.num_blocks); | ||
| printf(" num_groups = %d\n", data.num_groups); | ||
| printf(" s2_blocksize = %d\n", data.s2_blocksize); | ||
| printf(" offset = %f\n", data.offset); | ||
| printf(" compute_type = %s\n", compute_type.c_str()); | ||
|
|
||
| // 分配 GPU 内存 | ||
| uint8_t* d_packed_weights; | ||
| uint8_t* d_absmax_q; | ||
| half* d_absmax2; | ||
| half* d_code2; | ||
| void* d_output; | ||
|
|
||
| int64_t packed_size = data.n_elements / 2; | ||
|
|
||
| int64_t output_bytes = data.n_elements * 2; // bf16/fp16 = 2 bytes each | ||
|
|
||
| CUDA_CHECK(cudaMalloc(&d_packed_weights, packed_size)); | ||
| CUDA_CHECK(cudaMalloc(&d_absmax_q, data.num_blocks)); | ||
| CUDA_CHECK(cudaMalloc(&d_absmax2, data.num_groups * sizeof(half))); | ||
| CUDA_CHECK(cudaMalloc(&d_code2, 256 * sizeof(half))); | ||
| CUDA_CHECK(cudaMalloc(&d_output, output_bytes)); | ||
|
|
||
| // H2D 传输 | ||
| CUDA_CHECK(cudaMemcpy(d_packed_weights, data.packed_weights.data(), | ||
| packed_size, cudaMemcpyHostToDevice)); | ||
| CUDA_CHECK(cudaMemcpy(d_absmax_q, data.absmax_q.data(), | ||
| data.num_blocks, cudaMemcpyHostToDevice)); | ||
| CUDA_CHECK(cudaMemcpy(d_absmax2, data.absmax2.data(), | ||
| data.num_groups * sizeof(half), cudaMemcpyHostToDevice)); | ||
| CUDA_CHECK(cudaMemcpy(d_code2, data.code2.data(), | ||
| 256 * sizeof(half), cudaMemcpyHostToDevice)); | ||
|
|
||
| // Kernel launch 配置 | ||
| int n_packed = (int)((data.n_elements + 1) / 2); | ||
| int n_packed_vec = (n_packed + 3) / 4; // 每线程 4 字节 | ||
| int threads_per_block = 256; | ||
| int num_blocks_kernel = (n_packed_vec + threads_per_block - 1) / threads_per_block; | ||
|
|
||
| // 预计算 log2 用于位移优化 | ||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int log2_bs = log2_pow2(data.blocksize); | ||
| int log2_s2 = log2_pow2(data.s2_blocksize); | ||
|
|
||
| printf("\n[INFO] Kernel 配置:\n"); | ||
| printf(" n_packed = %d\n", n_packed); | ||
| printf(" n_packed_vec = %d (向量化后)\n", n_packed_vec); | ||
| printf(" threads_per_block = %d\n", threads_per_block); | ||
| printf(" grid_size = %d\n", num_blocks_kernel); | ||
| printf(" log2_blocksize = %d\n", log2_bs); | ||
| printf(" log2_s2_blocksize = %d\n", log2_s2); | ||
|
|
||
| // 预热 | ||
| printf("\n[INFO] 预热 %d 次...\n", warmup); | ||
| for (int i = 0; i < warmup; i++) { | ||
| if (use_bf16) { | ||
| nf4_dequantize_kernel<__nv_bfloat16><<<num_blocks_kernel, threads_per_block>>>( | ||
xfarawayx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| d_packed_weights, d_absmax_q, d_absmax2, d_code2, | ||
| data.offset, log2_bs, log2_s2, | ||
| data.n_elements, (__nv_bfloat16*)d_output | ||
| ); | ||
| } else { | ||
| nf4_dequantize_kernel<half><<<num_blocks_kernel, threads_per_block>>>( | ||
| d_packed_weights, d_absmax_q, d_absmax2, d_code2, | ||
| data.offset, log2_bs, log2_s2, | ||
| data.n_elements, (half*)d_output | ||
| ); | ||
| } | ||
| } | ||
| CUDA_CHECK(cudaDeviceSynchronize()); | ||
|
|
||
| // 计时: CUDA Events,每次迭代间同步以隔离测量 | ||
| printf("[INFO] 计时 %d 次...\n", repeats); | ||
|
|
||
| cudaEvent_t ev_start, ev_end; | ||
| CUDA_CHECK(cudaEventCreate(&ev_start)); | ||
| CUDA_CHECK(cudaEventCreate(&ev_end)); | ||
|
|
||
| std::vector<float> times(repeats); | ||
|
|
||
| for (int i = 0; i < repeats; i++) { | ||
| CUDA_CHECK(cudaDeviceSynchronize()); | ||
| CUDA_CHECK(cudaEventRecord(ev_start)); | ||
| if (use_bf16) { | ||
| nf4_dequantize_kernel<__nv_bfloat16><<<num_blocks_kernel, threads_per_block>>>( | ||
| d_packed_weights, d_absmax_q, d_absmax2, d_code2, | ||
| data.offset, log2_bs, log2_s2, | ||
| data.n_elements, (__nv_bfloat16*)d_output | ||
| ); | ||
| } else { | ||
| nf4_dequantize_kernel<half><<<num_blocks_kernel, threads_per_block>>>( | ||
| d_packed_weights, d_absmax_q, d_absmax2, d_code2, | ||
| data.offset, log2_bs, log2_s2, | ||
| data.n_elements, (half*)d_output | ||
| ); | ||
| } | ||
| CUDA_CHECK(cudaEventRecord(ev_end)); | ||
| CUDA_CHECK(cudaEventSynchronize(ev_end)); | ||
| CUDA_CHECK(cudaEventElapsedTime(×[i], ev_start, ev_end)); | ||
| } | ||
|
|
||
| // 排序取中位数,抗干扰 | ||
| std::vector<float> sorted_times = times; | ||
| std::sort(sorted_times.begin(), sorted_times.end()); | ||
|
|
||
| float total_ms = 0.0f; | ||
| float min_ms = sorted_times.front(); | ||
| float max_ms = sorted_times.back(); | ||
| for (int i = 0; i < repeats; i++) total_ms += times[i]; | ||
| float avg_ms = total_ms / repeats; | ||
| float median_ms = sorted_times[repeats / 2]; | ||
|
|
||
| // 有效内存带宽 (基于中位数) | ||
| double read_bytes = (double)packed_size + data.num_blocks + data.num_groups * 2 + 256 * 2; | ||
| double write_bytes = (double)output_bytes; | ||
| double total_bytes = read_bytes + write_bytes; | ||
| double bandwidth_gbps = total_bytes / (median_ms * 1e-3) / 1e9; | ||
|
|
||
| printf("\n========================================\n"); | ||
| printf(" NF4 反量化 Kernel 性能\n"); | ||
| printf("========================================\n"); | ||
| printf(" 矩阵大小 : (%ld, %ld)\n", (long)data.num_rows, (long)data.num_cols); | ||
| printf(" 块大小 : %d\n", data.blocksize); | ||
| printf(" 输出类型 : %s\n", compute_type.c_str()); | ||
| printf(" 平均耗时 : %.4f ms\n", avg_ms); | ||
| printf(" 中位数耗时 : %.4f ms\n", median_ms); | ||
| printf(" 最小耗时 : %.4f ms\n", min_ms); | ||
| printf(" 最大耗时 : %.4f ms\n", max_ms); | ||
| printf(" 有效带宽 : %.2f GB/s (基于中位数)\n", bandwidth_gbps); | ||
| printf("========================================\n"); | ||
|
|
||
| // 写出结果 | ||
| std::vector<uint8_t> h_output(output_bytes); | ||
| CUDA_CHECK(cudaMemcpy(h_output.data(), d_output, output_bytes, cudaMemcpyDeviceToHost)); | ||
|
|
||
| FILE* fout = fopen(output_file, "wb"); | ||
| if (!fout) { | ||
| fprintf(stderr, "[ERROR] Cannot open output file: %s\n", output_file); | ||
| return 1; | ||
| } | ||
| fwrite(h_output.data(), 1, output_bytes, fout); | ||
| fclose(fout); | ||
| printf("\n[INFO] 已写入解量化输出: %s (%ld bytes)\n", output_file, (long)output_bytes); | ||
|
|
||
| // 清理 | ||
| cudaEventDestroy(ev_start); | ||
| cudaEventDestroy(ev_end); | ||
| CUDA_CHECK(cudaFree(d_packed_weights)); | ||
| CUDA_CHECK(cudaFree(d_absmax_q)); | ||
| CUDA_CHECK(cudaFree(d_absmax2)); | ||
| CUDA_CHECK(cudaFree(d_code2)); | ||
| CUDA_CHECK(cudaFree(d_output)); | ||
|
|
||
| printf("[DONE] 完成\n"); | ||
| return 0; | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.