issue/348 - add Baichuan causal LM model support#351
issue/348 - add Baichuan causal LM model support#351JoeZhang-0x000 wants to merge 1 commit intoInfiniTensor:mainfrom
Conversation
- Add Baichuan model config adapter (csrc/models/baichuan/) - Register baichuan in config_factory.cpp and auto_config.py - Add Baichuan weight remapping (W_pack -> q/k/v_proj) in modeling_utils.py - Update test_infer.py for Baichuan tokenization and chat prompt handling
There was a problem hiding this comment.
这样修改是不是意味着只能做单轮推理,bench、精度测试、服务都无法使用?
如何做通用我们也需要花时间看一眼
| it->second(model_config); | ||
| } else { | ||
| std::vector<std::string> classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b"}; | ||
| std::vector<std::string> classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b", "baichuan"}; |
There was a problem hiding this comment.
新增模型不需要修改这里,请删除csrc/config/config_factory.cpp文件的修改
| @@ -47,4 +47,7 @@ def from_pretrained(model_path): | |||
| cfg.model_type = "minicpmv" | |||
| return cfg | |||
|
|
|||
There was a problem hiding this comment.
新增模型不需要修改这里,请删除python/infinilm/auto_config.py文件的修改
|
|
||
| namespace infinilm::models::baichuan { | ||
|
|
||
| std::shared_ptr<infinilm::config::ModelConfig> create_baichuan_model_config( |
There was a problem hiding this comment.
需要明确给出BaichuanForCausalLM的定义: 添加 using BaichuanForCausalLM = infinilm::models::llama::LlamaForCausalLM,
|
|
||
| INFINILM_REGISTER_CAUSAL_LM_MODEL( | ||
| baichuan, | ||
| infinilm::models::llama::LlamaForCausalLM, |
There was a problem hiding this comment.
移除csrc/models/baichuan/baichuan_for_causal_lm.cpp文件中#include "../llama/llama_for_causal_lm.hpp"。
将infinilm::models::llama::LlamaForCausalLM修改为infinilm::models::baichuan ::BaichuanForCausalLM
|
|
||
| namespace { | ||
|
|
||
| #ifndef USE_CLASSIC_LLAMA |
There was a problem hiding this comment.
新增的模型不需要放到USE_CLASSIC_LLAMA宏中。请删除csrc/models/baichuan/baichuan_for_causal_lm.cpp文件中的 USE_CLASSIC_LLAMA
| return new_sd | ||
|
|
||
|
|
||
| def maybe_remap_weights(state_dict, model): |
There was a problem hiding this comment.
这个函数改名叫adjust_state_dict或者就叫remap_weights吧,maybe略显随意
| } | ||
|
|
||
|
|
||
| def _split_first_dim(tensor, sizes, name): |
There was a problem hiding this comment.
将这个_split_first_dim 函数放到 _remap_baichuan_weights函数里面吧,作为_remap_baichuan_weights函数专用的。


Summary
csrc/models/baichuan/)"baichuan"inconfig_factory.cppclassic_models list andauto_config.pyW_pack→q/k/v_proj) inmodeling_utils.pytest_infer.pyfor Baichuan tokenization and chat prompt handlingCloses #348
Parent issue: #332