Skip to content

feat: extract the common module of Transformer#115

Open
JYMiracle305 wants to merge 1 commit intomasterfrom
feat/transformer
Open

feat: extract the common module of Transformer#115
JYMiracle305 wants to merge 1 commit intomasterfrom
feat/transformer

Conversation

@JYMiracle305
Copy link
Contributor

@JYMiracle305 JYMiracle305 commented Mar 13, 2026

本次PR主要内容为抽象出Transformer类模型的构建架构,将GPT2和LLaMA3构建过程统一为一个流程实现。

  1. 目录结构
    …/core/
    ├── models/decode_only_transformer/
    │ ├── layer_specs.h/.cc # 模型构建函数声明与实现
    │ └── model.h # RMSNorm/NewGELU/SwiGLU 等组件声明
    └── transformer/
    ├── spec_utils.h/.cc # ModuleSpec 构建工具函数与模块注册宏
    ├── transformer_block.h/.cc # TransformerBlock 等基础组件,注册实现
    ├── transformer_builders.h/.cc # 规格构建器声明与实现 (BuildNormSpec, BuildMLPSpec 等)
    ├── transformer_config.h # TransformerConfig 配置结构体,替代原GPT2Config和LLaMA3Config
    └── transformer_layer.h/.cc # TransformerFirstStage/Chunk/LastStage,替代原GPT2FirstStage/Chunk/LastStage,LLaMA3FirstStage/Chunk/LastStage

  2. 核心机制
    ModuleSpec数据结构用于声明模块的类型和参数,模块具体实现通过 ModuleRegistry 统一注册,在构建模型时通过build_module() 动态实例化,根据spec关联已注册的实现。

@JYMiracle305
Copy link
Contributor Author

JYMiracle305 commented Mar 16, 2026

单机多卡:
GPT2:
image

LLaMA3:
image

多机训练结果:
GPT2:
image

LLaMA3:
image

@JYMiracle305 JYMiracle305 requested review from Chamberlain0w0, chen2021673 and kilinchange and removed request for Chamberlain0w0 March 16, 2026 05:42
@JYMiracle305 JYMiracle305 force-pushed the feat/transformer branch 3 times, most recently from dfdd913 to d833ec2 Compare March 16, 2026 08:10
first_stage.with_submodule(TransformerFirstStage::kWTELayerName, BuildVocabEmbeddingSpec(gpt2_config))
.with_submodule(TransformerFirstStage::kWPELayerName,
BuildPositionEmbeddingSpec(gpt2_config.block_size, gpt2_config.n_embd));
spec.with_submodule("first_stage", first_stage);

This comment was marked as resolved.


namespace infini_train::nn {

void ModuleRegistry::Register(std::type_index type, ModuleCreator creator) { registry_[type] = std::move(creator); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Register的时候需要去重吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实不允许重复注册,这里加CHECK检查

auto tok_emb = (*modules_[kWTELayerName])({x1});

// Add position embedding only for models that use absolute position encoding
if (config_.attention_type == AttentionType::kStandard) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否有 WTE 绑定在 attention_type 上。
目前可以维持现状不修改,但语义上似乎是把 spec 负责的任务放到了通用层。目前两个模型可以支持,后续如果遇到例外情况不能在这里加分支,而是要下沉到 spec 里

// ManualSeed(42);

LLaMA3Config model_config = LLaMA3Config();
nn::TransformerConfig model_config;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn::TransformerConfig model_config; 声明的默认值都是沿用 gpt2 的架构的,use_bias/use_rope 啥的都是按照 gpt2 来的,导致下面 else 分支实际上构造的是个 gpt2 model。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增静态初始化方法,在各自main.cc调用对应的初始化方法


// ========== GPT2 Model Definition ==========
// Uses LayerNorm, GELU activation, standard multi-head attention
class GPT2 : public nn::TransformerLayer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Layer/Block 的名字似乎和 megatron 是反着来的,megatron 里面 Layer 代表一个 transformer block,Block 代表一串 transformer blocks;不过这个正确性上没影响,就看怎么称呼了…

另外得考虑下 GPT2 Model 这里直接继承自 TransformerLayer 合适吗,megatron 里面 GPTModel 应该还是一种直接继承自 nn.Module 的存在,然后其类内成员有个 self.decoder 构建为对应的 TransformerBlock(megatron 里的名称,对应这个 PR 里面的 TransformerLayer)对象。

Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
AttentionType attention_type_;

This comment was marked as resolved.


// Architecture choices
AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type
MLPType mlp_type = MLPType::kGELU; // MLP activation type

This comment was marked as resolved.


namespace infini_train::nn {

class RMSNorm : public infini_train::nn::CloneableModule<RMSNorm> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以讨论下像这种融合算子的 module 有无单独拆出来做单文件的必要;一是可以像 megatron 一样把 rmsnorm/layernorm 的选择逻辑包在一个统一的 norm module 里;二是考虑到之后 flash attn 可能要接进来的话,也会存在算法选择之类的逻辑。我感觉可以拎出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个最初设计是分出来的,现在比较少就先没分,可以先改其他地方,最后拆这些文件

modules_[kCFcLayerName] = build_module(config, spec.submodules_.at(kCFcLayerName));

// For SwiGLU, add second projection
if (spec.submodules_.count(kCFc2LayerName) > 0) {

This comment was marked as resolved.


// ========== LLaMA3 Model Definition ==========
// Uses RMSNorm, SwiGLU activation, GQA attention, RoPE positional encoding
class LLaMA3 : public nn::TransformerLayer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果对齐 megatron 的话,gpt2/llama3 本质上都用的是 GPTModel,我感觉这里似乎也不需要额外拆成两个类定义,可以就叫 DecoderOnlyTransformer 之类的?

@Chamberlain0w0
Copy link
Contributor

我先提了几个,突然来了好多事情感觉这两天来不及看了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants