Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
* text=auto
*.c text eol=lf
*.cc text eol=lf
*.cpp text eol=lf
*.cu text eol=lf
*.h text eol=lf
*.hpp text eol=lf
*.py text eol=lf
*.sh text eol=lf
*.bash text eol=lf
CMakeLists.txt text eol=lf
.gitignore text eol=lf
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@ build/
*.log
*.report.rank*
*.records.log.rank*

#------modify-start------------------------------------------
# Local sanity-check datasets (not part of repo)
tmp_data/
#---------modify-end-----------------------------------------
tmp/

# Generated Flash SDPA benchmark outputs
docs/flash_sdpa/logs/
docs/flash_sdpa/env/
docs/flash_sdpa/report_*.md
docs/flash_sdpa/summary_*.csv
44 changes: 38 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ endif()

if(USE_CUDA)
add_compile_definitions(USE_CUDA=1)

#------modify-start------------------------------------------
# CMake may fail to auto-detect nvcc / default architectures if CUDA is not in PATH.
# Pin nvcc path and a reasonable default arch for A100 (sm_80).
if(NOT DEFINED CMAKE_CUDA_COMPILER AND EXISTS "/usr/local/cuda/bin/nvcc")
set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc")
endif()
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 80)
endif()
#---------modify-end-----------------------------------------

enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
include_directories(${CUDAToolkit_INCLUDE_DIRS})
Expand All @@ -90,10 +102,34 @@ if(USE_CUDA)
PUBLIC
glog
CUDA::cudart
#------modify-start------------------------------------------
CUDA::nvrtc
#---------modify-end-----------------------------------------
CUDA::cublas
CUDA::cuda_driver
)

#------modify-start------------------------------------------
# cuDNN + cudnn-frontend (header-only) for fused SDPA backend
find_library(CUDNN_LIBRARY cudnn HINTS /lib/x86_64-linux-gnu /usr/lib/x86_64-linux-gnu)
if(NOT CUDNN_LIBRARY)
message(FATAL_ERROR "cuDNN (libcudnn.so) not found")
endif()
find_path(
CUDNN_FRONTEND_INCLUDE_DIR
cudnn_frontend.h
HINTS
${PROJECT_SOURCE_DIR}/third_party/cudnn_frontend/include
/usr/include
/usr/local/include
)
if(NOT CUDNN_FRONTEND_INCLUDE_DIR)
message(FATAL_ERROR "cudnn_frontend.h not found. Install cudnn-frontend or clone it under third_party/cudnn_frontend/include")
endif()
target_link_libraries(infini_train_cuda_kernels PUBLIC ${CUDNN_LIBRARY})
target_include_directories(infini_train_cuda_kernels PUBLIC ${CUDNN_FRONTEND_INCLUDE_DIR})
#---------modify-end-----------------------------------------

if(USE_NCCL)
message(STATUS "Add USE_NCCL, use NCCL with CUDA")
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
Expand Down Expand Up @@ -196,11 +232,7 @@ set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BIN

# Tests
add_executable(test_hook test/hook/test_hook.cc)
link_infini_train_exe(test_hook)
target_link_libraries(test_hook infini_train)

add_executable(test_precision_check test/hook/test_precision_check.cc)
link_infini_train_exe(test_precision_check)

add_executable(test_lora test/lora/test_lora.cc)
link_infini_train_exe(test_lora)

target_link_libraries(test_precision_check infini_train)
62 changes: 8 additions & 54 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
Expand Down Expand Up @@ -75,19 +74,14 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
//------modify-start------------------------------------------
DEFINE_bool(flash, false, "enable fused scaled-dot-product attention (BF16 only)");
//---------modify-end-----------------------------------------
// precision check
DEFINE_string(
precision_check, "",
"precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH");

// LoRA parameters
DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)");
DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor");
DEFINE_string(lora_target_modules, "c_attn,c_proj",
"LoRA target modules (comma-separated: c_attn,c_proj,c_fc,c_fc2,mlp.c_proj)");
DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training");
DEFINE_string(lora_load_path, "", "Path to load LoRA weights from");

using namespace infini_train;

namespace {
Expand Down Expand Up @@ -189,7 +183,6 @@ void Train(const nn::parallel::Rank &rank) {
// init the model, either from scratch or from OpenAI pretrained checkpoint
GPT2Config model_config;
std::shared_ptr<nn::Module> model = nullptr;

if (!FLAGS_llmc_filepath.empty()) {
model = GPT2::FromLLMC(FLAGS_llmc_filepath);
} else if (kModelToConfigs.count(FLAGS_model)) {
Expand All @@ -203,29 +196,6 @@ void Train(const nn::parallel::Rank &rank) {

utils::PrecisionChecker::BuildNameMap(model.get());

// Get chunk size before wrapping with LoRA (needed for PipelineParallel)
auto gpt2_model = std::dynamic_pointer_cast<GPT2>(model);
CHECK(gpt2_model) << "GPT2 example expects GPT2 model.";

// Apply LoRA using GetLoRAModel (in-place injection)
bool lora_enabled = FLAGS_lora_rank > 0;
if (lora_enabled) {
nn::lora::LoRAConfig lora_config{FLAGS_lora_rank, static_cast<float>(FLAGS_lora_alpha), 0.0f,
nn::lora::ParseLoRATargetModules(FLAGS_lora_target_modules)};

// GetLoRAModel: in-place injection, modifies module tree directly
model = nn::lora::GetLoRAModel(model, lora_config);

// Load LoRA weights if specified
if (!FLAGS_lora_load_path.empty()) {
LOG(INFO) << "Loading LoRA weights from: " << FLAGS_lora_load_path;
nn::lora::LoadLoRAWeights(model, FLAGS_lora_load_path);
}

// Print LoRA summary
nn::lora::PrintLoRASummary(model, rank.GlobalRank());
}

// select the data type
// TODO(lzm): change to solely rely on the weight file info for determining the dtype when autocast is supported
DataType dtype;
Expand All @@ -239,24 +209,15 @@ void Train(const nn::parallel::Rank &rank) {

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

// Create optimizer - use GetLoRAParameters if LoRA is enabled
std::vector<std::shared_ptr<Tensor>> params_to_optimize;
if (lora_enabled) {
params_to_optimize = nn::lora::GetLoRAParameters(model);
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters";
} else {
params_to_optimize = model->Parameters();
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters";
}

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, device, gpt2_model->GetChunkSize());
model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, device,
std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
Expand Down Expand Up @@ -304,10 +265,10 @@ void Train(const nn::parallel::Rank &rank) {
auto model_chunks = (pp_world_size > 1)
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
: std::vector<std::shared_ptr<nn::Module>>{model};
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, params_to_optimize,
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, model->Parameters(),
model_chunks, ddp_world_size, ddp_rank);
} else {
optimizer = optimizer_creator(params_to_optimize);
optimizer = optimizer_creator(model->Parameters());
}

auto train_iter = train_loader.begin();
Expand Down Expand Up @@ -436,13 +397,6 @@ void Train(const nn::parallel::Rank &rank) {
}
}
}

// Save LoRA weights if enabled and path specified
if (lora_enabled && !FLAGS_lora_save_path.empty()) {
LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path;
nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path);
}

#ifdef PROFILE_MODE
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("gpt2.records.log");
Expand Down
30 changes: 25 additions & 5 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
#include <tuple>
#include <vector>

#include "gflags/gflags.h"
#include "glog/logging.h"
//------modify-start------------------------------------------
// NOTE: --flash is a global gflags option defined in main.cc.
DECLARE_bool(flash);
//---------modify-end-----------------------------------------

#include "example/common/utils.h"
#include "infini_train/include/device.h"
Expand Down Expand Up @@ -105,6 +110,26 @@ CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Ten
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

//------modify-start------------------------------------------
// FlashAttention path (BF16 on CUDA only).
if (FLAGS_flash && q->GetDevice().type() == Device::DeviceType::kCUDA && q->Dtype() == DataType::kBFLOAT16) {
// cudnn SDPA expects a standard (B, H, T, D) layout; enforce contiguous strides.
q = q->Contiguous();
k = k->Contiguous();
v = v->Contiguous();

// (B, h_l, T, D) -> (B, h_l, T, D)
auto y = nn::function::ScaledDotProductAttention(q, k, v, /*attn_mask=*/nullptr, /*dropout_p=*/0.0,
/*is_causal=*/true);
// (B, h_l, T, D) -> (B, T, h_l, D) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});

// output projection
y = (*modules_[kCProjLayerName])({y})[0];
return {y};
}
//---------modify-end-----------------------------------------

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
Expand Down Expand Up @@ -307,11 +332,6 @@ GPT2::GPT2(const GPT2Config &config)
modules_[kTransformerLayerName] = std::make_shared<nn::ModuleDict>(std::move(transformer));

// FIXME(jym): Assigning the parameter values of wte to LMHead, which is not real tying operation
// TODO: Implement real GPT-2 weight tying: make lm_head.weight share the exact same Parameter/Tensor (same
// shared_ptr/storage) as transformer.wte.weight (pointer aliasing, not value copy), and ensure the tie is applied
// after loading weights so it won't be overwritten. Also fix GPT2::FromLLMC() loading logic to respect weight tying
// (do not create/load a separate lm_head.weight tensor; load once into the tied weight) so parameter counting
// matches PyTorch/PEFT.
if (nn::parallel::global::GetPipelineParallelSize() == 1) {
// https://paperswithcode.com/method/weight-tying
*mutable_module(kTransformerLayerName)
Expand Down
52 changes: 8 additions & 44 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
Expand Down Expand Up @@ -73,16 +72,13 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
//------modify-start------------------------------------------
DEFINE_bool(flash, false, "enable fused scaled-dot-product attention (BF16 only)");
//---------modify-end-----------------------------------------
// precision check
DEFINE_string(
precision_check, "",
"precision check config: level=N,format=simple|table,output_md5=true|false,output_path=PATH,baseline=PATH");
// LoRA parameters
DEFINE_int32(lora_rank, 0, "LoRA rank (0 = disabled)");
DEFINE_double(lora_alpha, 16.0, "LoRA alpha scaling factor");
DEFINE_string(lora_target_modules, "c_attn,c_proj,c_fc,c_fc2", "LoRA target modules (comma-separated)");
DEFINE_string(lora_save_path, "", "Path to save LoRA weights after training");
DEFINE_string(lora_load_path, "", "Path to load LoRA weights from");

using namespace infini_train;

Expand Down Expand Up @@ -168,6 +164,9 @@ void Train(const nn::parallel::Rank &rank) {
// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
//------modify-start------------------------------------------
model_config.flash = FLAGS_flash;
//---------modify-end-----------------------------------------
std::shared_ptr<nn::Module> model = nullptr;
if (!FLAGS_llmc_filepath.empty()) {
model = LLaMA3::FromLLMC(FLAGS_llmc_filepath);
Expand All @@ -179,25 +178,6 @@ void Train(const nn::parallel::Rank &rank) {

utils::PrecisionChecker::BuildNameMap(model.get());

// Apply LoRA using GetLoRAModel (in-place injection)
bool lora_enabled = FLAGS_lora_rank > 0;
if (lora_enabled) {
nn::lora::LoRAConfig lora_config{FLAGS_lora_rank, static_cast<float>(FLAGS_lora_alpha), 0.0f,
nn::lora::ParseLoRATargetModules(FLAGS_lora_target_modules)};

// GetLoRAModel: in-place injection, modifies module tree directly
model = nn::lora::GetLoRAModel(model, lora_config);

// Load LoRA weights if specified
if (!FLAGS_lora_load_path.empty()) {
LOG(INFO) << "Loading LoRA weights from: " << FLAGS_lora_load_path;
nn::lora::LoadLoRAWeights(model, FLAGS_lora_load_path);
}

// Print LoRA summary
nn::lora::PrintLoRASummary(model, rank.GlobalRank());
}

LOG(INFO) << "Rank " << rank.GlobalRank() << ": Model loaded to device.";

DataType dtype;
Expand Down Expand Up @@ -263,23 +243,14 @@ void Train(const nn::parallel::Rank &rank) {
auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate);
std::shared_ptr<Optimizer> optimizer = nullptr;

std::vector<std::shared_ptr<Tensor>> params_to_optimize;
if (lora_enabled) {
params_to_optimize = nn::lora::GetLoRAParameters(model);
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " LoRA parameters";
} else {
params_to_optimize = model->Parameters();
LOG(INFO) << "Optimizing " << params_to_optimize.size() << " model parameters";
}

if (FLAGS_use_distributed_optimizer) {
auto model_chunks = (pp_world_size > 1)
? *(dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks())
: std::vector<std::shared_ptr<nn::Module>>{model};
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, params_to_optimize,
optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, model->Parameters(),
model_chunks, ddp_world_size, ddp_rank);
} else {
optimizer = optimizer_creator(params_to_optimize);
optimizer = optimizer_creator(model->Parameters());
}

auto train_iter = train_loader.begin();
Expand Down Expand Up @@ -405,13 +376,6 @@ void Train(const nn::parallel::Rank &rank) {
}
}
}

// Save LoRA weights if enabled and path specified
if (lora_enabled && !FLAGS_lora_save_path.empty()) {
LOG(INFO) << "Saving LoRA weights to: " << FLAGS_lora_save_path;
nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path);
}

#ifdef PROFILE_MODE
Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("llama3.records.log");
Expand Down
Loading
Loading