Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
DEFINE_bool(flash, false, "Whether to enable FlashAttention in CausalSelfAttention");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
Expand Down Expand Up @@ -191,11 +192,16 @@ void Train(const nn::parallel::Rank &rank) {
std::shared_ptr<nn::Module> model = nullptr;

if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Loading GPT2 from LLMC file: " << FLAGS_llmc_filepath;
model = GPT2::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else if (kModelToConfigs.count(FLAGS_model)) {
model_config = kModelToConfigs.at(FLAGS_model);
model_config.flash = FLAGS_flash;
model = std::make_shared<GPT2>(model_config);
} else {
if (FLAGS_flash) {
LOG(WARNING) << "--flash is ignored when loading GPT2 from pretrained checkpoint.";
}
model = GPT2::FromPretrained(kStrToModelType.at(FLAGS_model));
}

Expand Down
43 changes: 29 additions & 14 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/ScaledDotProductAttention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -105,18 +106,31 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
std::shared_ptr<Tensor> y = nullptr;
if (config_.flash) {
// FlashAttention expects (B, T, H, D)
auto q_flash = q->Transpose(1, 2);
auto k_flash = k->Transpose(1, 2);
auto v_flash = v->Transpose(1, 2);
auto y_flash = std::make_shared<autograd::ScaledDotProductAttention>(
/*attn_mask=*/nullptr, /*dropout_p=*/0, /*is_causal=*/true,
/*scale=*/1.0 / std::sqrt(static_cast<double>(head_dim)), /*enable_gqa=*/false)
->Apply({q_flash, k_flash, v_flash})[0];
y = y_flash->Contiguous()->View({B, T, local_C});
} else {
// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
}

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
Expand Down Expand Up @@ -356,7 +370,7 @@ std::tuple<int32_t, infini_train::DataType> DetermineAndCheckVersion(const std::
}
} // namespace

std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -384,7 +398,8 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {
.original_vocab_size = vocab_size,
.n_layer = n_layer,
.n_head = n_head,
.n_embd = n_embd});
.n_embd = n_embd,
.flash = flash});

LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size
<< " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head
Expand Down
3 changes: 2 additions & 1 deletion example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct GPT2Config {
int64_t n_layer = 12;
int64_t n_head = 12;
int64_t n_embd = 768;
bool flash = false;
};

class NewGELU : public infini_train::nn::CloneableModule<NewGELU> {
Expand Down Expand Up @@ -140,7 +141,7 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const;

Expand Down
6 changes: 5 additions & 1 deletion example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
DEFINE_bool(flash, false, "Whether to enable FlashAttention in CausalSelfAttention");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
// precision check
Expand Down Expand Up @@ -168,9 +169,12 @@ void Train(const nn::parallel::Rank &rank) {
// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
model_config.flash = FLAGS_flash;
std::shared_ptr<nn::Module> model = nullptr;
LOG(INFO) << "Rank " << rank.GlobalRank() << ": FLAGS_flash = " << (FLAGS_flash ? "true" : "false");
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": Loading LLaMA3 from LLMC file: " << FLAGS_llmc_filepath;
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath, FLAGS_flash);
} else {
model = std::make_shared<LLaMA3>(model_config);
}
Expand Down
64 changes: 34 additions & 30 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "glog/logging.h"

#include "example/common/utils.h"
#include "infini_train/include/autograd/ScaledDotProductAttention.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/functional.h"
#include "infini_train/include/nn/init.h"
Expand Down Expand Up @@ -207,36 +208,38 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
// TODO(zbl): use kv cache during inference
// if (use_kv_) { ... }

// align n_head in GQA
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
k = RepeatKV(k, n_rep_);
v = RepeatKV(v, n_rep_);

// (B, T, H_local, D) -> (B, H_local, T, D)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// TODO(zbl): support flash attention later
// if (flash_) { ... }

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

// q: (B, H_local, T, D)
// k: (B, H_local, T, D) -> (B, H_local, D, T)
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
std::shared_ptr<Tensor> y = nullptr;
if (config_.flash) {
auto y_flash = std::make_shared<autograd::ScaledDotProductAttention>(
/*attn_mask=*/nullptr, /*dropout_p=*/0, /*is_causal=*/true,
/*scale=*/1.0 / std::sqrt(static_cast<double>(D)), /*enable_gqa=*/true)
->Apply({q, k, v})[0];
y = y_flash->Contiguous()->View({B, T, C_local});
} else {
// align n_head in GQA
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
k = RepeatKV(k, n_rep_);
v = RepeatKV(v, n_rep_);

// (B, T, H_local, D) -> (B, H_local, T, D)
q = q->Transpose(1, 2);
k = k->Transpose(1, 2);
v = v->Transpose(1, 2);

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
auto y = att->Matmul(v);
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
// output projection
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
Expand Down Expand Up @@ -457,7 +460,7 @@ constexpr int32_t kLLaMA3Magic = 20240803;
constexpr int32_t kLLaMA3FP32Version = 3;
} // namespace

std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath, bool flash) {
if (!std::filesystem::exists(filepath)) {
LOG(FATAL) << "File not found: " << filepath;
}
Expand Down Expand Up @@ -496,6 +499,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.rope_theta = rope_theta,
.use_scaled_rope = static_cast<bool>(use_scaled_rope),
.norm_eps = norm_eps,
.flash = flash,
.max_gen_batch_size = max_gen_bs});

// ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ==========
Expand Down
4 changes: 2 additions & 2 deletions example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct LLaMA3Config {

// Inference
bool use_kv = false; // kv cache
bool flash = false; // flash attention
bool flash = false; // enable flash attention path in CausalSelfAttention
int64_t max_gen_batch_size = 4; // max batch size during inference
};

Expand Down Expand Up @@ -179,7 +179,7 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath, bool flash = false);

int GetChunkSize() const { return stage_info_.layer_ranges_per_chunk.size(); }

Expand Down
42 changes: 42 additions & 0 deletions infini_train/include/autograd/ScaledDotProductAttention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include "infini_train/include/autograd/function.h"
#include "infini_train/include/kernels/cuda/flash_attention.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {
class ScaledDotProductAttention : public Function {
public:
static constexpr char kType[] = "ScaledDotProductAttentionFunction";

ScaledDotProductAttention(std::shared_ptr<Tensor> attn_mask = nullptr, int64_t dropout_p = 0,
bool is_causal = false, std::optional<double> scale = std::nullopt,
bool enable_gqa = false)
: Function(kType), attn_mask_(std::move(attn_mask)), dropout_p_(dropout_p), is_causal_(is_causal),
scale_(scale), enable_gqa_(enable_gqa) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
std::shared_ptr<Tensor> attn_mask_;
int64_t dropout_p_ = 0;
bool is_causal_ = false;
std::optional<double> scale_;
bool enable_gqa_ = false;

// Temporary storage for FlashAttentionForwardOutput to be used in SetupContext
// Note: This is defined in infini_train::kernels::cuda namespace
kernels::cuda::FlashAttentionForwardOutput flash_output_;
};
} // namespace infini_train::autograd
27 changes: 27 additions & 0 deletions infini_train/include/kernels/cuda/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include <memory>

namespace infini_train {
class Tensor;
}

namespace infini_train::kernels::cuda {

/**
* FlashAttention Forward Output Structure
*
* This structure holds the output tensors from FlashAttention forward pass.
*
* Args:
* output: Output tensor of shape [batch_size, seq_len_q, num_heads, head_dim]
* logsumexp: Logsumexp tensor for backward pass [batch_size, num_heads, seq_len_q]
* dropout_seed: Dropout seed for backward pass [1]
*/
struct FlashAttentionForwardOutput {
std::shared_ptr<Tensor> output;
std::shared_ptr<Tensor> logsumexp;
std::shared_ptr<Tensor> dropout_seed;
};

} // namespace infini_train::kernels::cuda
Loading