Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8ee5167
removed wrap function
bryce13950 Jun 11, 2025
312a847
added deep seek architecture
bryce13950 Jun 11, 2025
8f50a22
registered deep seek
bryce13950 Jun 11, 2025
5be70ee
Merge branch 'dev-3.x' into model-deepseek
bryce13950 Jun 16, 2025
0c6e9be
ran format
bryce13950 Jun 16, 2025
549860c
fixed typing
bryce13950 Jun 16, 2025
5095d30
updated loading
bryce13950 Jun 17, 2025
b7daa18
Merge branch 'dev-3.x' into model-deepseek
bryce13950 Jun 19, 2025
07b5613
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 5, 2025
fd9fd43
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 12, 2025
aba7d6c
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 15, 2025
5e03527
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 15, 2025
3ccefea
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 15, 2025
7daa8ec
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 20, 2025
8fa5030
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 22, 2025
a9b9c9c
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Aug 26, 2025
dc71ff5
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 5, 2025
d7533a3
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 6, 2025
7dea9f9
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 7, 2025
a836a44
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 10, 2025
b3c9be3
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 10, 2025
57cc140
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 12, 2025
e6fc57e
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 12, 2025
0e149bb
Merge remote-tracking branch 'origin/dev-3.x' into model-deepseek
bryce13950 Sep 12, 2025
312eb1d
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 15, 2025
74ba068
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 16, 2025
265682d
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 16, 2025
813fc9e
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 16, 2025
14de0e4
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 16, 2025
ef2a68f
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 16, 2025
e7311ba
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 17, 2025
8cdbb1d
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Oct 23, 2025
430f714
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Nov 12, 2025
0724732
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Nov 12, 2025
3ef8cda
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Nov 12, 2025
7d9ca71
Merge remote-tracking branch 'origin/dev-3.x-folding' into model-deep…
bryce13950 Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformer_lens.model_bridge.supported_architectures import (
BertArchitectureAdapter,
BloomArchitectureAdapter,
DeepseekArchitectureAdapter,
Gemma1ArchitectureAdapter,
Gemma2ArchitectureAdapter,
Gemma3ArchitectureAdapter,
Expand Down Expand Up @@ -35,6 +36,7 @@
SUPPORTED_ARCHITECTURES = {
"BertForMaskedLM": BertArchitectureAdapter,
"BloomForCausalLM": BloomArchitectureAdapter,
"DeepseekV3ForCausalLM": DeepseekArchitectureAdapter,
"GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
"Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
"Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from transformer_lens.model_bridge.supported_architectures.bloom import (
BloomArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.deepseek import (
DeepseekArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.gemma1 import (
Gemma1ArchitectureAdapter,
)
Expand Down Expand Up @@ -77,6 +80,7 @@
__all__ = [
"BertArchitectureAdapter",
"BloomArchitectureAdapter",
"DeepseekArchitectureAdapter",
"Gemma1ArchitectureAdapter",
"Gemma2ArchitectureAdapter",
"Gemma3ArchitectureAdapter",
Expand Down
70 changes: 70 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""DeepSeek architecture adapter."""

from typing import Any

from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.conversion_utils.conversion_steps import (
WeightConversionSet,
)
from transformer_lens.model_bridge.generalized_components import (
AttentionBridge,
BlockBridge,
EmbeddingBridge,
LayerNormBridge,
MLPBridge,
MoEBridge,
UnembeddingBridge,
)


class DeepseekArchitectureAdapter(ArchitectureAdapter):
"""Architecture adapter for DeepSeek models."""

def __init__(self, cfg: Any) -> None:
"""Initialize the DeepSeek architecture adapter.

Args:
cfg: The configuration object.
"""
super().__init__(cfg)

self.conversion_rules = WeightConversionSet(
{
"embed.W_E": "model.embed_tokens.weight",
"blocks.{i}.ln1.w": "model.layers.{i}.input_layernorm.weight",
# Attention weights
"blocks.{i}.attn.W_Q": "model.layers.{i}.self_attn.q_proj.weight",
"blocks.{i}.attn.W_K": "model.layers.{i}.self_attn.k_proj.weight",
"blocks.{i}.attn.W_V": "model.layers.{i}.self_attn.v_proj.weight",
"blocks.{i}.attn.W_O": "model.layers.{i}.self_attn.o_proj.weight",
"blocks.{i}.ln2.w": "model.layers.{i}.post_attention_layernorm.weight",
# MLP weights for dense layers
"blocks.{i}.mlp.W_gate": "model.layers.{i}.mlp.gate_proj.weight",
"blocks.{i}.mlp.W_in": "model.layers.{i}.mlp.up_proj.weight",
"blocks.{i}.mlp.W_out": "model.layers.{i}.mlp.down_proj.weight",
# MoE weights
"blocks.{i}.moe.gate.w": "model.layers.{i}.mlp.gate.weight",
"blocks.{i}.moe.experts.W_gate.{j}": "model.layers.{i}.mlp.experts.{j}.gate_proj.weight",
"blocks.{i}.moe.experts.W_in.{j}": "model.layers.{i}.mlp.experts.{j}.up_proj.weight",
"blocks.{i}.moe.experts.W_out.{j}": "model.layers.{i}.mlp.experts.{j}.down_proj.weight",
"ln_final.w": "model.norm.weight",
"unembed.W_U": "lm_head.weight",
}
)

self.component_mapping = {
"embed": ("model.embed_tokens", EmbeddingBridge),
"blocks": (
"model.layers",
BlockBridge,
{
"ln1": ("input_layernorm", LayerNormBridge),
"ln2": ("post_attention_layernorm", LayerNormBridge),
"attn": ("self_attn", AttentionBridge),
"mlp": ("mlp", MLPBridge),
"moe": ("mlp", MoEBridge),
},
),
"ln_final": ("model.norm", LayerNormBridge),
"unembed": ("lm_head", UnembeddingBridge),
}
Loading