Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions embeddings/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions embeddings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ candle-core = "0.9.2"
candle-nn = "0.9.2"
candle-transformers = "0.9.2"
ort = { version = "2.0.0-rc.9", default-features = false, features = ["std"] }
rayon = "1.11"

[features]
default = []
Expand Down
5 changes: 4 additions & 1 deletion embeddings/manticoresearch_text_embeddings.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ struct StringItem {
uintptr_t len;
};

using MakeVectEmbeddingsFn = FloatVecResult(*)(const TextModelWrapper*, const StringItem*, uintptr_t);
using MakeVectEmbeddingsFn = FloatVecResult(*)(const TextModelWrapper*,
const StringItem*,
uintptr_t,
int32_t);

using FreeVecResultFn = void(*)(FloatVecResult);

Expand Down
4 changes: 2 additions & 2 deletions embeddings/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type LoadModelFn = extern "C" fn(
type FreeModelResultFn = extern "C" fn(TextModelResult);

type MakeVectEmbeddingsFn =
extern "C" fn(&TextModelWrapper, *const StringItem, usize) -> FloatVecResult;
extern "C" fn(&TextModelWrapper, *const StringItem, usize, i32) -> FloatVecResult;

type FreeVecResultFn = extern "C" fn(FloatVecResult);

Expand Down Expand Up @@ -62,7 +62,7 @@ pub struct EmbedLib {
const VERSION_STR: &[u8] = concat!(env!("EMBEDDINGS_VERSION_STR"), "\0").as_bytes();

const LIB: EmbedLib = EmbedLib {
version: 3usize,
version: 4usize,
version_str: VERSION_STR.as_ptr() as *const c_char,
load_model: TextModelWrapper::load_model,
free_model_result: TextModelWrapper::free_model_result,
Expand Down
2 changes: 1 addition & 1 deletion embeddings/src/model/ffi_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ mod tests {
)
};
let vec_result =
TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1);
TextModelWrapper::make_vect_embeddings(&wrapper, items.as_ptr(), 1, 0);
assert!(vec_result.error.is_null());
assert_eq!(vec_result.len, 1);
TextModelWrapper::free_vec_result(vec_result);
Expand Down
8 changes: 6 additions & 2 deletions embeddings/src/model/jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ impl JinaModel {
}

impl TextModel for JinaModel {
fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn std::error::Error>> {
fn predict(
&self,
texts: &[&str],
_threads: usize,
) -> Result<Vec<Vec<f32>>, Box<dyn std::error::Error>> {
let url = self
.api_url
.as_deref()
Expand Down Expand Up @@ -308,7 +312,7 @@ impl TextModel for JinaModel {
fn validate_api_key(&self) -> Result<(), Box<dyn std::error::Error>> {
// Make a minimal test request with a single character to validate the API key
// This is cheaper than a full embedding request but still validates the key
self.predict(&["test"])?;
self.predict(&["test"], 0)?;
Ok(())
}
}
4 changes: 2 additions & 2 deletions embeddings/src/model/jina_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ mod tests {

// Test with real API key
let texts = vec!["test"];
let result = model.predict(&texts);
let result = model.predict(&texts, 0);
match result {
Ok(embeddings) => {
// Should have one embedding for one text
Expand Down Expand Up @@ -157,7 +157,7 @@ mod tests {
let model = JinaModel::new("jina/jina-embeddings-v3", &api_key, None, None).unwrap();

let empty_texts: Vec<&str> = vec![];
let result = model.predict(&empty_texts);
let result = model.predict(&empty_texts, 0);
// Empty input should succeed with empty result
match result {
Ok(embeddings) => {
Expand Down
85 changes: 69 additions & 16 deletions embeddings/src/model/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,38 @@ fn intra_threads() -> usize {
.unwrap_or(DEFAULT_INTRA_THREADS)
}

/// Run `f` inside a scoped rayon thread pool of the requested size.
///
/// `threads == 0` means "no cap": just run `f` on the existing (global) rayon pool,
/// which by default uses every available CPU. `threads > 0` builds a fresh pool
/// of that size (clamped to available CPUs) and installs it for the duration of `f`,
/// so candle's intra-op rayon work and tokenizers' parallelism both respect the limit.
///
/// Errors are stringified across the `pool.install` boundary because `Box<dyn Error>`
/// is not `Send`; the original error message is preserved.
fn with_thread_limit<F>(threads: usize, f: F) -> Result<Vec<Vec<f32>>, Box<dyn Error>>
where
F: FnOnce() -> Result<Vec<Vec<f32>>, Box<dyn Error>> + Send,
{
if threads == 0 {
return f();
}

let n = threads.min(available_cpus()).max(1);
let pool = match rayon::ThreadPoolBuilder::new().num_threads(n).build() {
Ok(p) => p,
// If the pool can't be built (extremely unlikely), fall back to the global pool
// rather than failing the whole inference.
Err(_) => return f(),
};

// pool.install requires `R: Send`, but Box<dyn Error> isn't Send. Stringify the
// error inside the closure and re-wrap it on the way out — keeps the message,
// satisfies the Send bound, and avoids forcing every model's error into Send+Sync.
pool.install(|| f().map_err(|e| e.to_string()))
.map_err(|s| -> Box<dyn Error> { s.into() })
}

/// Thread-safe session wrapper with platform-specific strategy:
/// - Linux/macOS: UnsafeCell for concurrent Run() (ORT C API is thread-safe)
/// - Windows: Mutex for serialized Run() (Windows ORT has threading issues)
Expand Down Expand Up @@ -967,7 +999,13 @@ impl OnnxEmbeddingModel {
/// - Large input (> batch_size): splits into num_cpus concurrent single-doc workers.
/// Each worker tokenizes + infers 1 doc at a time through SessionWrapper,
/// mimicking the concurrent caller pattern that gives best throughput.
fn predict_pipelined(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
///
/// `threads` caps the worker count. 0 means "use all available CPUs".
fn predict_pipelined(
&self,
texts: &[&str],
threads: usize,
) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
let bs = batch_size();
let max_input = self.max_input_len;
let session = &self.session;
Expand All @@ -984,8 +1022,14 @@ impl OnnxEmbeddingModel {

// Adaptive parallelism: scale workers with input size.
// Each worker needs at least batch_size docs to justify thread overhead.
// Cap at available CPUs — more workers than cores adds contention.
let num_workers = (texts.len() / bs).min(available_cpus()).max(1);
// Worker cap is the caller-supplied `threads` limit when > 0,
// otherwise fall back to all available CPUs.
let thread_cap = if threads > 0 {
threads.min(available_cpus())
} else {
available_cpus()
};
let num_workers = (texts.len() / bs).min(thread_cap).max(1);
let docs_per_worker = texts.len().div_ceil(num_workers);

let mut ordered_results: Vec<Vec<Vec<f32>>> = Vec::with_capacity(num_workers);
Expand Down Expand Up @@ -1146,19 +1190,15 @@ impl LocalModel {
}
}

impl TextModel for LocalModel {
fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
// BERT and ONNX: batched path (batch_size up to batch_size() per forward pass)
match self {
LocalModel::Bert(m) => {
return Self::predict_batched(&m.tokenizer, m.max_input_len, texts, |chunks| {
m.predict_chunks(chunks)
});
}
LocalModel::Onnx(m) => {
return m.predict_pipelined(texts);
}
_ => {} // fall through to sequential path
impl LocalModel {
/// Inner predict body for non-ONNX local models (BERT / T5 / Causal / Quantized).
/// Pulled out of the trait impl so the caller can wrap it in a scoped rayon pool.
fn predict_local(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
// BERT: batched path (batch_size up to batch_size() per forward pass)
if let LocalModel::Bert(m) = self {
return Self::predict_batched(&m.tokenizer, m.max_input_len, texts, |chunks| {
m.predict_chunks(chunks)
});
}

// Sequential path for T5, Causal, Quantized (these use KV caches / mutexes)
Expand Down Expand Up @@ -1288,6 +1328,19 @@ impl TextModel for LocalModel {

Ok(all_results)
}
}

impl TextModel for LocalModel {
fn predict(&self, texts: &[&str], threads: usize) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
// ONNX manages its own worker count internally — no rayon pool involved.
if let LocalModel::Onnx(m) = self {
return m.predict_pipelined(texts, threads);
}

// BERT / T5 / Causal / Quantized go through candle, which uses rayon for
// intra-op parallelism. Scope the rayon pool so threads > 0 caps the worker count.
with_thread_limit(threads, || self.predict_local(texts))
}

fn get_hidden_size(&self) -> usize {
match self {
Expand Down
36 changes: 18 additions & 18 deletions embeddings/src/model/local_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ mod tests {

for sentence in &test_sentences {
let local_model = LocalModel::new(model_id, cache_path.clone(), false, None).unwrap();
let embedding = local_model.predict(&[sentence]).unwrap();
let embedding = local_model.predict(&[sentence], 0).unwrap();
check_embedding_properties(&embedding[0], local_model.get_hidden_size());
}
}
Expand All @@ -429,8 +429,8 @@ mod tests {
let local_model = LocalModel::new(model_id, cache_path, false, None).unwrap();

let sentence = &["This is a test sentence."];
let embedding1 = local_model.predict(sentence).unwrap();
let embedding2 = local_model.predict(sentence).unwrap();
let embedding1 = local_model.predict(sentence, 0).unwrap();
let embedding2 = local_model.predict(sentence, 0).unwrap();

for (e1, e2) in embedding1[0].iter().zip(embedding2[0].iter()) {
assert_abs_diff_eq!(e1, e2, epsilon = 1e-6);
Expand Down Expand Up @@ -466,7 +466,7 @@ mod tests {

let test_text = &["This is a test sentence for Qwen embedding model."];
let embeddings = local_model
.predict(test_text)
.predict(test_text, 0)
.expect("Qwen model should generate embeddings");

check_embedding_properties(&embeddings[0], local_model.get_hidden_size());
Expand All @@ -482,7 +482,7 @@ mod tests {
.expect("Llama model should load");

let test_text = &["This is a test sentence for Llama embedding model."];
let embeddings = local_model.predict(test_text).unwrap();
let embeddings = local_model.predict(test_text, 0).unwrap();

check_embedding_properties(&embeddings[0], local_model.get_hidden_size());
}
Expand All @@ -496,7 +496,7 @@ mod tests {
let local_model = LocalModel::new(model_id, cache_path.clone(), false, None)
.expect("Mistral model should load");
let test_text = &["This is a test sentence for Mistral embedding model."];
let embeddings = local_model.predict(test_text).unwrap();
let embeddings = local_model.predict(test_text, 0).unwrap();
check_embedding_properties(&embeddings[0], local_model.get_hidden_size());
}

Expand All @@ -510,7 +510,7 @@ mod tests {
.expect("Gemma model should load");

let test_text = &["This is a test sentence for Gemma embedding model."];
let embeddings = local_model.predict(test_text).unwrap();
let embeddings = local_model.predict(test_text, 0).unwrap();
check_embedding_properties(&embeddings[0], local_model.get_hidden_size());
}

Expand All @@ -536,7 +536,7 @@ mod tests {
"Third sentence for batch processing verification.",
];

let embeddings = local_model.predict(test_texts).unwrap();
let embeddings = local_model.predict(test_texts, 0).unwrap();

assert_eq!(embeddings.len(), test_texts.len());

Expand Down Expand Up @@ -660,7 +660,7 @@ mod tests {
// Test embedding generation
let test_text = &["This is a test sentence for FRIDA embedding model."];
let embeddings = local_model
.predict(test_text)
.predict(test_text, 0)
.expect("FRIDA model should generate embeddings");

assert_eq!(embeddings.len(), 1, "Should return one embedding");
Expand Down Expand Up @@ -697,7 +697,7 @@ mod tests {
// Test embedding generation
let test_text = &["This is a test sentence for Google embeddinggemma model."];
let embeddings = local_model
.predict(test_text)
.predict(test_text, 0)
.expect("Google embeddinggemma should generate embeddings");

assert_eq!(embeddings.len(), 1, "Should return one embedding");
Expand Down Expand Up @@ -739,7 +739,7 @@ mod tests {

let test_text = &["This is a test sentence for ONNX embedding model."];
let embeddings = local_model
.predict(test_text)
.predict(test_text, 0)
.expect("ONNX model should generate embeddings");

assert_eq!(embeddings.len(), 1);
Expand All @@ -761,8 +761,8 @@ mod tests {
};

let sentence = &["This is a test sentence."];
let embedding1 = local_model.predict(sentence).unwrap();
let embedding2 = local_model.predict(sentence).unwrap();
let embedding1 = local_model.predict(sentence, 0).unwrap();
let embedding2 = local_model.predict(sentence, 0).unwrap();

for (e1, e2) in embedding1[0].iter().zip(embedding2[0].iter()) {
assert_abs_diff_eq!(e1, e2, epsilon = 1e-6);
Expand All @@ -789,7 +789,7 @@ mod tests {
"Third sentence for batch processing.",
];

let embeddings = local_model.predict(test_texts).unwrap();
let embeddings = local_model.predict(test_texts, 0).unwrap();
assert_eq!(embeddings.len(), test_texts.len());

for embedding in &embeddings {
Expand Down Expand Up @@ -850,20 +850,20 @@ mod tests {
};

// Warmup
let _ = st.predict(sentences).unwrap();
let _ = onnx.predict(sentences).unwrap();
let _ = st.predict(sentences, 0).unwrap();
let _ = onnx.predict(sentences, 0).unwrap();

// Safetensors timing
let start = Instant::now();
for _ in 0..iterations {
let _ = st.predict(sentences).unwrap();
let _ = st.predict(sentences, 0).unwrap();
}
let st_ms = start.elapsed().as_millis() as f64 / iterations as f64;

// ONNX timing
let start = Instant::now();
for _ in 0..iterations {
let _ = onnx.predict(sentences).unwrap();
let _ = onnx.predict(sentences, 0).unwrap();
}
let onnx_ms = start.elapsed().as_millis() as f64 / iterations as f64;

Expand Down
16 changes: 10 additions & 6 deletions embeddings/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ use std::error::Error;
use std::path::PathBuf;

pub trait TextModel {
fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>>;
/// Generate embeddings for the given texts.
///
/// `threads` caps the number of CPU threads used during generation.
/// 0 means "use all available CPUs" (default).
fn predict(&self, texts: &[&str], threads: usize) -> Result<Vec<Vec<f32>>, Box<dyn Error>>;
fn get_hidden_size(&self) -> usize;
fn get_max_input_len(&self) -> usize;
/// Validates the API key by making a minimal test request to the API.
Expand Down Expand Up @@ -57,12 +61,12 @@ pub enum Model {
}

impl TextModel for Model {
fn predict(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
fn predict(&self, texts: &[&str], threads: usize) -> Result<Vec<Vec<f32>>, Box<dyn Error>> {
match self {
Model::OpenAI(m) => m.predict(texts),
Model::Voyage(m) => m.predict(texts),
Model::Jina(m) => m.predict(texts),
Model::Local(m) => m.predict(texts),
Model::OpenAI(m) => m.predict(texts, threads),
Model::Voyage(m) => m.predict(texts, threads),
Model::Jina(m) => m.predict(texts, threads),
Model::Local(m) => m.predict(texts, threads),
}
}

Expand Down
Loading
Loading