Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,21 @@ def __init__(
else:
# Hugging Face defaults to use_fast to True
use_fast = True
# Phi model's fast tokenizer does not support adding a BOS token, use_fast
# Phi & Baichuan model's fast tokenizer does not support adding a BOS token, use_fast
# should be False
if "phi" in self.cfg.tokenizer_name.lower():
tokenizer_name = self.cfg.tokenizer_name.lower()
if "phi" in tokenizer_name:
use_fast = False
huggingface_token = os.environ.get("HF_TOKEN", None)
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
# add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token,
)
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token,
),
tokenizer,
default_padding_side=default_padding_side,
)
else:
Expand Down
27 changes: 27 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import transformer_lens.utils as utils
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.pretrained.weight_conversions import (
convert_baichuan_weights,
convert_bert_weights,
convert_bloom_weights,
convert_coder_weights,
Expand Down Expand Up @@ -218,6 +219,9 @@
"google-t5/t5-base",
"google-t5/t5-large",
"ai-forever/mGPT",
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan-13B-Chat",
]
"""Official model names for models on HuggingFace."""

Expand Down Expand Up @@ -640,6 +644,9 @@
"google-t5/t5-base": ["t5-base"],
"google-t5/t5-large": ["t5-large"],
"ai-forever/mGPT": ["mGPT"],
"baichuan-inc/Baichuan-7B": ["Baichuan-7B"],
"baichuan-inc/Baichuan-13B-Base": ["Baichuan-13B-Base"],
"baichuan-inc/Baichuan-13B-Chat": ["Baichuan-13B-Chat"],
}
"""Model aliases for models on HuggingFace."""

Expand Down Expand Up @@ -1293,6 +1300,24 @@ def convert_hf_model_config(model_name: str, **kwargs):
"use_attn_scale": False,
"tie_word_embeddings": hf_config.tie_word_embeddings,
}
elif architecture.startswith("Bai"):
cfg_dict = {
"d_model": hf_config.hidden_size,
"d_head": hf_config.hidden_size // hf_config.num_attention_heads,
"n_heads": hf_config.num_attention_heads,
"d_mlp": hf_config.intermediate_size,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": 2048, # Capped due to HF Tokenizer Constraints
"d_vocab": hf_config.vocab_size,
"eps": hf_config.rms_norm_eps,
"trust_remote_code": True,
"act_fn": hf_config.hidden_act,
"initializer_range": hf_config.initializer_range,
"normalization_type": "RMS",
"post_embedding_ln": True,
"positional_embedding_type": "alibi",
"tie_word_embeddings": hf_config.tie_word_embeddings,
}
else:
raise NotImplementedError(f"{architecture} is not currently supported.")
# All of these models use LayerNorm
Expand Down Expand Up @@ -1654,6 +1679,8 @@ def get_pretrained_state_dict(
state_dict = convert_neox_weights(hf_model, cfg)
elif cfg.original_architecture == "LlamaForCausalLM":
state_dict = convert_llama_weights(hf_model, cfg)
elif cfg.original_architecture.startswith("Bai"):
state_dict = convert_baichuan_weights(hf_model, cfg)
elif cfg.original_architecture == "BertForMaskedLM":
state_dict = convert_bert_weights(hf_model, cfg)
elif cfg.original_architecture == "T5ForConditionalGeneration":
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .bert import convert_bert_weights
from .mistral import convert_mistral_weights
from .mixtral import convert_mixtral_weights
from .baichuan import convert_baichuan_weights
from .bloom import convert_bloom_weights
from .coder import convert_coder_weights
from .qwen import convert_qwen_weights
Expand Down
66 changes: 66 additions & 0 deletions transformer_lens/pretrained/weight_conversions/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import einops
import torch

from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def convert_baichuan_weights(baichuan, cfg: HookedTransformerConfig):
state_dict = {}

state_dict["embed.W_E"] = baichuan.model.embed_tokens.weight

assert cfg.d_mlp is not None # keep mypy happy

for l in range(cfg.n_layers):
state_dict[f"blocks.{l}.ln1.w"] = baichuan.model.layers[l].input_layernorm.weight

W = baichuan.model.layers[l].self_attn.W_pack.weight

W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)

W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device
)
state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(
cfg.n_heads,
cfg.d_head,
dtype=cfg.dtype,
device=W_Q.device,
)
state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(
cfg.n_heads,
cfg.d_head,
dtype=cfg.dtype,
device=W_Q.device,
)

W_O = baichuan.model.layers[l].self_attn.o_proj.weight
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
state_dict[f"blocks.{l}.attn.W_O"] = W_O
state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
cfg.d_model, dtype=cfg.dtype, device=W_O.device
)

state_dict[f"blocks.{l}.ln2.w"] = baichuan.model.layers[l].post_attention_layernorm.weight

state_dict[f"blocks.{l}.mlp.W_in"] = baichuan.model.layers[l].mlp.up_proj.weight.T
state_dict[f"blocks.{l}.mlp.W_gate"] = baichuan.model.layers[l].mlp.gate_proj.weight.T
state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=W_O.dtype)

state_dict[f"blocks.{l}.mlp.W_out"] = baichuan.model.layers[l].mlp.down_proj.weight.T
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=W_O.dtype)

state_dict["ln_final.w"] = baichuan.model.norm.weight
state_dict["pos_embed.W_pos"] = baichuan.model.transformer.wpe.weight
state_dict["unembed.W_U"] = baichuan.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=W_O.dtype)

return state_dict