Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions src/api/providers/__tests__/native-ollama.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
import { NativeOllamaHandler } from "../native-ollama"
import { ApiHandlerOptions } from "../../../shared/api"
import { getOllamaModels } from "../fetchers/ollama"
import { getApiRequestTimeout } from "../utils/timeout-config"

// Mock the ollama package
const mockChat = vitest.fn()
const { mockChat, MockOllama } = vitest.hoisted(() => {
const mockChat = vitest.fn()
const MockOllama = vitest.fn().mockImplementation(() => ({
chat: mockChat,
}))
return { mockChat, MockOllama }
})
vitest.mock("ollama", () => {
return {
Ollama: vitest.fn().mockImplementation(() => ({
chat: mockChat,
})),
Ollama: MockOllama,
Message: vitest.fn(),
}
})
Expand All @@ -20,6 +25,13 @@ vitest.mock("../fetchers/ollama", () => ({
getOllamaModels: vitest.fn(),
}))

// Mock the timeout config
vitest.mock("../utils/timeout-config", () => ({
getApiRequestTimeout: vitest.fn(),
}))

const mockGetApiRequestTimeout = vitest.mocked(getApiRequestTimeout)

const mockGetOllamaModels = vitest.mocked(getOllamaModels)

describe("NativeOllamaHandler", () => {
Expand All @@ -28,6 +40,9 @@ describe("NativeOllamaHandler", () => {
beforeEach(() => {
vitest.clearAllMocks()

// Default mock for timeout config (600s = 600000ms)
mockGetApiRequestTimeout.mockReturnValue(600_000)

// Default mock for getOllamaModels
mockGetOllamaModels.mockResolvedValue({
llama2: {
Expand Down Expand Up @@ -605,4 +620,64 @@ describe("NativeOllamaHandler", () => {
expect(firstEndIndex).toBeGreaterThan(lastPartialIndex)
})
})

describe("timeout configuration", () => {
it("should pass a custom fetch with timeout to the Ollama client", async () => {
mockGetApiRequestTimeout.mockReturnValue(900_000) // 900s

// Create a new handler to trigger ensureClient with the mocked timeout
const options: ApiHandlerOptions = {
apiModelId: "llama2",
ollamaModelId: "llama2",
ollamaBaseUrl: "http://localhost:11434",
}

const timeoutHandler = new NativeOllamaHandler(options)

mockChat.mockImplementation(async function* () {
yield { message: { content: "Response" } }
})

const stream = timeoutHandler.createMessage("System", [{ role: "user" as const, content: "Test" }])
for await (const _ of stream) {
// consume stream
}

// Verify Ollama constructor was called with a fetch option
expect(MockOllama).toHaveBeenCalledWith(
expect.objectContaining({
host: "http://localhost:11434",
fetch: expect.any(Function),
}),
)
})

it("should not pass custom fetch when timeout is undefined", async () => {
mockGetApiRequestTimeout.mockReturnValue(undefined)

const options: ApiHandlerOptions = {
apiModelId: "llama2",
ollamaModelId: "llama2",
ollamaBaseUrl: "http://localhost:11434",
}

const timeoutHandler = new NativeOllamaHandler(options)

mockChat.mockImplementation(async function* () {
yield { message: { content: "Response" } }
})

const stream = timeoutHandler.createMessage("System", [{ role: "user" as const, content: "Test" }])
for await (const _ of stream) {
// consume stream
}

// Verify Ollama constructor was called WITHOUT a fetch option
expect(MockOllama).toHaveBeenCalledWith(
expect.not.objectContaining({
fetch: expect.any(Function),
}),
)
})
})
})
16 changes: 15 additions & 1 deletion src/api/providers/native-ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { BaseProvider } from "./base-provider"
import type { ApiHandlerOptions } from "../../shared/api"
import { getOllamaModels } from "./fetchers/ollama"
import { TagMatcher } from "../../utils/tag-matcher"
import { getApiRequestTimeout } from "./utils/timeout-config"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"

interface OllamaChatOptions {
Expand Down Expand Up @@ -160,7 +161,20 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio
try {
const clientOptions: OllamaOptions = {
host: this.options.ollamaBaseUrl || "http://localhost:11434",
// Note: The ollama npm package handles timeouts internally
}

// Apply configurable timeout via custom fetch wrapper.
// The ollama npm package uses Node.js native fetch (Undici) which
// defaults to a 300s (5 minute) timeout. This respects the user's
// apiRequestTimeout setting (default 600s) to support slow inference.
const timeoutMs = getApiRequestTimeout()
if (timeoutMs) {
clientOptions.fetch = ((url: RequestInfo | URL, init?: RequestInit) => {
return fetch(url, {
...init,
signal: init?.signal ?? AbortSignal.timeout(timeoutMs),
})
}) as typeof fetch
}

// Add API key if provided (for Ollama cloud or authenticated instances)
Expand Down
Loading