Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions embeddings/src/model/create_model_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use super::{create_model, Model, ModelOptions};

#[test]
fn test_create_model_allows_custom_openai_model_when_custom_api_url_is_set() {
let model = create_model(ModelOptions {
model_id: "openai/rubert-tiny-turbo".to_string(),
cache_path: None,
api_key: Some("test-key".to_string()),
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
api_timeout: None,
use_gpu: None,
});

assert!(model.is_ok());

match model.unwrap() {
Model::OpenAI(model) => assert_eq!(model.model, "rubert-tiny-turbo"),
_ => panic!("expected OpenAI model"),
}
}

#[test]
fn test_create_model_with_custom_url_still_uses_prefixed_jina_as_remote_signal() {
let model = create_model(ModelOptions {
model_id: "jina/custom-model".to_string(),
cache_path: None,
api_key: Some("test-key".to_string()),
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
api_timeout: None,
use_gpu: None,
});

assert!(model.is_ok());

match model.unwrap() {
Model::Jina(model) => assert_eq!(model.model, "custom-model"),
_ => panic!("expected Jina model"),
}
}

#[test]
fn test_create_model_supports_explicit_openai_colon_syntax() {
let model = create_model(ModelOptions {
model_id: "openai:openai/text-embedding-ada-002".to_string(),
cache_path: None,
api_key: Some("test-key".to_string()),
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
api_timeout: None,
use_gpu: None,
});

assert!(model.is_ok());

match model.unwrap() {
Model::OpenAI(model) => assert_eq!(model.model, "openai/text-embedding-ada-002"),
_ => panic!("expected OpenAI model"),
}
}

#[test]
fn test_create_model_supports_explicit_openai_colon_syntax_with_simple_model() {
let model = create_model(ModelOptions {
model_id: "openai:text-embedding-ada-002".to_string(),
cache_path: None,
api_key: Some("test-key".to_string()),
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
api_timeout: None,
use_gpu: None,
});

assert!(model.is_ok());

match model.unwrap() {
Model::OpenAI(model) => assert_eq!(model.model, "text-embedding-ada-002"),
_ => panic!("expected OpenAI model"),
}
}
78 changes: 59 additions & 19 deletions embeddings/src/model/jina.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use super::TextModel;
use super::{ModelValidationMode, TextModel};
use crate::LibError;
use reqwest::blocking::Client;
use std::sync::OnceLock;

#[derive(Debug)]
pub struct JinaModel {
pub client: Client,
pub model: String,
pub api_key: String,
pub api_url: Option<String>,
hidden_size_cache: OnceLock<usize>,
}

pub fn validate_model(model: &str) -> Result<(), String> {
Expand Down Expand Up @@ -50,8 +52,32 @@ impl JinaModel {
api_url: Option<&str>,
api_timeout: Option<u64>,
) -> Result<Self, Box<dyn std::error::Error>> {
let model = model_id.trim_start_matches("jina/").to_string();
validate_model(&model).map_err(|_| LibError::RemoteUnsupportedModel { status: None })?;
let validation_mode = if api_url.is_some() {
ModelValidationMode::Passthrough
} else {
ModelValidationMode::StrictBuiltInList
};

Self::new_with_validation_mode(model_id, api_key, api_url, api_timeout, validation_mode)
}

pub fn new_with_validation_mode(
model_id: &str,
api_key: &str,
api_url: Option<&str>,
api_timeout: Option<u64>,
validation_mode: ModelValidationMode,
) -> Result<Self, Box<dyn std::error::Error>> {
let model = if let Some(model) = model_id.strip_prefix("jina:") {
model.to_string()
} else {
model_id.trim_start_matches("jina/").to_string()
};

if validation_mode == ModelValidationMode::StrictBuiltInList {
validate_model(&model)
.map_err(|_| LibError::RemoteUnsupportedModel { status: None })?;
}
// Only validate basic requirements (non-empty, no whitespace)
// Real validation happens via actual API request in validate_api_key()
validate_api_key_basic(api_key)
Expand All @@ -62,8 +88,26 @@ impl JinaModel {
model,
api_key: api_key.to_string(),
api_url: api_url.map(|s| s.to_string()),
hidden_size_cache: OnceLock::new(),
})
}

fn known_hidden_size(&self) -> Option<usize> {
match self.model.as_str() {
"jina-embeddings-v4" => Some(2048), // 32K context, 2048 dimensions
"jina-clip-v2" => Some(1024), // 8K context, 1024 dimensions, multimodal
"jina-embeddings-v3" => Some(1024), // 8K context, 1024 dimensions
"jina-colbert-v2" => Some(128), // Multi-vector model, 8K context
"jina-clip-v1" => Some(768), // 8K context, 768 dimensions, multimodal
"jina-colbert-v1-en" => Some(128), // Multi-vector model, 8K context
"jina-embeddings-v2-base-es" => Some(768), // 8K context, 768 dimensions
"jina-embeddings-v2-base-code" => Some(768), // 8K context, 768 dimensions
"jina-embeddings-v2-base-de" => Some(768), // 8K context, 768 dimensions
"jina-embeddings-v2-base-zh" => Some(768), // 8K context, 768 dimensions
"jina-embeddings-v2-base-en" => Some(768), // 8K context, 768 dimensions
_ => None,
}
}
}

impl TextModel for JinaModel {
Expand Down Expand Up @@ -254,15 +298,17 @@ impl TextModel for JinaModel {
}));
}

let inferred_dim = embeddings[0].len();
let _ = self.hidden_size_cache.set(inferred_dim);

// Validate embedding dimensions and handle empty individual embeddings
let expected_dim = self.get_hidden_size();
for embedding in embeddings.iter() {
if embedding.is_empty() {
return Err(Box::new(LibError::RemoteHttpError {
status: status_code,
}));
}
if embedding.len() != expected_dim {
if embedding.len() != inferred_dim {
// Some models might return different dimensions, but we should validate
// For now, we'll be lenient but could add stricter validation later
}
Expand All @@ -272,20 +318,14 @@ impl TextModel for JinaModel {
}

fn get_hidden_size(&self) -> usize {
match self.model.as_str() {
"jina-embeddings-v4" => 2048, // 32K context, 2048 dimensions
"jina-clip-v2" => 1024, // 8K context, 1024 dimensions, multimodal
"jina-embeddings-v3" => 1024, // 8K context, 1024 dimensions
"jina-colbert-v2" => 128, // Multi-vector model, 8K context
"jina-clip-v1" => 768, // 8K context, 768 dimensions, multimodal
"jina-colbert-v1-en" => 128, // Multi-vector model, 8K context
"jina-embeddings-v2-base-es" => 768, // 8K context, 768 dimensions
"jina-embeddings-v2-base-code" => 768, // 8K context, 768 dimensions
"jina-embeddings-v2-base-de" => 768, // 8K context, 768 dimensions
"jina-embeddings-v2-base-zh" => 768, // 8K context, 768 dimensions
"jina-embeddings-v2-base-en" => 768, // 8K context, 768 dimensions
_ => panic!("Unknown model"),
}
// Invariant: cache is populated by new_with_validation_mode() — either
// implicitly via known_hidden_size() for built-ins or explicitly via a
// probe predict() for passthrough models. A miss here is a construction
// bug; the catch_unwind at the FFI boundary stops the panic from
// crossing into C++.
self.known_hidden_size()
.or_else(|| self.hidden_size_cache.get().copied())
.expect("hidden size must be populated during model construction")
}

fn get_max_input_len(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion embeddings/src/model/jina_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ mod tests {
}
}

#[should_panic(expected = "Unknown model")]
#[should_panic(expected = "hidden size must be populated during model construction")]
#[test]
fn test_get_hidden_size_unknown_model() {
// This test verifies the panic behavior for unknown models
Expand Down
26 changes: 23 additions & 3 deletions embeddings/src/model/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,16 @@ impl BertEmbeddingModel {
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(seq_len as f32, &self.device)?;
let mean_emb = summed.broadcast_div(&divisor)?;
mean_emb.get(0)?.to_vec1::<f32>()?
// .contiguous() forces candle's to_vec1 to take its
// contiguous-offsets path (slice::to_vec, cap == len).
// The strided path uses Iterator::collect, which can
// produce Vec with cap > len from FromIterator growth
// doubling — that would mean the (ptr, len, cap) we
// hand across FFI doesn't match the canonical layout
// glibc expects when Vec::from_raw_parts drops on the
// C++ side via free_vec_result. Eliminate the path
// dependency entirely.
mean_emb.get(0)?.contiguous()?.to_vec1::<f32>()?
};
normalize(&mut emb_vec);
all_embeddings.push(emb_vec);
Expand Down Expand Up @@ -527,7 +536,10 @@ impl BertEmbeddingModel {

let mut out = Vec::with_capacity(batch_size);
for i in 0..batch_size {
out.push(mean_emb.get(i)?.to_vec1::<f32>()?);
// See contiguous() rationale on the batch-of-1 fast path
// above — same FFI cap/len invariant requirement applies
// to each row pulled out of the batched mean_emb.
out.push(mean_emb.get(i)?.contiguous()?.to_vec1::<f32>()?);
}
out
};
Expand Down Expand Up @@ -1236,7 +1248,10 @@ impl TextModel for LocalModel {
let summed = emb.sum(1)?.to_dtype(DType::F32)?;
let divisor = Tensor::new(seq_len as f32, &m.device)?;
let mean_emb = summed.broadcast_div(&divisor)?;
mean_emb.get(0)?.to_vec1::<f32>()?
// See contiguous() rationale on
// BertEmbeddingModel::predict_chunks above. Same FFI
// canonical-layout invariant required here.
mean_emb.get(0)?.contiguous()?.to_vec1::<f32>()?
};
normalize(&mut emb_vec);
return Ok(vec![emb_vec]);
Expand Down Expand Up @@ -1361,7 +1376,12 @@ impl TextModel for LocalModel {
};

if let Ok(e_j) = embeddings.get(0) {
// See contiguous() rationale on BertEmbeddingModel above.
// Same FFI canonical-layout invariant for T5 / Causal /
// Quantized sequential output.
let emb_vec: Vec<f32> = e_j
.contiguous()
.map_err(|e| -> Box<dyn Error> { Box::new(e) })?
.to_vec1::<f32>()
.map_err(|e| -> Box<dyn Error> { Box::new(e) })?;
let mut emb = emb_vec;
Expand Down
66 changes: 50 additions & 16 deletions embeddings/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ mod local_test;
#[cfg(test)]
mod ffi_test;

#[cfg(test)]
mod create_model_test;

use std::error::Error;
use std::path::PathBuf;

Expand All @@ -41,6 +44,12 @@ pub struct ModelOptions {
pub use_gpu: Option<bool>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelValidationMode {
StrictBuiltInList,
Passthrough,
}

/// Unified model enum
///
/// Architecture:
Expand Down Expand Up @@ -96,34 +105,55 @@ impl TextModel for Model {

pub fn create_model(options: ModelOptions) -> Result<Model, Box<dyn Error>> {
let model_id = options.model_id.as_str();
let api_key = options.api_key.unwrap_or_default();
let api_url = options.api_url;
let api_timeout = options.api_timeout;

// Remote providers (HTTP APIs)
if model_id.starts_with("openai/") {
let model = openai::OpenAIModel::new(
if model_id.starts_with("openai:") {
let model = openai::OpenAIModel::new_with_validation_mode(
model_id,
options.api_key.unwrap_or_default().as_str(),
options.api_url.as_deref(),
options.api_timeout,
api_key.as_str(),
api_url.as_deref(),
api_timeout,
ModelValidationMode::Passthrough,
)?;

Ok(Model::OpenAI(Box::new(model)))
} else if model_id.starts_with("voyage/") {
let model = voyage::VoyageModel::new(
} else if model_id.starts_with("openai/") {
let model =
openai::OpenAIModel::new(model_id, api_key.as_str(), api_url.as_deref(), api_timeout)?;

Ok(Model::OpenAI(Box::new(model)))
} else if model_id.starts_with("voyage:") {
let model = voyage::VoyageModel::new_with_validation_mode(
model_id,
options.api_key.unwrap_or_default().as_str(),
options.api_url.as_deref(),
options.api_timeout,
api_key.as_str(),
api_url.as_deref(),
api_timeout,
ModelValidationMode::Passthrough,
)?;

Ok(Model::Voyage(Box::new(model)))
} else if model_id.starts_with("jina/") {
let model = jina::JinaModel::new(
} else if model_id.starts_with("voyage/") {
let model =
voyage::VoyageModel::new(model_id, api_key.as_str(), api_url.as_deref(), api_timeout)?;

Ok(Model::Voyage(Box::new(model)))
} else if model_id.starts_with("jina:") {
let model = jina::JinaModel::new_with_validation_mode(
model_id,
options.api_key.unwrap_or_default().as_str(),
options.api_url.as_deref(),
options.api_timeout,
api_key.as_str(),
api_url.as_deref(),
api_timeout,
ModelValidationMode::Passthrough,
)?;

Ok(Model::Jina(Box::new(model)))
} else if model_id.starts_with("jina/") {
let model =
jina::JinaModel::new(model_id, api_key.as_str(), api_url.as_deref(), api_timeout)?;

Ok(Model::Jina(Box::new(model)))
} else {
// Local models - auto-detect architecture from config
Expand All @@ -135,7 +165,11 @@ pub fn create_model(options: ModelOptions) -> Result<Model, Box<dyn Error>> {
.unwrap_or(String::from(".cache/manticore")),
);

let hf_token = options.api_key.as_deref();
let hf_token = if api_key.is_empty() {
None
} else {
Some(api_key.as_str())
};
let model = local::LocalModel::new(
model_id,
cache_path,
Expand Down
Loading
Loading