Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
#include "infini_train/include/nn/parallel/rank.h"
Expand Down Expand Up @@ -48,6 +49,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -171,8 +173,11 @@ void Train(const nn::parallel::Rank &rank) {
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.

// FIXME(zbl): set as argument
if (ddp_world_size > 1) {
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
}

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
Expand All @@ -197,7 +202,14 @@ void Train(const nn::parallel::Rank &rank) {
}

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);
// auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);
auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate);
std::shared_ptr<Optimizer> optimizer
= FLAGS_use_distributed_optimizer ? std::make_unique<nn::parallel::DistributedOptimizer>(
optimizer_creator, model->Parameters(),
dynamic_cast<DistributedDataParallel *>(model.get())->param_grad_buffers(),
dynamic_cast<DistributedDataParallel *>(model.get())->bucket_groups(), ddp_pg, ddp_world_size, ddp_rank)
: optimizer_creator(model->Parameters());

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
Expand All @@ -213,13 +225,18 @@ void Train(const nn::parallel::Rank &rank) {
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::Adam>(optimizer),
rank.thread_rank());
pp_rank, optimizer, rank.thread_rank());
}

auto cuda_device = device->IsCUDA() ? dynamic_cast<const CudaDevice *>(device) : nullptr;

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
const bool last_step = step == FLAGS_num_iteration;

if (cuda_device) {
cuda_device->ResetMemPoolHighWatermarks();
}

const auto iter_start = std::chrono::high_resolution_clock::now();

// once in a while evaluate the validation dataset
Expand All @@ -246,7 +263,7 @@ void Train(const nn::parallel::Rank &rank) {
float lossf = 0.0f;
if (pp_world_size == 1) {
// model->Train();
optimizer.ZeroGrad();
optimizer->ZeroGrad();

// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
Expand Down Expand Up @@ -284,7 +301,7 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
}

optimizer.Step();
optimizer->Step();
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -308,10 +325,16 @@ void Train(const nn::parallel::Rank &rank) {
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.IsLastRank()) {
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
size_t used_mb = 0, reserved_mb = 0;
if (cuda_device) {
std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB();
}

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
// FIXME(jym): to support PP
Expand Down
3 changes: 3 additions & 0 deletions infini_train/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class CudaDevice : public Device {

nn::parallel::Rank rank() const override;

void ResetMemPoolHighWatermarks() const;
std::pair<size_t, size_t> GetMemPoolPeakMB() const;

private:
CudaDevice(int8_t index);

Expand Down
27 changes: 25 additions & 2 deletions infini_train/include/nn/parallel/distributed_data_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <memory>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h"
#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
#include "infini_train/include/nn/parallel/reducer.h"

namespace infini_train {
Expand All @@ -14,13 +16,34 @@ namespace infini_train::nn::parallel {

class DistributedDataParallel : public nn::Module {
public:
DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id,
const ReducerOptions &opts = ReducerOptions{});
DistributedDataParallel(std::shared_ptr<nn::Module> module, int thread_rank,
DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig());

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

DistributedDataParallelConfig ddp_config() const { return ddp_config_; }

const std::vector<std::shared_ptr<ParamAndGradBuffer>> &param_grad_buffers() const { return param_grad_buffers_; }

const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups() const { return bucket_groups_; }

private:
void BuildParamAndGradBuffers();
void RegisterBackwardHooks();
void OnGradReady(const std::shared_ptr<Tensor> &param);

private:
std::shared_ptr<Reducer> reducer_ = nullptr;

DistributedDataParallelConfig ddp_config_;
const ProcessGroup *ddp_pg_ = nullptr;

std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers_;
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups_;
std::unordered_map<Tensor *, std::shared_ptr<ParamAndGradBucketGroup>> param_to_bucket_group_;

std::atomic<size_t> num_params_ready_{0};
size_t total_params_{0};
};

} // namespace infini_train::nn::parallel
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#pragma once

#include <limits>

namespace infini_train::nn::parallel {
namespace {
// Default bucket size in alignment with PyTorch
constexpr int kFirstBucketCapMB = 1;
constexpr int kNormalBucketCapMB = 25;
} // namespace

class DistributedDataParallelConfig {
public:
// ======================================================
// Reducer-related args
// Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
// ======================================================
// Max capacity for each bucket(in MB).
size_t first_bucket_cap_mb = kFirstBucketCapMB;
size_t normal_bucket_cap_mb = kNormalBucketCapMB;

// When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy.
bool gradient_as_bucket_view = true;

// Whether to enable gradient bucketing.
bool gradient_bucketing_enabled = true;

// ======================================================
// DistributedOptimizer-related args
// Ref:
// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel_config.py
// ======================================================
// Whether to enable DistributedOptimizer (ZeRO-1 equivalent).
// When set true:
// 1) Gradients/params are managed by ParamAndGradBuffer and reduced in groups.
// 2) The classic DDP reducer path is not used (i.e., disable reducer/bucketing in the DDP sense).
bool use_distributed_optimizer = false;

// Whether to overlap gradient reduce-scatter/all-reduce with backward compute.
// In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready.
bool overlap_grad_reduce = true;

// Whether to overlap parameter all-gather with forward compute.
bool overlap_param_gather = true;

// Whether to average values inside collectives (divide by world size) instead of summing.
bool average_in_collective = true;

// Whether to check NaNs/Infs/unusually large in gradients before collectives.
bool check_for_nan_in_grad = false;
bool check_for_large_grads = false;

// Number of DistributedOptimizer instances.
// Multiple DistOpt is used for building hierarchical collective groups for param/grad.
int num_distributed_optimizer_instances = 1;

// Maximum number of parameters in each ParamAndGradBucket.
// This is distinct from DDP Reducer's MB-based bucket caps.
size_t bucket_size_in_elements = std::numeric_limits<size_t>::max();

// Whether to pad bucket sizes to improve NCCL bus bandwidth utilization.
bool pad_buckets_for_high_nccl_busbw = false;
};
} // namespace infini_train::nn::parallel
52 changes: 52 additions & 0 deletions infini_train/include/nn/parallel/distributed_optimizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include <cstdint>
#include <memory>
#include <unordered_map>
#include <vector>

#include "infini_train/include/nn/parallel/param_and_grad_buffer.h"
#include "infini_train/include/optimizer.h"

namespace infini_train::nn::parallel {

class DistributedOptimizer final : public infini_train::Optimizer {
public:
DistributedOptimizer(OptimizerCreator inner_optimizer_creator,
const std::vector<std::shared_ptr<Tensor>> &full_params,
const std::vector<std::shared_ptr<ParamAndGradBuffer>> &buffers,
const std::vector<std::shared_ptr<ParamAndGradBucketGroup>> &bucket_groups,
const ProcessGroup *dp_pg, size_t dp_world_size, size_t ddp_rank);

void Step() override;

void ZeroGrad(bool set_to_none = true) override;

void StartGradSync();
void FinishGradSync();

void StartParamSync(bool force_sync = false);
void FinishParamSync(bool skip_next_bucket_dispatch = false);

private:
void BuildShardParamsAndBindGrads();

private:
// Inherit from DDP model
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers_;
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups_;

// DP info
const ProcessGroup *dp_pg_;
size_t dp_world_size_;
size_t dp_rank_;

// shard params
std::vector<std::shared_ptr<Tensor>> shard_params_;

// Base optimizer (SGD, Adam and etc.)
OptimizerCreator creator_;
std::shared_ptr<Optimizer> base_optimizer_;
};

} // namespace infini_train::nn::parallel
Loading