-
Notifications
You must be signed in to change notification settings - Fork 292
Add MTEBEvaluator for embedding model evaluation #2409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+393
−0
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,263 @@ | ||
| # ------------------------------------------------------------------------- | ||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| """MTEB-compatible wrappers for ONNX Runtime embedding models. | ||
|
|
||
| Mirrors the pattern in ``lmeval_ort.py``: thin adapters that bridge an | ||
| exported ONNX (or ORT-GenAI) model into the interface expected by the | ||
| evaluation library — in this case MTEB's ``EncoderProtocol``. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from abc import ABC, abstractmethod | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from olive.common.onnx_io import get_io_config | ||
|
|
||
| try: | ||
| import onnxruntime_genai as og | ||
| except ImportError: | ||
| og = None | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class MTEBOnnxBase(ABC): | ||
| """Base class for MTEB-compatible ONNX embedding model wrappers. | ||
|
|
||
| Subclasses must implement :meth:`_encode_batch` which takes tokenised | ||
| inputs and returns a numpy embedding matrix. | ||
| """ | ||
|
|
||
| def __init__(self, tokenizer_path: str, batch_size: int = 32, max_length: int | None = None): | ||
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | ||
| self.batch_size = batch_size | ||
| self._max_length = max_length | ||
|
|
||
| @property | ||
| def max_length(self) -> int: | ||
| if self._max_length: | ||
| return self._max_length | ||
| if hasattr(self.tokenizer, "model_max_length") and self.tokenizer.model_max_length < 1_000_000: | ||
| return self.tokenizer.model_max_length | ||
| return 512 | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # MTEB EncoderProtocol interface | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def encode(self, inputs, *, task_metadata=None, hf_split=None, hf_subset=None, prompt_type=None, **kwargs): | ||
| """Encode sentences into embeddings — MTEB ``EncoderProtocol.encode``.""" | ||
| # Handle string input (single sentence) | ||
| if isinstance(inputs, str): | ||
| inputs = [inputs] | ||
|
|
||
| # Flatten DataLoader batches into a plain list of strings | ||
| sentences: list[str] = [] | ||
| for batch in inputs: | ||
| if isinstance(batch, dict) and "text" in batch: | ||
| sentences.extend(batch["text"]) | ||
| elif isinstance(batch, (list, tuple)): | ||
| sentences.extend(batch) | ||
| elif isinstance(batch, str): | ||
| sentences.append(batch) | ||
| else: | ||
| sentences.extend(list(batch)) | ||
|
|
||
|
natke marked this conversation as resolved.
|
||
| all_embeddings: list[np.ndarray] = [] | ||
| for start in range(0, len(sentences), self.batch_size): | ||
| batch_sentences = sentences[start : start + self.batch_size] | ||
| encoded = self.tokenizer( | ||
| batch_sentences, | ||
| padding=True, | ||
| truncation=True, | ||
| max_length=self.max_length, | ||
| return_tensors="np", | ||
| ) | ||
| embeddings = self._encode_batch(encoded) | ||
| all_embeddings.append(embeddings) | ||
|
|
||
| return np.concatenate(all_embeddings, axis=0) | ||
|
|
||
| @staticmethod | ||
| def similarity(embeddings1, embeddings2): | ||
| """Cosine similarity — default for MTEB.""" | ||
| if isinstance(embeddings1, np.ndarray): | ||
| embeddings1 = torch.from_numpy(embeddings1) | ||
| if isinstance(embeddings2, np.ndarray): | ||
| embeddings2 = torch.from_numpy(embeddings2) | ||
| embeddings1 = torch.nn.functional.normalize(embeddings1.float(), p=2, dim=-1) | ||
| embeddings2 = torch.nn.functional.normalize(embeddings2.float(), p=2, dim=-1) | ||
| return embeddings1 @ embeddings2.T | ||
|
|
||
| @staticmethod | ||
| def similarity_pairwise(embeddings1, embeddings2): | ||
| if isinstance(embeddings1, np.ndarray): | ||
| embeddings1 = torch.from_numpy(embeddings1) | ||
| if isinstance(embeddings2, np.ndarray): | ||
| embeddings2 = torch.from_numpy(embeddings2) | ||
| embeddings1 = torch.nn.functional.normalize(embeddings1.float(), p=2, dim=-1) | ||
| embeddings2 = torch.nn.functional.normalize(embeddings2.float(), p=2, dim=-1) | ||
| return (embeddings1 * embeddings2).sum(dim=-1) | ||
|
|
||
| @property | ||
| def mteb_model_meta(self): | ||
| from mteb.models.model_meta import ModelMeta | ||
|
|
||
| return ModelMeta.create_empty() | ||
|
|
||
| @abstractmethod | ||
| def _encode_batch(self, encoded_input: dict) -> np.ndarray: | ||
| """Run model inference on a tokenised batch and return embeddings. | ||
|
|
||
| Args: | ||
| encoded_input: Dictionary with at least ``input_ids`` and | ||
| ``attention_mask`` as numpy arrays of shape ``[batch, seqlen]``. | ||
|
|
||
| Returns: | ||
|
github-advanced-security[bot] marked this conversation as resolved.
Fixed
|
||
| Embeddings array of shape ``[batch, embed_dim]``. | ||
|
|
||
| """ | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| # ------------------------------------------------------------------ | ||
| # ORT (plain ONNX) variant | ||
| # ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| class MTEBORTEvaluator(MTEBOnnxBase): | ||
| """MTEB wrapper for a plain ONNX embedding model run via ORT.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_path: str, | ||
| batch_size: int = 32, | ||
| max_length: int | None = None, | ||
| ep: str | None = None, | ||
| ep_options: dict | None = None, | ||
| ): | ||
| import onnxruntime as ort | ||
|
|
||
| model_dir = str(Path(model_path).parent) | ||
| super().__init__(tokenizer_path=model_dir, batch_size=batch_size, max_length=max_length) | ||
|
|
||
| providers = [] | ||
| if ep: | ||
| providers.append((ep, ep_options or {})) | ||
| providers.append(("CPUExecutionProvider", {})) | ||
|
|
||
| self.session = ort.InferenceSession(model_path, providers=providers) | ||
| self.io_config = get_io_config(model_path) | ||
| self._output_names = self.io_config["output_names"] | ||
|
|
||
| def _encode_batch(self, encoded_input: dict) -> np.ndarray: | ||
| input_ids = encoded_input["input_ids"] | ||
| attention_mask = encoded_input["attention_mask"] | ||
|
|
||
| feeds = {"input_ids": input_ids.astype(np.int64), "attention_mask": attention_mask.astype(np.int64)} | ||
| # Some models also accept token_type_ids | ||
| if "token_type_ids" in list(self.io_config["input_names"]): | ||
| feeds["token_type_ids"] = encoded_input.get("token_type_ids", np.zeros_like(input_ids, dtype=np.int64)) | ||
|
|
||
| outputs = self.session.run(None, feeds) | ||
|
|
||
| # Determine which output contains the embeddings. | ||
| # Common patterns: "last_hidden_state" (index 0) or a dedicated "sentence_embedding" output. | ||
| if len(outputs) == 1: | ||
| hidden_states = outputs[0] | ||
| else: | ||
| # Try to find a sentence-level output first | ||
| for i, name in enumerate(self._output_names): | ||
| if "sentence" in name.lower() or "pooler" in name.lower() or "embedding" in name.lower(): | ||
| return outputs[i] # Already pooled | ||
| # Fall back to the first output (last_hidden_state) | ||
| hidden_states = outputs[0] | ||
|
|
||
| # Mean pooling over the sequence dimension, masked by attention_mask | ||
| return self._mean_pool(hidden_states, attention_mask) | ||
|
|
||
| @staticmethod | ||
| def _mean_pool(hidden_states: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: | ||
| """Mean pooling: average token embeddings weighted by attention mask.""" | ||
| mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(hidden_states.dtype) | ||
| sum_embeddings = np.sum(hidden_states * mask_expanded, axis=1) | ||
| sum_mask = np.clip(mask_expanded.sum(axis=1), a_min=1e-9, a_max=None) | ||
| return sum_embeddings / sum_mask | ||
|
|
||
|
|
||
| # ------------------------------------------------------------------ | ||
| # ORT-GenAI variant (for models built with ModelBuilder) | ||
| # ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| class MTEBORTGenAIEvaluator(MTEBOnnxBase): | ||
| """MTEB wrapper for an ORT-GenAI embedding model (ModelBuilder output).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| pretrained: str, | ||
| batch_size: int = 32, | ||
| max_length: int | None = None, | ||
| ep: str = "follow_config", | ||
| ep_options: dict | None = None, | ||
| ): | ||
| if og is None: | ||
| raise ImportError("onnxruntime-genai is not installed.") | ||
|
|
||
| super().__init__(tokenizer_path=pretrained, batch_size=batch_size, max_length=max_length) | ||
|
|
||
| self.config = og.Config(pretrained) | ||
| if ep != "follow_config": | ||
| ep_clean = ep.lower().replace("executionprovider", "") | ||
| self.config.clear_providers() | ||
| if ep_clean != "cpu": | ||
| self.config.append_provider(ep_clean) | ||
| for key, value in (ep_options or {}).items(): | ||
| self.config.set_provider_option(ep_clean, key, value) | ||
|
|
||
| self.model = og.Model(self.config) | ||
| self.og_tokenizer = og.Tokenizer(self.model) | ||
| self.pretrained = pretrained | ||
|
|
||
| def _encode_batch(self, encoded_input: dict) -> np.ndarray: | ||
| input_ids = encoded_input["input_ids"].astype(np.int64) | ||
| attention_mask = encoded_input["attention_mask"].astype(np.int64) | ||
|
|
||
| # Use the GeneratorParams / Generator API to get hidden states | ||
| batch_size, seq_len = input_ids.shape | ||
| params = og.GeneratorParams(self.model) | ||
| params.set_search_options(max_length=seq_len + 1, past_present_share_buffer=False, batch_size=batch_size) | ||
|
|
||
| generator = og.Generator(self.model, params) | ||
| generator.append_tokens(input_ids.tolist()) | ||
|
|
||
| # Try to get hidden_states output (enabled via include_hidden_states=1) | ||
| try: | ||
| hidden_states = generator.get_output("hidden_states") | ||
| hidden_states = np.array(hidden_states, copy=False) | ||
| if hidden_states.ndim == 2: | ||
| # Shape might be [batch*seq, dim] — reshape | ||
| embed_dim = hidden_states.shape[-1] | ||
| hidden_states = hidden_states.reshape(batch_size, seq_len, embed_dim) | ||
| return self._mean_pool(hidden_states, attention_mask) | ||
| except Exception as e: | ||
| raise RuntimeError( | ||
| "hidden_states output not available from GenAI model. " | ||
| "Ensure the model was built with include_hidden_states=1 in ModelBuilder." | ||
| ) from e | ||
|
|
||
|
natke marked this conversation as resolved.
|
||
| @staticmethod | ||
| def _mean_pool(hidden_states: np.ndarray, attention_mask: np.ndarray) -> np.ndarray: | ||
| """Mean pooling: average token embeddings weighted by attention mask.""" | ||
| mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(hidden_states.dtype) | ||
| sum_embeddings = np.sum(hidden_states * mask_expanded, axis=1) | ||
| sum_mask = np.clip(mask_expanded.sum(axis=1), a_min=1e-9, a_max=None) | ||
| return sum_embeddings / sum_mask | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.