Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,29 @@ link_infini_train_exe(test_precision_check)
add_executable(test_lora test/lora/test_lora.cc)
link_infini_train_exe(test_lora)

add_executable(test_lr_scheduler test/lr_scheduler/test_lr_scheduler.cc)
link_infini_train_exe(test_lr_scheduler)

add_executable(test_constant_lr test/lr_scheduler/test_constant_lr.cc)
link_infini_train_exe(test_constant_lr)

add_executable(test_step_lr test/lr_scheduler/test_step_lr.cc)
link_infini_train_exe(test_step_lr)

add_executable(test_linear_lr test/lr_scheduler/test_linear_lr.cc)
link_infini_train_exe(test_linear_lr)

add_executable(test_lambda_lr test/lr_scheduler/test_lambda_lr.cc)
link_infini_train_exe(test_lambda_lr)

add_executable(test_sequential_lr test/lr_scheduler/test_sequential_lr.cc)
link_infini_train_exe(test_sequential_lr)

add_executable(test_chained_lr test/lr_scheduler/test_chained_lr.cc)
link_infini_train_exe(test_chained_lr)

add_executable(test_training_lr_scheduler test/lr_scheduler/test_training_lr_scheduler.cc)
link_infini_train_exe(test_training_lr_scheduler)

add_executable(test_lr_scheduler_validation test/lr_scheduler/test_lr_scheduler_validation.cc)
link_infini_train_exe(test_lr_scheduler_validation)
35 changes: 31 additions & 4 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -54,8 +55,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_double(learning_rate, 1e-4, "Peak learning rate.");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -98,6 +105,8 @@ constexpr char kDeviceCPU[] = "cpu";
constexpr char kDeviceCUDA[] = "cuda";
constexpr char kDtypeFP32[] = "float32";
constexpr char kDtypeBF16[] = "bfloat16";
const std::unordered_set<std::string> kSupportedLRDecayStyles
= {"none", "constant", "linear", "cosine", "inverse-square-root"};

//
const std::unordered_map<std::string, GPT2Config> kModelToConfigs = {
Expand All @@ -118,6 +127,8 @@ const std::unordered_map<std::string, GPT2::ModelType> kStrToModelType = {
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(lr_decay_style,
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -310,6 +321,16 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(params_to_optimize);
}

const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
TrainingLRSchedulerConfig sched_config;
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
sched_config.lr_decay_style = FLAGS_lr_decay_style;
sched_config.lr_decay_iters = lr_decay_iters;
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand Down Expand Up @@ -353,6 +374,7 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
float lossf = 0.0f;
// model->Train();
if (pp_world_size == 1) {
Expand Down Expand Up @@ -396,6 +418,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -405,6 +430,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -420,11 +448,10 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
35 changes: 31 additions & 4 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/core/runtime/device_guard.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/device.h"
#include "infini_train/include/lr_scheduler.h"
#include "infini_train/include/nn/lora/lora_utils.h"
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -53,8 +54,14 @@ DEFINE_uint32(num_iteration, 10, "number of iterations to run");
DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_double(learning_rate, 1e-5, "Peak learning rate.");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// lr scheduler
DEFINE_double(min_lr, 0.0, "Minimum learning rate.");
DEFINE_string(lr_decay_style, "constant", "LR decay style: none|constant|linear|cosine|inverse-square-root");
DEFINE_int64(lr_warmup_iters, 0, "Number of linear warmup iterations.");
DEFINE_double(lr_warmup_init, 0.0, "Initial learning rate at the start of warmup.");
DEFINE_int64(lr_decay_iters, 0, "Number of iterations to decay LR over (0 = num_iteration).");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -93,11 +100,15 @@ constexpr char kDeviceCPU[] = "cpu";
constexpr char kDeviceCUDA[] = "cuda";
constexpr char kDtypeFP32[] = "float32";
constexpr char kDtypeBF16[] = "bfloat16";
const std::unordered_set<std::string> kSupportedLRDecayStyles
= {"none", "constant", "linear", "cosine", "inverse-square-root"};
} // namespace

DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
DEFINE_validator(device,
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
DEFINE_validator(lr_decay_style,
[](const char *, const std::string &value) { return kSupportedLRDecayStyles.contains(value); });

void Train(const nn::parallel::Rank &rank) {
using namespace nn::parallel;
Expand Down Expand Up @@ -282,6 +293,16 @@ void Train(const nn::parallel::Rank &rank) {
optimizer = optimizer_creator(params_to_optimize);
}

const int64_t lr_decay_iters = FLAGS_lr_decay_iters > 0 ? FLAGS_lr_decay_iters : FLAGS_num_iteration;
TrainingLRSchedulerConfig sched_config;
sched_config.lr = static_cast<float>(FLAGS_learning_rate);
sched_config.min_lr = static_cast<float>(FLAGS_min_lr);
sched_config.lr_decay_style = FLAGS_lr_decay_style;
sched_config.lr_decay_iters = lr_decay_iters;
sched_config.lr_warmup_iters = FLAGS_lr_warmup_iters;
sched_config.lr_warmup_init = static_cast<float>(FLAGS_lr_warmup_init);
auto scheduler = CreateLRScheduler(optimizer, sched_config);

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
Expand Down Expand Up @@ -322,6 +343,7 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif

const float current_lr = scheduler ? scheduler->GetLR() : static_cast<float>(FLAGS_learning_rate);
float lossf = 0.0f;
if (pp_world_size == 1) {
// model->Train();
Expand Down Expand Up @@ -365,6 +387,9 @@ void Train(const nn::parallel::Rank &rank) {
}

optimizer->Step();
if (scheduler) {
scheduler->Step();
}
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -374,6 +399,9 @@ void Train(const nn::parallel::Rank &rank) {
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
if (scheduler) {
scheduler->Step();
}
}

if (ddp_world_size > 1) {
Expand All @@ -389,11 +417,10 @@ void Train(const nn::parallel::Rank &rank) {
if (rank.IsLastRank()) {
size_t used_mb = 0, reserved_mb = 0;
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
step + 1, FLAGS_num_iteration, lossf, current_lr, duration_us / 1e3f, tps,
used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
Expand Down
Loading
Loading