Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ if(USE_CUDA)
# CUDA compilation options
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")

# Only compile CUDA kernels / cuda sources here (your original used src/*.cu)
# Compile regular CUDA kernels (sm_75+), excluding flash_attention.cu which needs sm_80+
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)
list(FILTER CUDA_KERNELS EXCLUDE REGEX ".*flash_attention\\.cu$")

add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS})
set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90")
Expand All @@ -94,6 +95,13 @@ if(USE_CUDA)
CUDA::cuda_driver
)

# Flash attention kernel requires sm_80+ (cp.async, bf16 mma, ldmatrix).
# Build as a separate library targeting sm_80 and sm_90 only.
add_library(infini_train_flash_attention STATIC
${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cuda/flash_attention.cu)
set_target_properties(infini_train_flash_attention PROPERTIES CUDA_ARCHITECTURES "80;90")
target_link_libraries(infini_train_flash_attention PUBLIC glog CUDA::cudart)

if(USE_NCCL)
message(STATUS "Add USE_NCCL, use NCCL with CUDA")
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
Expand Down Expand Up @@ -121,6 +129,7 @@ if(USE_CUDA)
target_link_libraries(infini_train
PUBLIC
infini_train_cuda_kernels
infini_train_flash_attention
CUDA::cudart
CUDA::cublas
CUDA::cuda_driver
Expand All @@ -145,6 +154,7 @@ function(link_infini_train_exe target_name)
infini_train
infini_train_cpu_kernels
infini_train_cuda_kernels
infini_train_flash_attention
"-Wl,--no-whole-archive"
"-Wl,--end-group"
)
Expand Down
7 changes: 6 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// flash attention
DEFINE_bool(flash, false, "Use FlashAttention for self-attention");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -191,9 +193,10 @@ void Train(const nn::parallel::Rank &rank) {
std::shared_ptr<nn::Module> model = nullptr;

if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.flash = FLAGS_flash;
model = std::make_shared<GPT2>(model_config);
} else {
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
Expand All @@ -203,6 +206,8 @@ void Train(const nn::parallel::Rank &rank) {

utils::PrecisionChecker::BuildNameMap(model.get());

if (FLAGS_flash && FLAGS_dtype != kDtypeBF16) {
LOG(FATAL) << "--flash=true requires --dtype=bfloat16 (FlashAttention only supports bfloat16)";
// Get chunk size before wrapping with LoRA (needed for PipelineParallel)
auto gpt2_model = std::dynamic_pointer_cast<GPT2>(model);
CHECK(gpt2_model) << "GPT2 example expects GPT2 model.";
Expand Down
38 changes: 24 additions & 14 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/flash_attention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -105,18 +106,26 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
std::shared_ptr<Tensor> y;
if (config_.flash) {
// FlashAttention: q, k, v are (B, h_l, T, Dh)
y = std::make_shared<autograd::FlashAttention>(/*is_causal=*/true)->Apply({q, k, v})[0];
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto att_v = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = att_v->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
}

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
Expand Down Expand Up @@ -356,7 +365,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
}
} // namespace

std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -384,7 +393,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
.original_vocab_size = vocab_size,
.n_layer = n_layer,
.n_head = n_head,
.n_embd = n_embd});
.n_embd = n_embd,
.flash = flash});

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
Expand Down
3 changes: 2 additions & 1 deletion example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool flash = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down Expand Up @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const;

Expand Down
9 changes: 8 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// flash attention
DEFINE_bool(flash, false, "Use FlashAttention for self-attention");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -168,9 +170,10 @@ void Train(const nn::parallel::Rank &rank) {
// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
model_config.flash = FLAGS_flash;
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model = std::make_shared<LLaMA3>(model_config);
}
Expand Down Expand Up @@ -200,6 +203,10 @@ void Train(const nn::parallel::Rank &rank) {

LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";

if (FLAGS_flash && FLAGS_dtype != kDtypeBF16) {
LOG(FATAL) << "--flash=true requires --dtype=bfloat16 (FlashAttention only supports bfloat16)";
}

DataType dtype;
if (FLAGS_dtype == kDtypeFP32) {
dtype = DataType::kFLOAT32;
Expand Down
52 changes: 32 additions & 20 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/flash_attention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -77,6 +78,11 @@ ApplyRotaryEmbedding(const std::shared_ptr<Tensor> &xq, const std::shared_ptr<Te
std::vector<int64_t> target_shape(cos_sin->Dims().begin(), cos_sin->Dims().end() - 1);
auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2)
auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2)
// Cast cos/sin to match xq dtype to avoid float32 promotion when freqs_cis is float32
if (cos->Dtype() != xq->Dtype()) {
cos = std::make_shared<Tensor>(cos->To(xq->Dtype()));
sin = std::make_shared<Tensor>(sin->To(xq->Dtype()));
}

auto slice_pair = [](const std::shared_ptr<Tensor> &x) {
auto even = x->Slice(-1, 0, x->Dims().back(), 2);
Expand Down Expand Up @@ -217,26 +223,31 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y;
if (config_.flash) {
// FlashAttention: q, k, v are (B, H_local, T, D)
y = std::make_shared<autograd::FlashAttention>(/*is_causal=*/true)->Apply({q, k, v})[0];
// (B, H_local, T, D) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
} else {
// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto att_v = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = att_v->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
// output projection
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
Expand Down Expand Up @@ -457,7 +468,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -496,6 +507,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.rope_theta = rope_theta,
.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.flash = flash,
.max_gen_batch_size = max_gen_bs});

// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
Expand Down
2 changes: 1 addition & 1 deletion example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
32 changes: 32 additions & 0 deletions infini_train/include/autograd/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {
class FlashAttention : public Function {
public:
static constexpr char kType[] = "FlashAttentionFunction";

explicit FlashAttention(bool is_causal = true, float scale = -1.0f)
: Function(kType), is_causal_(is_causal), scale_(scale) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
bool is_causal_;
float scale_; // <0 means use default 1/sqrt(head_dim)
// L (logsumexp) is returned by the forward kernel alongside O, but is not an
// output visible to the caller. We stash it here so SetupContext can save it.
std::shared_ptr<Tensor> l_;
};
} // namespace infini_train::autograd
44 changes: 44 additions & 0 deletions infini_train/src/autograd/flash_attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "infini_train/include/autograd/flash_attention.h"

#include "glog/logging.h"

#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

namespace infini_train::autograd {

std::vector<std::shared_ptr<Tensor>>
FlashAttention::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
CHECK_EQ(input_tensors.size(), 3);
const auto &q = input_tensors[0];
auto device = q->GetDevice().type();
auto result = Dispatcher::Instance().Call<std::vector<std::shared_ptr<Tensor>>>(
{device, "FlashAttentionForward"}, q, input_tensors[1], input_tensors[2], is_causal_, scale_);
CHECK_EQ(result.size(), 2);
l_ = result[1];
return {result[0]};
}

void FlashAttention::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) {
CHECK(l_ != nullptr);
saved_tensors_ = {input_tensors[0], input_tensors[1], input_tensors[2], output_tensors[0], l_};
l_ = nullptr;
}

std::vector<std::shared_ptr<Tensor>>
FlashAttention::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(saved_tensors_.size(), 5);
const auto &q = saved_tensors_[0];
const auto &k = saved_tensors_[1];
const auto &v = saved_tensors_[2];
const auto &o = saved_tensors_[3];
const auto &l = saved_tensors_[4];
CHECK_EQ(grad_outputs.size(), 1);
const auto &do_ = grad_outputs[0];
auto device = q->GetDevice().type();
return Dispatcher::Instance().Call<std::vector<std::shared_ptr<Tensor>>>({device, "FlashAttentionBackward"}, q, k,
v, o, l, do_, is_causal_, scale_);
}

} // namespace infini_train::autograd
Loading