Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
build/
build_script/
.cache/
.vscode/

Expand Down
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ link_infini_train_exe(test_hook)
add_executable(test_precision_check test/hook/test_precision_check.cc)
link_infini_train_exe(test_precision_check)

add_executable(test_sdpa test/hook/test_sdpa.cc)
link_infini_train_exe(test_sdpa)

add_executable(test_sdpa_cuda test/hook/test_sdpa_cuda.cc)
link_infini_train_exe(test_sdpa_cuda)

add_executable(test_lora test/lora/test_lora.cc)
link_infini_train_exe(test_lora)

1 change: 1 addition & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
DEFINE_bool(overfit_single_batch, true, "overfit just one batch of data");
// memory management
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if using parallel training mode");
DEFINE_bool(flash, false, "enable flash attention");
// parallel
DEFINE_int32(
nthread_per_process, 1,
Expand Down
15 changes: 15 additions & 0 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <tuple>
#include <vector>

#include "gflags/gflags.h"
#include "glog/logging.h"

#include "example/common/utils.h"
Expand All @@ -32,6 +33,8 @@
using namespace infini_train;
namespace nn = infini_train::nn;

DECLARE_bool(flash);

namespace {
constexpr int kRandomSeed = 42;

Expand Down Expand Up @@ -105,6 +108,18 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

if (FLAGS_flash) {
// FlashAttention path (placeholder): use unified SDPA API.
// q/k/v: (B, h_l, T, Dh)
auto y_flash = nn::function::ScaledDotProductAttention(q, k, v, /*attn_mask=*/nullptr,
/*dropout_p=*/0.0, /*is_causal=*/true);
// y: (B, h_l, T, Dh)
auto y = y_flash;
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});
y = (*modules_[kCProjLayerName])({y})[0];
return {y};
}

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
Expand Down
1 change: 1 addition & 0 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
DEFINE_bool(overfit_single_batch, true, "overfit just one batch of data");
// memory management
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if using parallel training mode");
DEFINE_bool(flash, false, "enable flash attention");
// parallel
DEFINE_int32(
nthread_per_process, 1,
Expand Down
13 changes: 13 additions & 0 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <unordered_map>
#include <vector>

#include "gflags/gflags.h"
#include "glog/logging.h"

#include "example/common/utils.h"
Expand All @@ -30,6 +31,8 @@
using namespace infini_train;
namespace nn = infini_train::nn;

DECLARE_bool(flash);

namespace {
constexpr int kRandomSeed = 42;

Expand Down Expand Up @@ -220,6 +223,16 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
// TODO(zbl): support flash attention later
// if (flash_) { ... }

if (FLAGS_flash) {
// FlashAttention path (placeholder): use unified SDPA API.
// q/k/v: (B, H_local, T, D)
auto y = nn::function::ScaledDotProductAttention(q, k, v, /*attn_mask=*/mask,
/*dropout_p=*/0.0, /*is_causal=*/false);
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
y = (*modules_[kCProjLayerName])({y})[0];
return {y};
}

// manual implementation of attention
// this materializes the large (T,T) matrix for all the queries and keys

Expand Down
44 changes: 44 additions & 0 deletions infini_train/include/autograd/sdpa.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class ScaledDotProductAttention : public Function {
public:
static constexpr char kType[] = "ScaledDotProductAttentionFunction";

ScaledDotProductAttention(const std::shared_ptr<Tensor> &attn_mask, double dropout_p, bool is_causal,
std::optional<double> scale, bool enable_gqa)
: Function(kType), attn_mask_(attn_mask), dropout_p_(dropout_p), is_causal_(is_causal), scale_(scale),
enable_gqa_(enable_gqa) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;

std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
std::shared_ptr<Tensor> attn_mask_;
double dropout_p_ = 0.0;
bool is_causal_ = false;
std::optional<double> scale_;
bool enable_gqa_ = false;

double scale_value_ = 1.0;
int64_t n_rep_ = 1;
bool has_mask_ = false;
};

} // namespace infini_train::autograd
19 changes: 19 additions & 0 deletions infini_train/include/nn/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <memory>
#include <optional>
#include <vector>

namespace infini_train {
Expand Down Expand Up @@ -183,4 +184,22 @@ std::shared_ptr<Tensor> Stack(const std::vector<std::shared_ptr<Tensor>> &inputs
// Concatenation of the input tensors.
std::shared_ptr<Tensor> Concat(const std::vector<std::shared_ptr<Tensor>> &inputs, int64_t dim = 0);

// Scaled dot-product attention (PyTorch-like API).
//
// Expected tensor layout (current InfiniTrain examples):
// query/key/value: (B, H, T, D)
//
// Semantics:
// - If attn_mask is provided: positions where mask is non-zero are masked.
// - If is_causal is true: applies a causal (upper-triangular) mask.
// - If scale is not provided: uses 1 / sqrt(D).
// - If enable_gqa is true and key/value have fewer heads than query, key/value
// will be repeated along the head dimension.
std::shared_ptr<Tensor> ScaledDotProductAttention(const std::shared_ptr<Tensor> &query,
const std::shared_ptr<Tensor> &key,
const std::shared_ptr<Tensor> &value,
const std::shared_ptr<Tensor> &attn_mask = nullptr,
double dropout_p = 0.0, bool is_causal = false,
std::optional<double> scale = std::nullopt, bool enable_gqa = false);

} // namespace infini_train::nn::function
Loading