Skip to content

issue/349 - add GLM4 causal LM model support#352

Open
JoeZhang-0x000 wants to merge 1 commit intoInfiniTensor:mainfrom
JoeZhang-0x000:issue/349
Open

issue/349 - add GLM4 causal LM model support#352
JoeZhang-0x000 wants to merge 1 commit intoInfiniTensor:mainfrom
JoeZhang-0x000:issue/349

Conversation

@JoeZhang-0x000
Copy link
Copy Markdown

Summary

  • Add GLM4 model config adapter (csrc/models/glm4/)
  • Add partial rotary support (llama_utils.hpp, llama_attention.*)
  • Add forward_naive + GLM4 post-norm support (llama_decoder_layer.*)
  • Add RoPE algo selection for GLM4 (llama_model.cpp)
  • Add GLM4 special model creation path (rank_worker.cpp)
  • Register "glm4" in config_factory.cpp classic_models list and auto_config.py
  • Add GLM4 weight remapping (gate_up_proj split) in modeling_utils.py

Closes #349
Parent issue: #332

- Add GLM4 model config adapter (csrc/models/glm4/)
- Add partial rotary support (llama_utils.hpp, llama_attention.*)
- Add forward_naive + GLM4 post-norm support (llama_decoder_layer.*)
- Add RoPE algo selection for GLM4 (llama_model.cpp)
- Add GLM4 special model creation path (rank_worker.cpp)
- Register glm4 in config_factory.cpp and auto_config.py
- Add GLM4 weight remapping (gate_up_proj split) in modeling_utils.py
rotary_emb_->forward(q_rope->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true);
rotary_emb_->forward(k_rope->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true);
q_reshaped = q_rope;
k_reshaped = k_rope;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方新建了两个变量q_rope 和k_rope ,他们是q_reshaped和k_reshaped的赋值。

请问rotary_emb_->forward可以直接操作q_reshaped和k_reshaped么,不新建变量可以么

Copy link
Copy Markdown
Collaborator

@pengcheng888 pengcheng888 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请给出测试,正确说话的截图

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原则上讲尽量不要修改和使用llama legacy,那段旧版代码不确定什么时候就给删了。

能摆脱对llama legacy的依赖的话,应该就不需要在config factory, rank worker和auto config中做修改了

@wooway777
Copy link
Copy Markdown
Collaborator

请给出测试,正确说话的截图

image image

Copy link
Copy Markdown
Collaborator

@pengcheng888 pengcheng888 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议参考一下修改点

(1)新增模型不应该修改已有模型代码,不要修改llama_legacy文件中的代码。
(2)请删除config_factory.cpprank_worker.cpp的改动
(3)参考已有代码实现(非llama_legacy文件夹),mlp model causual_lm 应该可以使用现有的模块。
(4)请在glm4文件添加如下文件 glm4_decoder_layer.cpp/hpp + glm4_for_causal_lm.cpp/hpp。
如果有必要添加glm4_attention.cpp/hpp,用来应对rope的修改。
(5)csrc/models/glm4/glm4_for_causal_lm.cpp中,需要定义一个自己的Glm4ForCausalLM类,不要使用nfinilm::models::llama::LlamaForCausalLM。

(6)RoPE类型问题:请增加https://github.com/InfiniTensor/InfiniLM/blob/main/csrc/layers/rotary_embedding/rotary_embedding.cpp 中get_rope函数的功能,在这个函数中处理GPT_J类型和"partial_rotary_factor"超参数。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support GLM4 model

3 participants