Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ DEFINE_int32(
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
Expand Down Expand Up @@ -293,6 +294,7 @@ void Train(const nn::parallel::Rank &rank) {
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
// FIXME(jym): verify gradient accumulation precision
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
Expand Down Expand Up @@ -356,7 +358,7 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
236 changes: 175 additions & 61 deletions example/gpt2/net.cc

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions example/gpt2/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,31 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

std::vector<std::shared_ptr<infini_train::nn::Module>> BuildChunks(int pp_rank) override;

static std::shared_ptr<GPT2> FromPretrained(ModelType model_type);
static std::shared_ptr<GPT2> FromLLMC(const std::string &filepath);

private:
GPT2Config config_;
};

class GPT2Chunk : public infini_train::nn::CloneableModule<GPT2Chunk> {
public:
GPT2Chunk(GPT2 *parent, int layer_begin, int chunk_layers, bool has_embedding, bool has_lm_head,
const GPT2Config &config)
: parent_(parent), layer_begin_(layer_begin), chunk_layers_(chunk_layers), has_embedding_(has_embedding),
has_lm_head_(has_lm_head), config_(config) {}

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
GPT2 *parent_ = nullptr;
int layer_begin_ = 0;
int chunk_layers_ = 0;
bool has_embedding_ = false;
bool has_lm_head_ = false;

GPT2Config config_;
};
6 changes: 4 additions & 2 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ DEFINE_int32(
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, , specified the number of PP stages.");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage.");
// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");

Expand Down Expand Up @@ -270,6 +271,7 @@ void Train(const nn::parallel::Rank &rank) {
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
// FIXME(jym): verify gradient accumulation precision
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
Expand Down Expand Up @@ -333,7 +335,7 @@ int main(int argc, char *argv[]) {
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);
FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
203 changes: 155 additions & 48 deletions example/llama3/net.cc

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions example/llama3/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,31 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

std::vector<std::shared_ptr<infini_train::nn::Module>> BuildChunks(int pp_rank) override;

static std::shared_ptr<LLaMA3> FromPretrained(ModelType model_type);
static std::shared_ptr<LLaMA3> FromLLMC(const std::string &filepath);

private:
LLaMA3Config config_;
};

class LLaMA3Chunk : public infini_train::nn::CloneableModule<LLaMA3Chunk> {
public:
LLaMA3Chunk(LLaMA3 *parent, int layer_begin, int chunk_layers, bool has_embedding, bool has_lm_head,
const LLaMA3Config &config)
: parent_(parent), layer_begin_(layer_begin), chunk_layers_(chunk_layers), has_embedding_(has_embedding),
has_lm_head_(has_lm_head), config_(config){};

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
LLaMA3 *parent_ = nullptr;
int layer_begin_ = 0;
int chunk_layers_ = 0;
bool has_embedding_ = false;
bool has_lm_head_ = false;

LLaMA3Config config_;
};
2 changes: 2 additions & 0 deletions infini_train/include/nn/modules/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class ModuleList : public CloneableModule<ModuleList> {
auto begin() const { return module_list_.begin(); }
auto end() const { return module_list_.end(); }

std::shared_ptr<Module> &operator[](std::size_t idx) { return module_list_.at(idx); }

private:
std::vector<std::shared_ptr<Module>> module_list_;
};
Expand Down
2 changes: 2 additions & 0 deletions infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Module : public std::enable_shared_from_this<Module> {
return 0.0f;
};

virtual std::vector<std::shared_ptr<Module>> BuildChunks(int pp_rank);

virtual void To(const Device *device);

virtual void To(DataType dtype);
Expand Down
2 changes: 2 additions & 0 deletions infini_train/include/nn/parallel/distributed_data_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class DistributedDataParallel : public nn::Module {

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

std::vector<std::shared_ptr<nn::Module>> BuildChunks(int pp_rank) override;

private:
std::shared_ptr<Reducer> reducer_ = nullptr;
};
Expand Down
10 changes: 7 additions & 3 deletions infini_train/include/nn/parallel/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class GlobalEnv {
static GlobalEnv &Instance();

void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size);
int pipeline_parallel_size, int virtual_pipeline_parallel_size);

int nnodes() const;

Expand All @@ -51,6 +51,8 @@ class GlobalEnv {

int pipeline_parallel_size() const;

int virtual_pipeline_parallel_size() const;

Layout layout() const;

private:
Expand All @@ -75,6 +77,7 @@ class GlobalEnv {
int data_parallel_size_ = 1;

int pipeline_parallel_size_ = 1;
int virtual_pipeline_parallel_size_ = 1;

mutable std::mutex mutex_;
bool initialized_ = false;
Expand All @@ -83,9 +86,9 @@ class GlobalEnv {
};

inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size) {
int pipeline_parallel_size, int virtual_pipeline_parallel) {
GlobalEnv::Instance().Init(nthread_per_process, tensor_parallel_size, sequence_parallel_enabled,
pipeline_parallel_size);
pipeline_parallel_size, virtual_pipeline_parallel);
}

inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); }
Expand All @@ -100,6 +103,7 @@ inline int GetSequenceParallelSize() { return GlobalEnv::Instance().sequence_par
inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence_parallel_enabled(); }
inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); }
inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); }
inline int GetVirtualPipelineParallelSize() { return GlobalEnv::Instance().virtual_pipeline_parallel_size(); }

// =========================
// Layout Helper Functions
Expand Down
11 changes: 10 additions & 1 deletion infini_train/include/nn/parallel/pp/pipeline_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ class PipelineSchedule;

extern thread_local int pp_rank;

struct StageInfo {
bool is_first_stage;
bool is_last_stage;

// Layer index ranges for chunks assigned to this pipeline stage.
// Each element is a pair: (inclusive_start_layer, exclusive_end_layer)
std::vector<std::pair<int, int>> layer_ranges_per_chunk;
};

class PipelineParallel : public Module {
public:
PipelineParallel(const std::shared_ptr<nn::Module> module, int num_stages, int num_micro_batches,
Expand All @@ -28,7 +37,7 @@ class PipelineParallel : public Module {
const std::vector<std::shared_ptr<Tensor>> &target, const std::shared_ptr<nn::Module> &loss_fn,
DataType dtype);

static std::tuple<bool, bool, int, int> GetStageInfo(int total_layers, int pp_size, int pp_rank);
static StageInfo GetStageInfo(int total_layers, int pp_size, int pp_rank, int chunks_per_stage = 1);

private:
int num_stages_ = -1;
Expand Down
41 changes: 20 additions & 21 deletions infini_train/include/nn/parallel/pp/pipeline_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,34 @@ class PipelineSchedule {

virtual float StepMicroBatches(const std::vector<std::shared_ptr<Tensor>> &arg_mbs,
const std::vector<std::shared_ptr<Tensor>> &target_mbs,
const std::shared_ptr<nn::Module> &loss_fn, DataType dtype)
= 0;
const std::shared_ptr<nn::Module> &loss_fn, DataType dtype);

std::vector<std::shared_ptr<Tensor>> ReceiveFromPrev();
std::vector<std::shared_ptr<Tensor>> SendToNext(const std::vector<std::shared_ptr<Tensor>> &tensors);
std::vector<std::shared_ptr<Tensor>> ReceiveFromPrev(int peer_rank);
std::vector<std::shared_ptr<Tensor>> SendToNext(const std::vector<std::shared_ptr<Tensor>> &tensors, int peer_rank);

protected:
int num_micro_batches_ = -1;
std::shared_ptr<PipelineStage> stage_ = nullptr;
};

class ScheduleGPipe : public PipelineSchedule {
class PipelineParallelScheduler {
public:
ScheduleGPipe(std::shared_ptr<PipelineStage> stage, int num_stages, int num_micro_batches)
: PipelineSchedule(std::move(stage), num_stages, num_micro_batches){};

float StepMicroBatches(const std::vector<std::shared_ptr<Tensor>> &arg_mbs,
const std::vector<std::shared_ptr<Tensor>> &target_mbs,
const std::shared_ptr<nn::Module> &loss_fn, DataType dtype) override;
};

class Schedule1F1B : public PipelineSchedule {
public:
Schedule1F1B(std::shared_ptr<PipelineStage> stage, int num_stages, int num_micro_batches)
: PipelineSchedule(std::move(stage), num_stages, num_micro_batches){};

float StepMicroBatches(const std::vector<std::shared_ptr<Tensor>> &arg_mbs,
const std::vector<std::shared_ptr<Tensor>> &target_mbs,
const std::shared_ptr<nn::Module> &loss_fn, DataType dtype) override;
struct Task {
int step;
int microbatch_id;
int global_chunk_id;
int local_chunk_idx;
bool is_forward;
int stage_id;
bool is_first_chunk;
bool is_last_chunk;
};

static Task CreateTask(int step, int mb, int global_chunk, int num_stages, int total_chunks, bool is_forward);

static std::vector<Task> GenerateGPipeSchedule(int n, int num_stages, int vpp_size);

static std::vector<Task> GenerateInterleaved1F1BSchedule(int n, int num_stages, int vpp_size);
};

} // namespace infini_train::nn::parallel
4 changes: 3 additions & 1 deletion infini_train/include/nn/parallel/pp/pipeline_stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class PipelineStage {
const std::vector<std::vector<int64_t>> &recv_shape, std::shared_ptr<Optimizer> optimizer,
int device_id);

std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs);
std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs,
int local_chunk_idx = 0);

bool IsFirstStage() const { return stage_index_ == 0; }
bool IsLastStage() const { return stage_index_ == num_stages_ - 1; }
Expand All @@ -39,6 +40,7 @@ class PipelineStage {
int next_rank_ = -1;
const Device *device_ = nullptr;
std::shared_ptr<nn::Module> model_ = nullptr;
std::vector<std::shared_ptr<Module>> chunks_;
std::shared_ptr<Optimizer> optimizer_ = nullptr;
std::vector<std::vector<int64_t>> recv_shape_;
};
Expand Down
5 changes: 5 additions & 0 deletions infini_train/src/nn/modules/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ std::vector<std::shared_ptr<Tensor>> Module::Forward(const std::vector<std::shar
return {};
}

std::vector<std::shared_ptr<Module>> Module::BuildChunks(int pp_rank) {
LOG(FATAL) << "BuildChunks function not implemented for this module";
return {};
}

void Module::To(const Device *device) {
CHECK_NOTNULL(device);
if (device == device_) {
Expand Down
3 changes: 3 additions & 0 deletions infini_train/src/nn/parallel/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,7 @@ DistributedDataParallel::Forward(const std::vector<std::shared_ptr<Tensor>> &inp
return outputs;
}

std::vector<std::shared_ptr<nn::Module>> DistributedDataParallel::BuildChunks(int pp_rank) {
return modules_[kModuleName]->BuildChunks(pp_rank);
}
} // namespace infini_train::nn::parallel
8 changes: 7 additions & 1 deletion infini_train/src/nn/parallel/global.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ GlobalEnv &GlobalEnv::Instance() {
}

void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size) {
int pipeline_parallel_size, int virtual_pipeline_parallel_size) {
std::lock_guard<std::mutex> lock(mutex_);

CHECK(!initialized_) << "Repeated initialization of GlobalEnv!";
Expand All @@ -106,6 +106,7 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
tensor_parallel_size_ = tensor_parallel_size;
sequence_parallel_enabled_ = sequence_parallel_enabled;
pipeline_parallel_size_ = pipeline_parallel_size;
virtual_pipeline_parallel_size_ = virtual_pipeline_parallel_size;
data_parallel_size_ = world_size_ / tensor_parallel_size_ / pipeline_parallel_size_;

layout_.sizes[DP] = data_parallel_size_;
Expand Down Expand Up @@ -171,6 +172,11 @@ int GlobalEnv::pipeline_parallel_size() const {
return pipeline_parallel_size_;
}

int GlobalEnv::virtual_pipeline_parallel_size() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return virtual_pipeline_parallel_size_;
}

Layout GlobalEnv::layout() const {
CHECK(initialized_) << "GlobalEnv is not initialized!";
return layout_;
Expand Down
Loading