Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions bun.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 16 additions & 21 deletions packages/embedding/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ interface EmbedOptions {
abortSignal?: AbortSignal;
}

interface EmbedManyOptions extends EmbedOptions {
maxParallelCalls?: number;
}

function getEmbedOptions(config: EmbeddingConfig): EmbedOptions {
const options: EmbedOptions = {};

Expand All @@ -26,6 +30,16 @@ function getEmbedOptions(config: EmbeddingConfig): EmbedOptions {
return options;
}

function getEmbedManyOptions(config: EmbeddingConfig): EmbedManyOptions {
const options: EmbedManyOptions = getEmbedOptions(config);

if (config.options?.maxParallelCalls !== undefined) {
options.maxParallelCalls = config.options.maxParallelCalls;
}

return options;
}

// =============================================================================
// Error Detection
// =============================================================================
Expand All @@ -50,19 +64,6 @@ function isContextLengthError(error: unknown): boolean {
return false;
}

// =============================================================================
// Provider Options
// =============================================================================

function getProviderOptions(
config: EmbeddingConfig,
): Record<string, Record<string, string>> | undefined {
if (config.provider === "cohere") {
return { cohere: { truncate: "END" } };
}
return undefined;
}

// =============================================================================
// Single Embedding (internal, no retry)
// =============================================================================
Expand All @@ -82,13 +83,11 @@ async function generateEmbeddingOnce(
config: EmbeddingConfig,
): Promise<SingleEmbedResult> {
const model = getEmbeddingModel(config);
const providerOptions = getProviderOptions(config);

const result = await embed({
model,
value: text,
...getEmbedOptions(config),
...(providerOptions && { providerOptions }),
});

if (result.embedding.length !== config.dimensions) {
Expand Down Expand Up @@ -124,8 +123,7 @@ export async function generateEmbedding(
): Promise<SingleEmbedResult> {
const maxTokens = config.options?.maxTokens ?? MAX_OPENAI_TOKENS;

// Non-OpenAI providers: truncate defensively, single attempt
// (Cohere also uses API-side truncation as backup)
// Ollama: truncate defensively, single attempt
if (config.provider !== "openai") {
const { text: truncated } = truncateText(text, maxTokens);
return generateEmbeddingOnce(truncated, config);
Expand Down Expand Up @@ -178,7 +176,6 @@ export async function generateEmbeddings(

const model = getEmbeddingModel(config);
const maxTokens = config.options?.maxTokens ?? MAX_OPENAI_TOKENS;
const providerOptions = getProviderOptions(config);

// Pre-truncate all providers defensively
const texts = rows.map((row) => truncateText(row.content, maxTokens).text);
Expand All @@ -190,8 +187,7 @@ export async function generateEmbeddings(
const { embeddings, usage } = await embedMany({
model,
values: texts,
...getEmbedOptions(config),
...(providerOptions && { providerOptions }),
...getEmbedManyOptions(config),
});

// embedMany returns aggregate token count — distribute evenly
Expand Down Expand Up @@ -250,7 +246,6 @@ export async function generateEmbeddings(
model,
value: text,
...getEmbedOptions(config),
...(providerOptions && { providerOptions }),
});

if (result.embedding.length !== config.dimensions) {
Expand Down
5 changes: 1 addition & 4 deletions packages/embedding/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
"main": "index.ts",
"dependencies": {
"ai": "^6.0.0",
"@ai-sdk/openai": "^3.0.0",
"@ai-sdk/cohere": "^3.0.0",
"@ai-sdk/mistral": "^3.0.0",
"@ai-sdk/google": "^3.0.0"
"@ai-sdk/openai": "^3.0.0"
}
}
44 changes: 4 additions & 40 deletions packages/embedding/provider.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import { createCohere } from "@ai-sdk/cohere";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createMistral } from "@ai-sdk/mistral";
import { createOpenAI } from "@ai-sdk/openai";
import type { EmbeddingModel } from "ai";
import type { EmbeddingConfig } from "./types";
Expand All @@ -12,11 +9,11 @@ import type { EmbeddingConfig } from "./types";
/**
* Get an embedding model for the configured provider.
*
* Supports: openai, ollama, cohere, mistral, google
* Supports: openai, ollama
*
* API key resolution:
* 1. config.apiKey if provided
* 2. Environment variable: {PROVIDER}_API_KEY (e.g., OPENAI_API_KEY)
* 2. Environment variable: OPENAI_API_KEY
*
* Ollama special handling:
* - Auto-appends /v1 to baseUrl if missing
Expand All @@ -25,12 +22,9 @@ import type { EmbeddingConfig } from "./types";
export function getEmbeddingModel(config: EmbeddingConfig): EmbeddingModel {
const provider = config.provider.toLowerCase();

// Resolve API key from config or environment
const envKey = `${provider.toUpperCase()}_API_KEY`;
const apiKey = config.apiKey ?? process.env[envKey];

switch (provider) {
case "openai": {
const apiKey = config.apiKey ?? process.env.OPENAI_API_KEY;
if (!apiKey) {
throw new Error(
`API key not found for OpenAI. Set apiKey in config or OPENAI_API_KEY environment variable.`,
Expand All @@ -57,40 +51,10 @@ export function getEmbeddingModel(config: EmbeddingConfig): EmbeddingModel {
return ollama.embedding(config.model);
}

case "cohere": {
if (!apiKey) {
throw new Error(
`API key not found for Cohere. Set apiKey in config or COHERE_API_KEY environment variable.`,
);
}
const cohere = createCohere({ apiKey });
return cohere.embedding(config.model);
}

case "mistral": {
if (!apiKey) {
throw new Error(
`API key not found for Mistral. Set apiKey in config or MISTRAL_API_KEY environment variable.`,
);
}
const mistral = createMistral({ apiKey });
return mistral.embedding(config.model);
}

case "google": {
if (!apiKey) {
throw new Error(
`API key not found for Google. Set apiKey in config or GOOGLE_API_KEY environment variable.`,
);
}
const google = createGoogleGenerativeAI({ apiKey });
return google.textEmbeddingModel(config.model);
}

default:
throw new Error(
`Unsupported embedding provider: ${config.provider}. ` +
`Supported providers: openai, ollama, cohere, mistral, google`,
`Supported providers: openai, ollama`,
);
}
}
11 changes: 2 additions & 9 deletions packages/embedding/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
// Provider
// =============================================================================

export type EmbeddingProvider =
| "openai"
| "ollama"
| "cohere"
| "mistral"
| "google";
export type EmbeddingProvider = "openai" | "ollama";

// =============================================================================
// Config
Expand All @@ -25,13 +20,11 @@ export interface EmbeddingConfig {
export interface EmbeddingOptions {
/** Max tokens per text (truncates longer inputs) */
maxTokens?: number;
/** Number of texts per embedding API call */
batchSize?: number;
/** Timeout per embedding API call in milliseconds */
timeoutMs?: number;
/** Max retries on transient failures */
maxRetries?: number;
/** Max concurrent embedding API calls */
/** Max concurrent chunk requests when embedding many values (default: Infinity) */
maxParallelCalls?: number;
}

Expand Down