Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/config/config_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ std::shared_ptr<infinilm::config::ModelConfig> ConfigFactory::createConfig(const
if (it != config_map.end()) {
it->second(model_config);
} else {
std::vector<std::string> classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b"};
std::vector<std::string> classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b", "baichuan"};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增模型不需要修改这里,请删除csrc/config/config_factory.cpp文件的修改

const std::string &model_type = model_config->get<std::string>("model_type");
if (std::find(classic_models.begin(), classic_models.end(), model_type) == classic_models.end()) {
throw std::invalid_argument("infinilm::config::ConfigFactory::createConfig: Unsupported model config type: " + model_type);
Expand Down
50 changes: 50 additions & 0 deletions csrc/models/baichuan/baichuan_for_causal_lm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "baichuan_for_causal_lm.hpp"
#include "../llama/llama_for_causal_lm.hpp"
#include "../models_registry.hpp"

namespace infinilm::models::baichuan {

std::shared_ptr<infinilm::config::ModelConfig> create_baichuan_model_config(
std::shared_ptr<infinilm::config::ModelConfig> model_config) {
const std::string &model_type = model_config->get<std::string>("model_type");
if ("baichuan" != model_type) {
throw std::runtime_error(
"infinilm::models::baichuan::create_baichuan_model_config: model_type is not baichuan");
}

nlohmann::json &config_json = model_config->get_config_json();

if (!config_json.contains("num_key_value_heads")) {
config_json["num_key_value_heads"] = model_config->get<size_t>("num_attention_heads");
}

if (!config_json.contains("head_dim")) {
config_json["head_dim"] = model_config->get<size_t>("hidden_size")
/ model_config->get<size_t>("num_attention_heads");
}

if (!config_json.contains("rope_theta")) {
config_json["rope_theta"] = 10000.0;
}

if (!config_json.contains("attention_bias")) {
config_json["attention_bias"] = false;
}

return model_config;
}

} // namespace infinilm::models::baichuan

namespace {

#ifndef USE_CLASSIC_LLAMA
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的模型不需要放到USE_CLASSIC_LLAMA宏中。请删除csrc/models/baichuan/baichuan_for_causal_lm.cpp文件中的 USE_CLASSIC_LLAMA


INFINILM_REGISTER_CAUSAL_LM_MODEL(
baichuan,
infinilm::models::llama::LlamaForCausalLM,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

移除csrc/models/baichuan/baichuan_for_causal_lm.cpp文件中#include "../llama/llama_for_causal_lm.hpp"。

将infinilm::models::llama::LlamaForCausalLM修改为infinilm::models::baichuan ::BaichuanForCausalLM

infinilm::models::baichuan::create_baichuan_model_config);

#endif

} // namespace
11 changes: 11 additions & 0 deletions csrc/models/baichuan/baichuan_for_causal_lm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include "../../layers/common_modules.hpp"
#include <memory>

namespace infinilm::models::baichuan {

std::shared_ptr<infinilm::config::ModelConfig> create_baichuan_model_config(
Copy link
Copy Markdown
Collaborator

@pengcheng888 pengcheng888 May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要明确给出BaichuanForCausalLM的定义: 添加 using BaichuanForCausalLM = infinilm::models::llama::LlamaForCausalLM,

std::shared_ptr<infinilm::config::ModelConfig> model_config);

} // namespace infinilm::models::baichuan
59 changes: 44 additions & 15 deletions examples/test_infer.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样修改是不是意味着只能做单轮推理,bench、精度测试、服务都无法使用?
如何做通用我们也需要花时间看一眼

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
import argparse
import json
import sys
import time
import os
Expand All @@ -22,6 +23,37 @@
_PAGED_KV_BLOCK_SIZE = 256


def _get_baichuan_role_token_ids(model_path: str) -> tuple[int, int]:
user_token_id = 195
assistant_token_id = 196
generation_config_path = os.path.join(model_path, "generation_config.json")
if os.path.exists(generation_config_path):
with open(generation_config_path, "r") as f:
generation_config = json.load(f)
user_token_id = int(generation_config.get("user_token_id", user_token_id))
assistant_token_id = int(
generation_config.get("assistant_token_id", assistant_token_id)
)
return user_token_id, assistant_token_id


def _encode_baichuan_chat_prompts(
prompts: list[str],
tokenizer: AutoTokenizer,
model_path: str,
max_length: int,
) -> list[list[int]]:
user_token_id, assistant_token_id = _get_baichuan_role_token_ids(model_path)
max_content_length = max(0, max_length - 2)
input_ids_list = []
for prompt in prompts:
content_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(content_ids) > max_content_length:
content_ids = content_ids[-max_content_length:]
input_ids_list.append([user_token_id, *content_ids, assistant_token_id])
return input_ids_list


def test(
prompts: str | list[str],
model_path,
Expand Down Expand Up @@ -104,7 +136,10 @@ def test(
updated_prompts.append(prompt)
prompts = updated_prompts

if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None:
used_chat_template = (
hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
)
if used_chat_template:
input_contents = [
tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": prompt}],
Expand Down Expand Up @@ -139,20 +174,14 @@ def test(
else:
raise ValueError(f"Unsupported multimodal model_type: {model.model_type}")
else:
if hasattr(tokenizer, "batch_encode_plus"):
input_ids_list = tokenizer.batch_encode_plus(input_contents)["input_ids"]
elif hasattr(tokenizer, "_encode_plus"):
input_ids_list = tokenizer._encode_plus(input_contents)["input_ids"]
else:
input_ids_list = tokenizer(input_contents)[
"input_ids"
] # List: [[1, 1128, 526, 366, 29892]]

# input_ids_list = tokenizer.batch_encode_plus(input_contents)[
# "input_ids"
# ] # List: [[1, 1128, 526, 366, 29892]]
if version.parse(transformers.__version__) < version.parse("5.0.0"):
# Ideally this is solved by upgrading transformers. However, doing so causes version mismatch between transformers and mlu pytorch on devices with Phytium CPU. So a branch is temporarily used.
if model.model_type == "baichuan" and not used_chat_template:
input_ids_list = _encode_baichuan_chat_prompts(
prompts,
tokenizer,
model_path,
max_length=2048,
)
elif version.parse(transformers.__version__) < version.parse("5.0.0"):
input_ids_list = [
tokenizer.encode_plus(
text, truncation=True, max_length=2048, add_special_tokens=True
Expand Down
3 changes: 3 additions & 0 deletions python/infinilm/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,7 @@ def from_pretrained(model_path):
cfg.model_type = "minicpmv"
return cfg

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增模型不需要修改这里,请删除python/infinilm/auto_config.py文件的修改

elif config_dict["model_type"] == "baichuan":
return LlamaConfig(**config_dict)

raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
46 changes: 46 additions & 0 deletions python/infinilm/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from typing import Dict, Union
import time
import torch
Expand Down Expand Up @@ -41,6 +42,48 @@ def parse_dtype(dtype_str: str):
}


def _split_first_dim(tensor, sizes, name):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将这个_split_first_dim 函数放到 _remap_baichuan_weights函数里面吧,作为_remap_baichuan_weights函数专用的。

if tensor.dim() not in (1, 2):
raise ValueError(f"Cannot split {name} with shape {tensor.shape}")
return torch.split(tensor, sizes, dim=0)


def _remap_baichuan_weights(state_dict, hf_config):
hidden_size = hf_config.get("hidden_size", 4096)
num_heads = hf_config.get("num_attention_heads", 32)
per_head_dim = num_heads * (hidden_size // num_heads)
new_sd = {}

for key, tensor in state_dict.items():
wpack_match = re.match(r"(.*\.)W_pack\.(weight|bias)", key)
if not wpack_match:
new_sd[key] = tensor
continue

prefix = wpack_match.group(1)
suffix = wpack_match.group(2)
q, k, v = _split_first_dim(
tensor,
[per_head_dim, per_head_dim, tensor.shape[0] - 2 * per_head_dim],
"W_pack",
)
new_sd[f"{prefix}q_proj.{suffix}"] = q
new_sd[f"{prefix}k_proj.{suffix}"] = k
new_sd[f"{prefix}v_proj.{suffix}"] = v
return new_sd


def maybe_remap_weights(state_dict, model):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数改名叫adjust_state_dict或者就叫remap_weights吧,maybe略显随意

if not hasattr(model, "hf_config"):
return state_dict

hf_config = model.hf_config
model_type = hf_config.get("model_type", "")
if model_type == "baichuan":
return _remap_baichuan_weights(state_dict, hf_config)
return state_dict


def check_parameters(model_keys: list, already_loaded_keys: list):
model_keys = set(model_keys)
already_loaded_keys = set(already_loaded_keys)
Expand Down Expand Up @@ -165,6 +208,7 @@ def load_model_state_dict_by_file(
model_param = load_state_dict(
file_path, device=torch_device, dtype=torch_dtype
)
model_param = maybe_remap_weights(model_param, model)
already_loaded_keys.extend(model_param.keys())

# --------------------------------------------------------- #
Expand All @@ -181,6 +225,8 @@ def load_model_state_dict_by_file(
file_path = os.path.join(model_path, "pytorch_model.bin")
model_params = torch.load(file_path, weights_only=True, map_location="cpu")

model_params = maybe_remap_weights(model_params, model)

model_param_infini = {}
for key in model_params.keys():
model_param_infini[key] = infinicore.from_torch(
Expand Down