Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "third_party/eigen"]
path = third_party/eigen
url = git@github.com:InfiniTensor/eigen-mirror.git
[submodule "third_party/flash_attention"]
path = third_party/flash_attention
url = https://github.com/Dao-AILab/flash-attention.git
20 changes: 12 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,24 @@ if(USE_CUDA)
add_compile_definitions(USE_CUDA=1)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)

# ========== cuDNN 库 ==========
find_library(CUDNN_LIBRARY cudnn REQUIRED)
message(STATUS "Found cuDNN at: ${CUDNN_LIBRARY}")
# ========================================

include_directories(${CUDAToolkit_INCLUDE_DIRS})

# CUDA compilation options
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")

# Only compile CUDA kernels / cuda sources here (your original used src/*.cu)
# Only compile CUDA kernels / cuda sources here
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)

add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS})
target_include_directories(infini_train_cuda_kernels PUBLIC
${PROJECT_SOURCE_DIR}/third_party/cudnn-frontend/include
)
set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90")

target_link_libraries(infini_train_cuda_kernels
Expand All @@ -92,6 +101,7 @@ if(USE_CUDA)
CUDA::cudart
CUDA::cublas
CUDA::cuda_driver
${CUDNN_LIBRARY}
)

if(USE_NCCL)
Expand All @@ -116,8 +126,6 @@ target_link_libraries(infini_train
)

if(USE_CUDA)
# infini_train contains cuda runtime wrappers (*.cc) like cuda_blas_handle.cc/cuda_guard.cc
# Those may need CUDA runtime/driver/cublas symbols at final link, so attach them here too.
target_link_libraries(infini_train
PUBLIC
infini_train_cuda_kernels
Expand All @@ -127,15 +135,12 @@ if(USE_CUDA)
)

if(USE_NCCL)
# If your core library code also directly references NCCL symbols (not only kernels),
# keep this. Otherwise it's harmless.
target_link_libraries(infini_train PUBLIC nccl)
endif()
endif()

# ------------------------------------------------------------------------------
# Helper: link libraries in a group to fix static lib one-pass resolution
# (THIS is what fixes "undefined reference" from cuda_kernels -> core symbols)
# ------------------------------------------------------------------------------
function(link_infini_train_exe target_name)
if(USE_CUDA)
Expand All @@ -160,7 +165,6 @@ function(link_infini_train_exe target_name)
endif()
endfunction()


# ------------------------------------------------------------------------------
# Examples
# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -199,4 +203,4 @@ add_executable(test_hook test/hook/test_hook.cc)
target_link_libraries(test_hook infini_train)

add_executable(test_precision_check test/hook/test_precision_check.cc)
target_link_libraries(test_precision_check infini_train)
target_link_libraries(test_precision_check infini_train)
88 changes: 88 additions & 0 deletions InfiniTrain报告.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# InfiniTrain 作业报告

## 1. 功能正确性验证
gpt2_1_bfloat16
![alt text](image-3.png)
gpt2_bfloat16_flash
![alt text](image-4.png)
llama3_1_bfloat16
![alt text](image-2.png)
llama3_1_bfloat16_flash
![alt text](image-5.png)


## 2. 性能评估报告
### 2.1 实验环境说明

**硬件环境**
- GPU 型号:NVIDIA A100-SXM4-80GB
- 单卡显存:81920 MiB(80GB)
- 机器总卡数:8 张(index 0~7)
- 本次测试可见设备:`CUDA_VISIBLE_DEVICES=4,5,6,7`
- 实际并行配置:日志中 `DP=1, TP=1, SP=1, PP=1`,即单进程单卡执行

**软件环境**
- CUDA:12.8(`nvcc` build `cuda_12.8.r12.8`)
- Driver:570.133.20
- C++ 编译器:`c++ (Ubuntu 13.3.0) 13.3.0`
- CMake:3.31.4
- 编译命令:`cmake -DUSE_CUDA=ON -DUSE_NCCL=ON .. && make -j`

### 2.2 实验配置

基于四个日志文件:
- `gpt2_1_bfloat16.log`(baseline)
- `gpt2_1_bfloat16_fla.log`(FlashAttention)
- `llama3_1_bfloat16.log`(baseline)
- `llama3_1_bfloat16_fla.log`(FlashAttention)

关键参数(由程序默认参数与命令行确认):
- `dtype=bfloat16`
- `batch_size=4`
- `sequence_length=64`
- `total_batch_size=256 tokens/step`
- 训练步数:10 steps
- baseline:小算子拼接版本(不加 `--flash true`)
- 实验组:FlashAttention 融合算子版本(`--flash true`)

> 说明:为减少首步冷启动影响,下面主表采用 **step 2~10** 的均值作为稳态指标。

### 2.3 性能指标定义

- 平均时延(avg latency):每步迭代耗时均值(ms)
- 吞吐率(tokens/s):日志中的每步 tokens/s 均值
- GPU 显存占用(MB):日志 `peak used` 的峰值(max)
- 加速比:$\text{Speedup} = \frac{\text{Latency}_{baseline}}{\text{Latency}_{flash}}$
- 显存节省比例:$\text{MemSaving} = \frac{\text{Mem}_{baseline}-\text{Mem}_{flash}}{\text{Mem}_{baseline}} \times 100\%$

### 2.4 结果展示(baseline vs FlashAttention)

| 模型 | 方案 | Avg Latency (ms) | Throughput (tok/s) | Peak Used (MB) |
|---|---|---:|---:|---:|
| GPT2 | baseline | 119.71 | 2153.67 | 1914 |
| GPT2 | FlashAttention | 63.58 | 4057.67 | 3056 |
| LLaMA3 | baseline | 768.33 | 333.78 | 24561 |
| LLaMA3 | FlashAttention | 336.90 | 765.33 | 26552 |

**汇总指标(按模型聚合)**

| 模型 | Speedup (baseline/flash) | 吞吐提升 (flash/baseline) | 显存节省比例 |
|---|---:|---:|---:|
| GPT2 | 1.88x | 1.88x | -59.67% |
| LLaMA3 | 2.28x | 2.29x | -8.11% |

### 2.5 结论分析

1. **GPT2 上 FlashAttention 提升明显**:
- 时延从 119.71 ms 降到 63.58 ms,Speedup 为 **1.88x**;
- 吞吐从 2153.67 提升到 4057.67 tok/s(约 **1.88x**)。

2. **LLaMA3 上收益显著**:
- 时延从 768.33 ms 降到 336.90 ms,Speedup 为 **2.28x**;
- 吞吐从 333.78 提升到 765.33 tok/s(约 **2.29x**)。

3. **显存占用现象**:
- GPT2 在本次日志中 FlashAttention 的 `peak used` 更高(1914 MB -> 3056 MB,显存节省比例 -59.67%);
- LLaMA3 在本次日志中 FlashAttention 的 `peak used` 也更高(24561 MB -> 26552 MB,显存节省比例 -8.11%);
- 说明本次实验里 FlashAttention 的收益主要体现在计算效率(时延/吞吐),而非显存降低。

6 changes: 6 additions & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)")
DEFINE_string(
precision_check, "",
"precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH");
DEFINE_bool(flash, false, "Whether to enable flash attention");

using namespace infini_train;

Expand Down Expand Up @@ -140,6 +141,7 @@ void Train(const nn::parallel::Rank &rank) {

if (rank.IsParallel()) {
device = Device(Device::DeviceType::kCUDA, rank.thread_rank());
//
auto *pg_factory = ProcessGroupFactory::Instance(device.type());

if (ddp_world_size > 1) {
Expand Down Expand Up @@ -322,6 +324,10 @@ void Train(const nn::parallel::Rank &rank) {
}

for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
if (auto dist_optimizer = std::dynamic_pointer_cast<nn::parallel::DistributedOptimizer>(optimizer)) {
dist_optimizer->SetIsLastMicrobatch(micro_step == grad_accum_steps - 1);
}

// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device.type(), dtype);

Expand Down
50 changes: 37 additions & 13 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <vector>

#include "glog/logging.h"
#include "gflags/gflags.h"

#include "example/common/utils.h"
#include "infini_train/include/device.h"
Expand All @@ -29,6 +30,7 @@
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/tensor.h"


using namespace infini_train;
namespace nn = infini_train::nn;

Expand Down Expand Up @@ -78,6 +80,7 @@ CausalSelfAttention::CausalSelfAttention(const GPT2Config &config)
->View({1, 1, config_.block_size, config_.block_size});
}

DECLARE_bool(flash);
std::vector<std::shared_ptr<infini_train::Tensor>>
CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
auto tp_world_size = nn::parallel::global::GetTensorParallelSize();
Expand All @@ -96,7 +99,7 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
auto k = qkv[1];
auto v = qkv[2];

// NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear
// NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear T->seq_len
const auto T = q->Dims()[1];

// View to multi-head: local_n_head * head_dim == local_C
Expand All @@ -105,18 +108,39 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
std::shared_ptr<infini_train::Tensor> y;
if (FLAGS_flash) {
// cuDNN SDPA path: causal masking should be enabled by `is_causal=true`.
// Do not pass the 0/1 tril mask as additive bias (it is not -inf mask).
auto q_flash = q;
auto k_flash = k;
auto v_flash = v;
if (q->Dtype() == DataType::kFLOAT32) {
q_flash = std::make_shared<Tensor>(q->To(DataType::kBFLOAT16));
k_flash = std::make_shared<Tensor>(k->To(DataType::kBFLOAT16));
v_flash = std::make_shared<Tensor>(v->To(DataType::kBFLOAT16));
}
y = nn::function::ScaledDotProductAttention(q_flash, k_flash, v_flash, nullptr, 0.0, true, std::nullopt,
false);
if (y->Dtype() != q->Dtype()) {
y = std::make_shared<Tensor>(y->To(q->Dtype()));
}
// ensure expected layout: (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
}

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
Expand Down
5 changes: 5 additions & 0 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ constexpr char kDtypeBF16[] = "bfloat16";
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_bool(flash, false, "Whether to enable flash attention");

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -298,6 +299,10 @@ void Train(const nn::parallel::Rank &rank) {
}

for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
if (auto dist_optimizer = std::dynamic_pointer_cast<nn::parallel::DistributedOptimizer>(optimizer)) {
dist_optimizer->SetIsLastMicrobatch(micro_step == grad_accum_steps - 1);
}

// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device.type(), dtype);

Expand Down
58 changes: 39 additions & 19 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <unordered_map>
#include <vector>

#include "gflags/gflags.h"
#include "glog/logging.h"

#include "example/common/utils.h"
Expand Down Expand Up @@ -138,6 +139,7 @@ std::vector<std::shared_ptr<Tensor>> RMSNorm::Forward(const std::vector<std::sha
return {norm * parameters_[kParamWeightName]};
}

DECLARE_bool(flash);
CausalSelfAttention::CausalSelfAttention(const LLaMA3Config &config)
: CloneableModule(kType), config_(config), n_head_(config.n_head), n_embd_(config.n_embd),
n_kv_head_(config.n_kv_head), n_rep_(config.n_head / config.n_kv_head), head_dim_(config.n_embd / config.n_head) {
Expand Down Expand Up @@ -217,26 +219,44 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y;
if (FLAGS_flash) {
// cuDNN SDPA path: causal masking should be enabled by `is_causal=true`.
// Do not pass Triu(ones, 1) mask as additive bias.
auto q_flash = q;
auto k_flash = k;
auto v_flash = v;
if (q->Dtype() == DataType::kFLOAT32) {
q_flash = std::make_shared<Tensor>(q->To(DataType::kBFLOAT16));
k_flash = std::make_shared<Tensor>(k->To(DataType::kBFLOAT16));
v_flash = std::make_shared<Tensor>(v->To(DataType::kBFLOAT16));
}
y = nn::function::ScaledDotProductAttention(q_flash, k_flash, v_flash, nullptr, 0.0, true, std::nullopt,
false);
if (y->Dtype() != q->Dtype()) {
y = std::make_shared<Tensor>(y->To(q->Dtype()));
}
// ensure expected layout: (B, H_local, T, D) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
} else {
// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
// output projection
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
Expand Down
1 change: 1 addition & 0 deletions example/mnist/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ constexpr char kDeviceCUDA[] = "cuda";

DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_bool(flash, false, "Whether to enable flash attention");

int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Expand Down
Binary file added image-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image-3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image-4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image-5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading