Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ docs/build
docs/source/generated
**.orig
.venv

49 changes: 32 additions & 17 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __init__(
)

self.cfg = HookedTransformerConfig.unwrap(cfg)

if tokenizer is not None:
self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
elif self.cfg.tokenizer_name is not None:
Expand All @@ -161,13 +160,18 @@ def __init__(
if "phi" in self.cfg.tokenizer_name.lower():
use_fast = False
huggingface_token = os.environ.get("HF_TOKEN", "")
add_bos_token = self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
]
self.set_tokenizer(
AutoTokenizer.from_pretrained(
self.cfg.tokenizer_name,
add_bos_token=True,
trust_remote_code=self.cfg.trust_remote_code,
use_fast=use_fast,
token=huggingface_token if len(huggingface_token) > 0 else None,
add_bos_token=add_bos_token,
),
default_padding_side=default_padding_side,
)
Expand Down Expand Up @@ -734,7 +738,14 @@ def set_tokenizer(
# tokenizers like LlamaTokenizer are different when bos token is automatically/manually
# prepended, and add_bos_token cannot be dynamically controlled after initialization
# (https://github.com/huggingface/transformers/issues/25886).
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
if self.cfg.original_architecture not in [
"OlmoForCausalLM",
"OlmoeForCausalLM",
"Olmo2ForCausalLM",
]:
tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
else:
tokenizer_with_bos = tokenizer
self.tokenizer = tokenizer_with_bos
self.tokenizer.padding_side = default_padding_side

Expand Down Expand Up @@ -1798,18 +1809,18 @@ def fold_layer_norm(
if not self.cfg.final_rms and fold_biases:
# Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
# pre unembed.
state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
state_dict["unembed.b_U"] = state_dict["unembed.b_U"] + (
state_dict["unembed.W_U"] * state_dict["ln_final.b"][:, None]
).sum(dim=-2)
del state_dict[f"ln_final.b"]
del state_dict["ln_final.b"]

state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
del state_dict[f"ln_final.w"]
state_dict["unembed.W_U"] = state_dict["unembed.W_U"] * state_dict["ln_final.w"][:, None]
del state_dict["ln_final.w"]

if center_weights:
# Center the weights that read in from the LayerNormPre
state_dict[f"unembed.W_U"] -= einops.reduce(
state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
state_dict["unembed.W_U"] -= einops.reduce(
state_dict["unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
)

return state_dict
Expand All @@ -1821,13 +1832,17 @@ def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
W_out. This is done by subtracting the mean of the weights from the weights themselves. This
is done in-place. See fold_layer_norm for more details.
"""
state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
-1, keepdim=True
)
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
print("Not centering embedding weights for Olmo2ForCausalLM")
pass # should not because input of attn of 1st layer is not normed
else:
state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
-1, keepdim=True
)
if self.cfg.positional_embedding_type != "rotary":
state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
"pos_embed.W_pos"
].mean(-1, keepdim=True)
for l in range(self.cfg.n_layers):
state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
f"blocks.{l}.attn.W_O"
Expand Down
4 changes: 2 additions & 2 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class HookedTransformerConfig:
NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that
affects the rate of change between low and high-frequency interpolation strategies.
Defaults to 8.0.


norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer.
"""

n_layers: int
Expand Down Expand Up @@ -264,6 +263,7 @@ class HookedTransformerConfig:
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
NTK_original_ctx_len: int = 8192
norm_topk_prob: bool = False

def __post_init__(self):
if self.n_heads == -1:
Expand Down
48 changes: 44 additions & 4 deletions transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ def __init__(
# will be overwritten by the child T5Attention class
self.has_relative_attention_bias = False

if (
self.cfg.original_architecture == "OlmoeForCausalLM"
or self.cfg.original_architecture == "Olmo2ForCausalLM"
):
self.q_norm = RMSNorm(self.cfg, self.cfg.d_model)
k_norm_dim = (
self.cfg.d_model
if self.cfg.original_architecture == "Olmo2ForCausalLM"
else self.cfg.d_head * self.cfg.n_key_value_heads
)
self.k_norm = RMSNorm(self.cfg, k_norm_dim)

@property
def OV(self) -> FactoredMatrix:
"""
Expand Down Expand Up @@ -209,6 +221,32 @@ def forward(

q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

# OLMoE uses QK-norm.
if (
self.cfg.original_architecture == "OlmoeForCausalLM"
or self.cfg.original_architecture == "Olmo2ForCausalLM"
):
q = einops.rearrange(
self.q_norm(
einops.rearrange(
q,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=q.shape[2],
)
k = einops.rearrange(
self.k_norm(
einops.rearrange(
k,
"batch pos head_index d_head -> batch pos (head_index d_head)",
)
),
"batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head",
head_index=k.shape[2],
)

if past_kv_cache_entry is not None:
# Appends the new keys and values to the cached values, and automatically updates the cache
kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1)
Expand Down Expand Up @@ -244,9 +282,10 @@ def forward(
)

# Take the last query_ctx positions so it also works with past_kv_cache
attn_scores += self.alibi[
:, -query_ctx:, :key_ctx
] # [batch, head_index, query_pos, key_pos]
if self.alibi is not None: # Add None check
attn_scores += self.alibi[
:, -query_ctx:, :key_ctx
] # [batch, head_index, query_pos, key_pos]
elif self.cfg.positional_embedding_type == "relative_positional_bias":
if position_bias is None:
if self.has_relative_attention_bias:
Expand All @@ -260,7 +299,8 @@ def forward(
device=attn_scores.device,
)

attn_scores += position_bias
if position_bias is not None: # Add None check
attn_scores += position_bias
if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
Expand Down
3 changes: 2 additions & 1 deletion transformer_lens/components/mlps/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def forward(
# both are [batch, pos, experts_per_token]
weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
weights /= weights.sum(dim=-1, keepdim=True)
if self.cfg.norm_topk_prob:
weights /= weights.sum(dim=-1, keepdim=True)
expert_indices = self.hook_expert_indices(expert_indices)
weights = weights.to(x.dtype)

Expand Down
41 changes: 28 additions & 13 deletions transformer_lens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,37 @@ def forward(
key_input = attn_in
value_input = attn_in

attn_out = (
# hook the residual stream states that are used to calculate the
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.
self.attn(
query_input=self.ln1(query_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
key_input=self.ln1(key_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
value_input=self.ln1(value_input),
if self.cfg.original_architecture == "Olmo2ForCausalLM":
attn_out = self.attn(
query_input=query_input,
key_input=key_input,
value_input=value_input,
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
) # [batch, pos, d_model]
else:
attn_out = (
# hook the residual stream states that are used to calculate the
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.
self.attn(
query_input=self.ln1(query_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
key_input=self.ln1(key_input)
+ (0.0 if shortformer_pos_embed is None else shortformer_pos_embed),
value_input=self.ln1(value_input),
past_kv_cache_entry=past_kv_cache_entry,
attention_mask=attention_mask,
)
) # [batch, pos, d_model]
if self.cfg.use_normalization_before_and_after:
# If we use LayerNorm both before and after, then apply the second LN after the layer
# and before the hook. We do it before the hook so hook_attn_out captures "that which
# is added to the residual stream"
attn_out = self.ln1_post(attn_out)
attn_out = self.hook_attn_out(attn_out)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
attn_out = self.ln1(attn_out)

if resid_pre.device != attn_out.device:
resid_pre = resid_pre.to(attn_out.device)
Expand All @@ -182,8 +193,12 @@ def forward(
mlp_in = (
resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone())
)
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
if self.cfg.original_architecture == "Olmo2ForCausalLM":
mlp_out = self.apply_mlp(mlp_in)
mlp_out = self.ln2(mlp_out)
else:
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
elif self.cfg.parallel_attn_mlp:
# Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
Expand Down
Loading
Loading