-
Notifications
You must be signed in to change notification settings - Fork 65
issue/348 - add Baichuan causal LM model support #351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| #include "baichuan_for_causal_lm.hpp" | ||
| #include "../llama/llama_for_causal_lm.hpp" | ||
| #include "../models_registry.hpp" | ||
|
|
||
| namespace infinilm::models::baichuan { | ||
|
|
||
| std::shared_ptr<infinilm::config::ModelConfig> create_baichuan_model_config( | ||
| std::shared_ptr<infinilm::config::ModelConfig> model_config) { | ||
| const std::string &model_type = model_config->get<std::string>("model_type"); | ||
| if ("baichuan" != model_type) { | ||
| throw std::runtime_error( | ||
| "infinilm::models::baichuan::create_baichuan_model_config: model_type is not baichuan"); | ||
| } | ||
|
|
||
| nlohmann::json &config_json = model_config->get_config_json(); | ||
|
|
||
| if (!config_json.contains("num_key_value_heads")) { | ||
| config_json["num_key_value_heads"] = model_config->get<size_t>("num_attention_heads"); | ||
| } | ||
|
|
||
| if (!config_json.contains("head_dim")) { | ||
| config_json["head_dim"] = model_config->get<size_t>("hidden_size") | ||
| / model_config->get<size_t>("num_attention_heads"); | ||
| } | ||
|
|
||
| if (!config_json.contains("rope_theta")) { | ||
| config_json["rope_theta"] = 10000.0; | ||
| } | ||
|
|
||
| if (!config_json.contains("attention_bias")) { | ||
| config_json["attention_bias"] = false; | ||
| } | ||
|
|
||
| return model_config; | ||
| } | ||
|
|
||
| } // namespace infinilm::models::baichuan | ||
|
|
||
| namespace { | ||
|
|
||
| #ifndef USE_CLASSIC_LLAMA | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 新增的模型不需要放到USE_CLASSIC_LLAMA宏中。请删除csrc/models/baichuan/baichuan_for_causal_lm.cpp文件中的 USE_CLASSIC_LLAMA |
||
|
|
||
| INFINILM_REGISTER_CAUSAL_LM_MODEL( | ||
| baichuan, | ||
| infinilm::models::llama::LlamaForCausalLM, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 移除csrc/models/baichuan/baichuan_for_causal_lm.cpp文件中#include "../llama/llama_for_causal_lm.hpp"。 将infinilm::models::llama::LlamaForCausalLM修改为infinilm::models::baichuan ::BaichuanForCausalLM |
||
| infinilm::models::baichuan::create_baichuan_model_config); | ||
|
|
||
| #endif | ||
|
|
||
| } // namespace | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| #pragma once | ||
|
|
||
| #include "../../layers/common_modules.hpp" | ||
| #include <memory> | ||
|
|
||
| namespace infinilm::models::baichuan { | ||
|
|
||
| std::shared_ptr<infinilm::config::ModelConfig> create_baichuan_model_config( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要明确给出BaichuanForCausalLM的定义: 添加 using BaichuanForCausalLM = infinilm::models::llama::LlamaForCausalLM, |
||
| std::shared_ptr<infinilm::config::ModelConfig> model_config); | ||
|
|
||
| } // namespace infinilm::models::baichuan | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这样修改是不是意味着只能做单轮推理,bench、精度测试、服务都无法使用? |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -47,4 +47,7 @@ def from_pretrained(model_path): | |
| cfg.model_type = "minicpmv" | ||
| return cfg | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 新增模型不需要修改这里,请删除python/infinilm/auto_config.py文件的修改 |
||
| elif config_dict["model_type"] == "baichuan": | ||
| return LlamaConfig(**config_dict) | ||
|
|
||
| raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import os | ||
| import re | ||
| from typing import Dict, Union | ||
| import time | ||
| import torch | ||
|
|
@@ -41,6 +42,48 @@ def parse_dtype(dtype_str: str): | |
| } | ||
|
|
||
|
|
||
| def _split_first_dim(tensor, sizes, name): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 将这个_split_first_dim 函数放到 _remap_baichuan_weights函数里面吧,作为_remap_baichuan_weights函数专用的。 |
||
| if tensor.dim() not in (1, 2): | ||
| raise ValueError(f"Cannot split {name} with shape {tensor.shape}") | ||
| return torch.split(tensor, sizes, dim=0) | ||
|
|
||
|
|
||
| def _remap_baichuan_weights(state_dict, hf_config): | ||
| hidden_size = hf_config.get("hidden_size", 4096) | ||
| num_heads = hf_config.get("num_attention_heads", 32) | ||
| per_head_dim = num_heads * (hidden_size // num_heads) | ||
| new_sd = {} | ||
|
|
||
| for key, tensor in state_dict.items(): | ||
| wpack_match = re.match(r"(.*\.)W_pack\.(weight|bias)", key) | ||
| if not wpack_match: | ||
| new_sd[key] = tensor | ||
| continue | ||
|
|
||
| prefix = wpack_match.group(1) | ||
| suffix = wpack_match.group(2) | ||
| q, k, v = _split_first_dim( | ||
| tensor, | ||
| [per_head_dim, per_head_dim, tensor.shape[0] - 2 * per_head_dim], | ||
| "W_pack", | ||
| ) | ||
| new_sd[f"{prefix}q_proj.{suffix}"] = q | ||
| new_sd[f"{prefix}k_proj.{suffix}"] = k | ||
| new_sd[f"{prefix}v_proj.{suffix}"] = v | ||
| return new_sd | ||
|
|
||
|
|
||
| def maybe_remap_weights(state_dict, model): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数改名叫adjust_state_dict或者就叫remap_weights吧,maybe略显随意 |
||
| if not hasattr(model, "hf_config"): | ||
| return state_dict | ||
|
|
||
| hf_config = model.hf_config | ||
| model_type = hf_config.get("model_type", "") | ||
| if model_type == "baichuan": | ||
| return _remap_baichuan_weights(state_dict, hf_config) | ||
| return state_dict | ||
|
|
||
|
|
||
| def check_parameters(model_keys: list, already_loaded_keys: list): | ||
| model_keys = set(model_keys) | ||
| already_loaded_keys = set(already_loaded_keys) | ||
|
|
@@ -165,6 +208,7 @@ def load_model_state_dict_by_file( | |
| model_param = load_state_dict( | ||
| file_path, device=torch_device, dtype=torch_dtype | ||
| ) | ||
| model_param = maybe_remap_weights(model_param, model) | ||
| already_loaded_keys.extend(model_param.keys()) | ||
|
|
||
| # --------------------------------------------------------- # | ||
|
|
@@ -181,6 +225,8 @@ def load_model_state_dict_by_file( | |
| file_path = os.path.join(model_path, "pytorch_model.bin") | ||
| model_params = torch.load(file_path, weights_only=True, map_location="cpu") | ||
|
|
||
| model_params = maybe_remap_weights(model_params, model) | ||
|
|
||
| model_param_infini = {} | ||
| for key in model_params.keys(): | ||
| model_param_infini[key] = infinicore.from_torch( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
新增模型不需要修改这里,请删除csrc/config/config_factory.cpp文件的修改