Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from huggingface_hub.hf_api import model_info
from skeletoken import TokenizerModel
from skeletoken.external.transformers import reshape_embeddings
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
Expand Down Expand Up @@ -77,17 +78,22 @@ def distill_from_model(

device = select_optimal_device(device)
original_tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer)
original_tokenizer_model = original_tokenizer_model.prune_added_tokens()

# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
# Copy the original tokenizer model.
tokenizer_model = original_tokenizer_model.model_copy(deep=True)
tokenizer_model = original_tokenizer_model.deep_copy()
if tokenizer_model.adds_prefix_space is not None:
tokenizer_model.adds_prefix_space = True

# Create the vocabulary in the new tokenizer.
tokenizer_model = clean_and_create_vocabulary(tokenizer_model, vocabulary, token_remove_regex=token_remove_regex)
# Remove the post processor, this is not necessary.
tokenizer_model.post_processor = None
# Prune again now that the post processor is gone.
# We can't do this before because we need the post preocessor and associated
# tokens before to add eos/bos.
tokenizer_model = tokenizer_model.prune_added_tokens()

# All tokens in a single list.
all_tokens = tokenizer_model.sorted_vocabulary
Expand All @@ -97,12 +103,15 @@ def distill_from_model(
# Turn all _new_ tokens into ids using the original tokenizer
token_ids = turn_tokens_into_ids(all_tokens, original_tokenizer_model)

# Reshape the transformer
model = reshape_embeddings(model, original_tokenizer_model)

# Create the embeddings using the ids from the original tokenizer.
embeddings = create_embeddings(
tokenized=token_ids,
model=model,
device=device,
pad_token_id=tokenizer_model.pad_token_id or 0,
pad_token_id=original_tokenizer_model.pad_token_id or 0,
pooling=pooling,
)

Expand Down
3 changes: 2 additions & 1 deletion model2vec/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def turn_tokens_into_ids(tokens: list[str], model: TokenizerModel) -> list[list[

token_ids: list[list[int]] = []
for token in tokens:
if token_id := vocabulary.get(token):
token_id = vocabulary.get(token)
if token_id is not None:
token_ids.append([*prefix, token_id, *suffix])
else:
token_ids.append(tokenizer.encode(token).ids)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dev = [
"ruff",
]

distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.0"]
distill = ["torch", "transformers", "scikit-learn", "skeletoken @ https://github.com/stephantul/skeletoken.git"]
onnx = ["onnx", "torch"]
# train also installs inference
train = ["torch", "lightning", "scikit-learn", "skops"]
Expand Down
36 changes: 33 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,24 @@ def mock_tokenizermodel() -> TokenizerModel:


@pytest.fixture
def mock_transformer() -> PreTrainedModel:
def mock_transformer(request: pytest.FixtureRequest) -> PreTrainedModel:
"""Create a mock transformer model."""
params = getattr(request, "param", {}) or {}
# Default vocab size
vocab_size: int = params.get("vocab_size", 30522)
dim: int = params.get("dim", 768)
with_pooler: bool = params.get("with_pooler", True)
pooler_value: float = params.get("pooler_value", 7.0)

class MockPreTrainedModel:
def __init__(self, dim: int = 768, with_pooler: bool = True, pooler_value: float = 7.0) -> None:
def __init__(self, vocab_size: int, dim: int, with_pooler: bool, pooler_value: float) -> None:
self.device = "cpu"
self.name_or_path = "mock-model"
self.dim = dim
self.with_pooler = with_pooler
self.pooler_value = pooler_value
self.input_embs = torch.nn.Embedding(vocab_size, dim)
self.config: dict[str, Any] = {}

def to(self, device: str) -> MockPreTrainedModel:
self.device = device
Expand All @@ -91,7 +99,29 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:

__call__ = forward

return cast(PreTrainedModel, MockPreTrainedModel())
def get_input_embeddings(self) -> torch.nn.Embedding:
return self.input_embs

def resize_token_embeddings(self, vocab_size: int) -> None:
curr_size = len(self.input_embs.weight)
if vocab_size == curr_size:
return
if vocab_size < curr_size:
self.input_embs.weight.data = self.input_embs.weight.data[: vocab_size + 1]
else:
self.input_embs.weight.data = torch.cat(
[self.input_embs.weight, torch.zeros(vocab_size - curr_size, self.dim)], dim=0
)

return cast(
PreTrainedModel,
MockPreTrainedModel(
dim=dim,
with_pooler=with_pooler,
pooler_value=pooler_value,
vocab_size=vocab_size,
),
)


@pytest.fixture(scope="session")
Expand Down
17 changes: 10 additions & 7 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def test_distill_removal_pattern_all_tokens(

@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
@pytest.mark.parametrize("mock_transformer", [{"vocab_size": 35022}], indirect=True)
def test_distill_removal_pattern(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
Expand All @@ -114,7 +115,8 @@ def test_distill_removal_pattern(
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
mock_auto_model.return_value = mock_transformer

expected_vocab_size = mock_berttokenizer.vocab_size
# Because the added [MASK], [CLS] and [SEP] get removed
expected_vocab_size = mock_berttokenizer.vocab_size - 3

static_model = distill_from_model(
model=mock_transformer,
Expand Down Expand Up @@ -159,18 +161,19 @@ def test_distill_removal_pattern(
@pytest.mark.parametrize(
"vocabulary, pca_dims, sif_coefficient, expected_shape",
[
(None, 256, None, (30522, 256)), # PCA applied, SIF off
(None, "auto", None, (30522, 768)), # PCA 'auto', SIF off
(None, "auto", 1e-4, (30522, 768)), # PCA 'auto', SIF on
(None, 256, None, (30519, 256)), # PCA applied, SIF off
(None, "auto", None, (30519, 768)), # PCA 'auto', SIF off
(None, "auto", 1e-4, (30519, 768)), # PCA 'auto', SIF on
(None, "auto", 0, None), # invalid SIF (too low) -> raises
(None, "auto", 1, None), # invalid SIF (too high) -> raises
(None, 1024, None, (30522, 768)), # PCA set high (no reduction)
(["wordA", "wordB"], 4, None, (30524, 4)), # Custom vocab, PCA applied
(None, None, None, (30522, 768)), # No PCA, SIF off
(None, 1024, None, (30519, 768)), # PCA set high (no reduction)
(["wordA", "wordB"], 4, None, (30521, 4)), # Custom vocab, PCA applied
(None, None, None, (30519, 768)), # No PCA, SIF off
],
)
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
@pytest.mark.parametrize("mock_transformer", [{"vocab_size": 30522}], indirect=True)
def test_distill(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import safetensors
from tokenizers import Tokenizer

from model2vec import StaticModel
Expand Down
8 changes: 2 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading