feat: extract the common module of Transformer#115
feat: extract the common module of Transformer#115JYMiracle305 wants to merge 1 commit intomasterfrom
Conversation
2ab6ca5 to
e0504d9
Compare
dfdd913 to
d833ec2
Compare
| first_stage.with_submodule(TransformerFirstStage::kWTELayerName, BuildVocabEmbeddingSpec(gpt2_config)) | ||
| .with_submodule(TransformerFirstStage::kWPELayerName, | ||
| BuildPositionEmbeddingSpec(gpt2_config.block_size, gpt2_config.n_embd)); | ||
| spec.with_submodule("first_stage", first_stage); |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| namespace infini_train::nn { | ||
|
|
||
| void ModuleRegistry::Register(std::type_index type, ModuleCreator creator) { registry_[type] = std::move(creator); } |
There was a problem hiding this comment.
确实不允许重复注册,这里加CHECK检查
| auto tok_emb = (*modules_[kWTELayerName])({x1}); | ||
|
|
||
| // Add position embedding only for models that use absolute position encoding | ||
| if (config_.attention_type == AttentionType::kStandard) { |
There was a problem hiding this comment.
这里是否有 WTE 绑定在 attention_type 上。
目前可以维持现状不修改,但语义上似乎是把 spec 负责的任务放到了通用层。目前两个模型可以支持,后续如果遇到例外情况不能在这里加分支,而是要下沉到 spec 里
| // ManualSeed(42); | ||
|
|
||
| LLaMA3Config model_config = LLaMA3Config(); | ||
| nn::TransformerConfig model_config; |
There was a problem hiding this comment.
nn::TransformerConfig model_config; 声明的默认值都是沿用 gpt2 的架构的,use_bias/use_rope 啥的都是按照 gpt2 来的,导致下面 else 分支实际上构造的是个 gpt2 model。
There was a problem hiding this comment.
新增静态初始化方法,在各自main.cc调用对应的初始化方法
|
|
||
| // ========== GPT2 Model Definition ========== | ||
| // Uses LayerNorm, GELU activation, standard multi-head attention | ||
| class GPT2 : public nn::TransformerLayer { |
There was a problem hiding this comment.
Layer/Block 的名字似乎和 megatron 是反着来的,megatron 里面 Layer 代表一个 transformer block,Block 代表一串 transformer blocks;不过这个正确性上没影响,就看怎么称呼了…
另外得考虑下 GPT2 Model 这里直接继承自 TransformerLayer 合适吗,megatron 里面 GPTModel 应该还是一种直接继承自 nn.Module 的存在,然后其类内成员有个 self.decoder 构建为对应的 TransformerBlock(megatron 里的名称,对应这个 PR 里面的 TransformerLayer)对象。
| Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override; | ||
|
|
||
| private: | ||
| AttentionType attention_type_; |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| // Architecture choices | ||
| AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type | ||
| MLPType mlp_type = MLPType::kGELU; // MLP activation type |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| namespace infini_train::nn { | ||
|
|
||
| class RMSNorm : public infini_train::nn::CloneableModule<RMSNorm> { |
There was a problem hiding this comment.
可以讨论下像这种融合算子的 module 有无单独拆出来做单文件的必要;一是可以像 megatron 一样把 rmsnorm/layernorm 的选择逻辑包在一个统一的 norm module 里;二是考虑到之后 flash attn 可能要接进来的话,也会存在算法选择之类的逻辑。我感觉可以拎出来
There was a problem hiding this comment.
这个最初设计是分出来的,现在比较少就先没分,可以先改其他地方,最后拆这些文件
| modules_[kCFcLayerName] = build_module(config, spec.submodules_.at(kCFcLayerName)); | ||
|
|
||
| // For SwiGLU, add second projection | ||
| if (spec.submodules_.count(kCFc2LayerName) > 0) { |
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
|
|
||
| // ========== LLaMA3 Model Definition ========== | ||
| // Uses RMSNorm, SwiGLU activation, GQA attention, RoPE positional encoding | ||
| class LLaMA3 : public nn::TransformerLayer { |
There was a problem hiding this comment.
如果对齐 megatron 的话,gpt2/llama3 本质上都用的是 GPTModel,我感觉这里似乎也不需要额外拆成两个类定义,可以就叫 DecoderOnlyTransformer 之类的?
|
我先提了几个,突然来了好多事情感觉这两天来不及看了。 |
d833ec2 to
9efb498
Compare




本次PR主要内容为抽象出Transformer类模型的构建架构,将GPT2和LLaMA3构建过程统一为一个流程实现。
目录结构
…/core/
├── models/decode_only_transformer/
│ ├── layer_specs.h/.cc # 模型构建函数声明与实现
│ └── model.h # RMSNorm/NewGELU/SwiGLU 等组件声明
└── transformer/
├── spec_utils.h/.cc # ModuleSpec 构建工具函数与模块注册宏
├── transformer_block.h/.cc # TransformerBlock 等基础组件,注册实现
├── transformer_builders.h/.cc # 规格构建器声明与实现 (BuildNormSpec, BuildMLPSpec 等)
├── transformer_config.h # TransformerConfig 配置结构体,替代原GPT2Config和LLaMA3Config
└── transformer_layer.h/.cc # TransformerFirstStage/Chunk/LastStage,替代原GPT2FirstStage/Chunk/LastStage,LLaMA3FirstStage/Chunk/LastStage
核心机制
ModuleSpec数据结构用于声明模块的类型和参数,模块具体实现通过 ModuleRegistry 统一注册,在构建模型时通过build_module() 动态实例化,根据spec关联已注册的实现。