Skip to content

Conversation

@JYMiracle305
Copy link
Contributor

No description provided.

@JYMiracle305 JYMiracle305 force-pushed the add_1F1B branch 3 times, most recently from 496bbfd to 7108a12 Compare December 16, 2025 14:54
@JYMiracle305 JYMiracle305 force-pushed the add_1F1B branch 2 times, most recently from 3726518 to 9af4751 Compare December 22, 2025 09:04
@JYMiracle305
Copy link
Contributor Author

JYMiracle305 commented Dec 22, 2025

新增超参数virtual_pipeline_parallel(vpp_size),表示PP场景对stage进行虚拟切分的块数,PP场景将模型切分成pp_size * vpp_size块,分配到对应的设备上;重构后统一不同调度策略对上层的接口,构造调度器PipelineParallelScheduler时根据不同策略填充任务Task表,任务表中保存子任务(关联chunk、microbatch和当前属于正/反向),训练时上层调用StepMicroBatches,StepMicroBatches内部遍历任务表。

virtual_pipeline_parallel为1时,调度表示如下:
image

virtual_pipeline_parallel大于1时,调度表示如下:
image

float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype);
LOG(INFO) << "=== Schedule Table ===";
LOG(INFO) << "n=" << n << ", stages=" << num_stages << ", vpp=" << vpp_size
<< ", total_chunks=" << total_global_chunks;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了增加可读性,用 format 拼字符串吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

struct StageInfo {
bool is_first_stage;
bool is_last_stage;
std::vector<std::pair<int, int>> layer_chunks;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个注释说明一下这个 vector 里存的是每个 chunk 包含的 layer 的起始位置吧,以及建议改个更直观的名字,比如 chunk_layer_ranges。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


std::vector<std::shared_ptr<nn::Module>> chunk_blocks;
int current_index = 0;
for (auto it = h_layers->begin(); it != h_layers->end(); ++it, ++current_index) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

给 ModuleList 类型重载一个索引操作吧,这里直接用索引获取对应 layer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
};

class GPT2Chunk {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPT2Chunk 继承 Module,

class GPT2Chunk : public Module {
public:
  GPT2Chunk(
    GPT2* parent,
    int layer_begin,
    int layer_end,
    bool has_embedding,
    bool has_lm_head
  );

  std::vector<std::shared_ptr<Tensor>>
  Forward(const std::vector<std::shared_ptr<Tensor>>& x) override;

private:
  GPT2* parent_ = nullptr;
  int layer_begin_ = 0;
  int layer_end_ = 0;
  bool has_embedding_ = false;
  bool has_lm_head_ = false;
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个定义其实对 Transformer 结构都类似,感觉可以提出来成 class TransformerChunk : public Module,然后把定义放到 pp 的文件夹,在各自的 net.cc 里面再定义一个 class GPT2 : public TransformerChunk,仅需要 override 一下 Forward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

}

std::tuple<bool, bool, int, int> PipelineParallel::GetStageInfo(int total_layers, int pp_size, int pp_rank) {
StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int chunks_per_stage) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pp_rank 还是从 net.cc 传进来,尽量控制 thread_local 变量使用的范围。

使用 thread_local 变量存储 pp_rank/tp_rank 的写法只是一个临时方案,因为线程如果再起子线程不会继承这些变量,所以这是一个不安全的方式,框架里其他地方都尽可能使用 device 里存储的 rank 数据结构获取这些信息,但是在模型初始化的地方 device 尚未创建,所以仅在此处这样做;为了规避 thread_local 大量存在带来的不安全性,后续我们需要开发线程池接管框架所有新起的线程,统一管理把这些 thread_local 变量继承给需要的线程。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

void BuildChunks();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BuildChunks 返回 stage 切分后得到的所有 chunk,在构造 PipelineStage 时调用 module 的 BuildChunks 方法,将所有 chunk 存在 PipelineStage 里。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

return model_->Forward(inputs);
std::vector<std::shared_ptr<Tensor>> PipelineStage::ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs,
int local_chunk_idx) {
return model_->ForwardChunk(local_chunk_idx, inputs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接通过 local_chunk_idx 索引获取 stage 存储的 chunk,调用 chunk 的 Forward 方法。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
};

class GPT2Chunk {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个定义其实对 Transformer 结构都类似,感觉可以提出来成 class TransformerChunk : public Module,然后把定义放到 pp 的文件夹,在各自的 net.cc 里面再定义一个 class GPT2 : public TransformerChunk,仅需要 override 一下 Forward

std::vector<std::shared_ptr<infini_train::Tensor>>
GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
int pp_rank = nn::parallel::pp_rank;
void GPT2::BuildChunks() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

针对 Transformer 模型的话,BuildChunks 也可以合并,gpt2/llama 仅是一个 pos_emb 的区别,加个 if 判断就可以

@JYMiracle305 JYMiracle305 force-pushed the add_1F1B branch 3 times, most recently from 0f5628b to aeb8ee0 Compare December 25, 2025 04:59
if (tp_world_size > 1) {
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank()));
tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank());
Copy link
Collaborator

@kilinchange kilinchange Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多机时需要用 global rank 获取通信组

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

int tp_rank = 0;
if (tp_world_size > 1) {
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
nn::parallel::GetTensorParallelProcessGroupName(device->rank().thread_rank()));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GlobalRank, 这个文件里其他地方也是,除了 main.cc 里需要传递 device_id 时用 thread_rank,其他地方需要获取通信组时都要传 GlobalRank

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

auto [is_first_stage, is_last_stage, layer_chunks]
= nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, vpp_size);
// ========== layer to chunk ==========
std::unordered_map<int, bool> owned_layers;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉没必要用 map,用 vector 就行,查起来还更快
std::vector owned_layers(n_layer, false)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@JYMiracle305 JYMiracle305 force-pushed the add_1F1B branch 2 times, most recently from f8b086c to c22da40 Compare December 26, 2025 03:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants