Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ build/
.cache/
.vscode/

*.log
*.report.rank*
*.records.log.rank*
.claude
CLAUDE.md
scripts/compare_logs
my_tmp
# *.sh
# g*.txt
# l*.txt
15 changes: 14 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ if(USE_CUDA)
# CUDA compilation options
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")

# Only compile CUDA kernels / cuda sources here (your original used src/*.cu)
# Compile regular CUDA kernels (sm_75+), excluding flash_attention.cu which needs sm_80+
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)
list(FILTER CUDA_KERNELS EXCLUDE REGEX ".*flash_attention\\.cu$")

add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS})
set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90")
Expand All @@ -94,6 +95,13 @@ if(USE_CUDA)
CUDA::cuda_driver
)

# Flash attention kernel requires sm_80+ (cp.async, bf16 mma, ldmatrix).
# Build as a separate library targeting sm_80 and sm_90 only.
add_library(infini_train_flash_attention STATIC
${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cuda/flash_attention.cu)
set_target_properties(infini_train_flash_attention PROPERTIES CUDA_ARCHITECTURES "80;90")
target_link_libraries(infini_train_flash_attention PUBLIC glog CUDA::cudart)

if(USE_NCCL)
message(STATUS "Add USE_NCCL, use NCCL with CUDA")
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
Expand Down Expand Up @@ -121,6 +129,7 @@ if(USE_CUDA)
target_link_libraries(infini_train
PUBLIC
infini_train_cuda_kernels
infini_train_flash_attention
CUDA::cudart
CUDA::cublas
CUDA::cuda_driver
Expand All @@ -145,6 +154,7 @@ function(link_infini_train_exe target_name)
infini_train
infini_train_cpu_kernels
infini_train_cuda_kernels
infini_train_flash_attention
"-Wl,--no-whole-archive"
"-Wl,--end-group"
)
Expand Down Expand Up @@ -200,3 +210,6 @@ target_link_libraries(test_hook infini_train)

add_executable(test_precision_check test/hook/test_precision_check.cc)
target_link_libraries(test_precision_check infini_train)

add_executable(test_flash_attention test/test_flash_attention.cc)
link_infini_train_exe(test_flash_attention)
9 changes: 8 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// flash attention
DEFINE_bool(flash, false, "Use FlashAttention for self-attention");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -181,9 +183,10 @@ void Train(const nn::parallel::Rank &rank) {
GPT2Config model_config;
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.flash = FLAGS_flash;
model = std::make_shared<GPT2>(model_config);
} else {
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
Expand All @@ -193,6 +196,10 @@ void Train(const nn::parallel::Rank &rank) {

utils::PrecisionChecker::BuildNameMap(model.get());

if (FLAGS_flash && FLAGS_dtype != kDtypeBF16) {
LOG(FATAL) << "--flash=true requires --dtype=bfloat16 (FlashAttention only supports bfloat16)";
}

// select the data type
// TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported
DataType dtype;
Expand Down
38 changes: 24 additions & 14 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/flash_attention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -105,18 +106,26 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
std::shared_ptr<Tensor> y;
if (config_.flash) {
// FlashAttention: q, k, v are (B, h_l, T, Dh)
y = std::make_shared<autograd::FlashAttention>(/*is_causal=*/true)->Apply({q, k, v})[0];
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto att_v = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = att_v->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
}

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
Expand Down Expand Up @@ -351,7 +360,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
}
} // namespace

std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -379,7 +388,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
.original_vocab_size = vocab_size,
.n_layer = n_layer,
.n_head = n_head,
.n_embd = n_embd});
.n_embd = n_embd,
.flash = flash});

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
Expand Down
3 changes: 2 additions & 1 deletion example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool flash = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down Expand Up @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const;

Expand Down
9 changes: 8 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// flash attention
DEFINE_bool(flash, false, "Use FlashAttention for self-attention");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -161,9 +163,10 @@ void Train(const nn::parallel::Rank &rank) {
// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
model_config.flash = FLAGS_flash;
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model = std::make_shared<LLaMA3>(model_config);
}
Expand All @@ -174,6 +177,10 @@ void Train(const nn::parallel::Rank &rank) {

LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";

if (FLAGS_flash && FLAGS_dtype != kDtypeBF16) {
LOG(FATAL) << "--flash=true requires --dtype=bfloat16 (FlashAttention only supports bfloat16)";
}

DataType dtype;
if (FLAGS_dtype == kDtypeFP32) {
dtype = DataType::kFLOAT32;
Expand Down
52 changes: 32 additions & 20 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/flash_attention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -77,6 +78,11 @@ ApplyRotaryEmbedding(const std::shared_ptr<Tensor> &xq, const std::shared_ptr<Te
std::vector<int64_t> target_shape(cos_sin->Dims().begin(), cos_sin->Dims().end() - 1);
auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2)
auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2)
// Cast cos/sin to match xq dtype to avoid float32 promotion when freqs_cis is float32
if (cos->Dtype() != xq->Dtype()) {
cos = std::make_shared<Tensor>(cos->To(xq->Dtype()));
sin = std::make_shared<Tensor>(sin->To(xq->Dtype()));
}

auto slice_pair = [](const std::shared_ptr<Tensor> &x) {
auto even = x->Slice(-1, 0, x->Dims().back(), 2);
Expand Down Expand Up @@ -217,26 +223,31 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y;
if (config_.flash) {
// FlashAttention: q, k, v are (B, H_local, T, D)
y = std::make_shared<autograd::FlashAttention>(/*is_causal=*/true)->Apply({q, k, v})[0];
// (B, H_local, T, D) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
} else {
// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto att_v = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = att_v->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
// output projection
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
Expand Down Expand Up @@ -457,7 +468,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -496,6 +507,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.rope_theta = rope_theta,
.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.flash = flash,
.max_gen_batch_size = max_gen_bs});

// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
Expand Down
2 changes: 1 addition & 1 deletion example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
Loading
Loading