Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions backend/src/llm-apis/prompt-cache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/**
* Simple in-memory cache for LLM prompts and responses
* Cost optimization: Cache common system prompts to leverage provider caching
*/

import crypto from 'crypto'
import { logger } from '../util/logger'

interface CacheEntry<T> {
value: T
timestamp: number
hits: number
}

interface CacheStats {
hits: number
misses: number
entries: number
hitRate: number
}

export class PromptCache<T = any> {
private cache = new Map<string, CacheEntry<T>>()
private defaultTtl: number
private maxSize: number
private stats = { hits: 0, misses: 0 }

constructor(ttlMs: number = 30 * 60 * 1000, maxSize: number = 1000) { // 30 min default
this.defaultTtl = ttlMs
this.maxSize = maxSize
}

/**
* Generate cache key from content
*/
private generateKey(content: string | object): string {
const str = typeof content === 'string' ? content : JSON.stringify(content)
return crypto.createHash('sha256').update(str).digest('hex').substring(0, 16)
}

/**
* Check if cache entry is expired
*/
private isExpired(entry: CacheEntry<T>, ttl?: number): boolean {
const maxAge = ttl || this.defaultTtl
return Date.now() - entry.timestamp > maxAge
}

/**
* Evict oldest entries if cache is full
*/
private evictIfNeeded(): void {
if (this.cache.size >= this.maxSize) {
// Remove oldest entries (simple FIFO eviction)
const oldestKey = this.cache.keys().next().value
if (oldestKey) {
this.cache.delete(oldestKey)
}
}
}

/**
* Get value from cache
*/
get(key: string | object, ttl?: number): T | null {
const cacheKey = typeof key === 'string' ? key : this.generateKey(key)
const entry = this.cache.get(cacheKey)

if (!entry) {
this.stats.misses++
return null
}

if (this.isExpired(entry, ttl)) {
this.cache.delete(cacheKey)
this.stats.misses++
return null
}

entry.hits++
this.stats.hits++
return entry.value
}

/**
* Set value in cache
*/
set(key: string | object, value: T, ttl?: number): void {
const cacheKey = typeof key === 'string' ? key : this.generateKey(key)

this.evictIfNeeded()

this.cache.set(cacheKey, {
value,
timestamp: Date.now(),
hits: 0
})
}

/**
* Get or compute value with automatic caching
*/
async getOrCompute<R = T>(
key: string | object,
computeFn: () => Promise<R>,
ttl?: number
): Promise<R> {
const cached = this.get(key, ttl) as R
if (cached !== null) {
return cached
}

const computed = await computeFn()
this.set(key, computed as unknown as T, ttl)
return computed
}

/**
* Clear cache
*/
clear(): void {
this.cache.clear()
this.stats = { hits: 0, misses: 0 }
}

/**
* Get cache statistics
*/
getStats(): CacheStats {
return {
hits: this.stats.hits,
misses: this.stats.misses,
entries: this.cache.size,
hitRate: this.stats.hits + this.stats.misses > 0
? this.stats.hits / (this.stats.hits + this.stats.misses)
: 0
}
}

/**
* Clean expired entries
*/
cleanup(): number {
let cleaned = 0
for (const [key, entry] of this.cache.entries()) {
if (this.isExpired(entry)) {
this.cache.delete(key)
cleaned++
}
}
return cleaned
}
}

// Global cache instances for different types of content
export const systemPromptCache = new PromptCache<string>(60 * 60 * 1000) // 1 hour TTL
export const fileTreeCache = new PromptCache<string>(30 * 60 * 1000) // 30 min TTL
export const responseCache = new PromptCache<any>(15 * 60 * 1000) // 15 min TTL

// Periodic cleanup
setInterval(() => {
const cleaned = systemPromptCache.cleanup() +
fileTreeCache.cleanup() +
responseCache.cleanup()

if (cleaned > 0) {
logger.debug(`Cleaned ${cleaned} expired cache entries`)
}
}, 5 * 60 * 1000) // Every 5 minutes

// Log cache stats periodically
setInterval(() => {
const systemStats = systemPromptCache.getStats()
const fileTreeStats = fileTreeCache.getStats()
const responseStats = responseCache.getStats()

logger.info({
systemPromptCache: systemStats,
fileTreeCache: fileTreeStats,
responseCache: responseStats
}, 'Cache performance stats')
}, 30 * 60 * 1000) // Every 30 minutes
153 changes: 149 additions & 4 deletions backend/src/llm-apis/vercel-ai-sdk/ai-sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { checkLiveUserInput, getLiveUserInputIds } from '../../live-user-inputs'
import { logger } from '../../util/logger'
import { saveMessage } from '../message-cost-tracker'
import { openRouterLanguageModel } from '../openrouter'
import { systemPromptCache, responseCache } from '../prompt-cache'
import { vertexFinetuned } from './vertex-finetuned'

import type {
Expand All @@ -36,6 +37,93 @@ import type {
import type { LanguageModel } from 'ai'
import type { z } from 'zod/v4'

// Cost optimization: Task-based parameter optimization
interface TaskBasedParameters {
temperature: number
maxTokens: number
}

type TaskType = 'file-operations' | 'simple-query' | 'code-generation' | 'analysis' | 'creative' | 'complex-reasoning' | 'default'

const getOptimalParametersByTask = (taskType: TaskType): TaskBasedParameters => {
const paramConfigs: Record<TaskType, TaskBasedParameters> = {
'file-operations': { temperature: 0.0, maxTokens: 1000 }, // Deterministic file ops
'simple-query': { temperature: 0.0, maxTokens: 500 }, // Quick factual responses
'code-generation': { temperature: 0.1, maxTokens: 2000 }, // Consistent code output
'analysis': { temperature: 0.3, maxTokens: 1500 }, // Balanced analysis
'creative': { temperature: 0.8, maxTokens: 4000 }, // High creativity
'complex-reasoning': { temperature: 0.4, maxTokens: 3000 }, // Deep thinking
'default': { temperature: 0.3, maxTokens: 2000 } // Balanced default
}

return paramConfigs[taskType] || paramConfigs['default']
}

const detectTaskTypeFromMessages = (messages: Message[]): TaskType => {
const lastMessage = messages[messages.length - 1]
const content = typeof lastMessage?.content === 'string'
? lastMessage.content.toLowerCase()
: JSON.stringify(lastMessage?.content || '').toLowerCase()

// Tool-based detection
if (content.includes('write_file') || content.includes('str_replace') || content.includes('read_files')) {
return 'file-operations'
}
if (content.includes('run_terminal_command') || content.includes('browser_logs')) {
return 'file-operations'
}
if (content.includes('spawn_agents') || content.includes('think_deeply')) {
return 'complex-reasoning'
}
if (content.includes('code_search') || content.includes('create_plan')) {
return 'analysis'
}

// Content-based detection
if (content.length < 100) {
return 'simple-query'
}
if (content.includes('write') && (content.includes('code') || content.includes('function') || content.includes('class'))) {
return 'code-generation'
}
if (content.includes('analyze') || content.includes('explain') || content.includes('review')) {
return 'analysis'
}
if (content.includes('creative') || content.includes('story') || content.includes('poem')) {
return 'creative'
}
if (content.includes('complex') || content.includes('architecture') || content.includes('design')) {
return 'complex-reasoning'
}

return 'default'
}

// Cost optimization: Cache system prompts and common responses
const isCacheableSystemPrompt = (messages: Message[]): boolean => {
// Cache system prompts (first message is usually system)
if (messages.length > 0 && messages[0].role === 'system') {
const content = typeof messages[0].content === 'string'
? messages[0].content
: JSON.stringify(messages[0].content || '')

// Cache if it's a system prompt > 500 chars (likely to be reused)
return content.length > 500
}
return false
}

const generateCacheKey = (messages: Message[], model: string, options: any): string => {
// Create cache key from messages + model + key parameters
const cacheableContent = {
messages: messages.slice(0, 2), // Only first 2 messages (system + first user)
model,
temperature: (options as any).temperature,
maxTokens: (options as any).maxTokens
}
return JSON.stringify(cacheableContent)
}

// TODO: We'll want to add all our models here!
const modelToAiSDKModel = (model: Model): LanguageModel => {
if (
Expand Down Expand Up @@ -100,8 +188,19 @@ export const promptAiSdkStream = async function* (

let aiSDKModel = modelToAiSDKModel(options.model)

const response = streamText({
// Cost optimization: Apply task-based parameter optimization
const taskType = detectTaskTypeFromMessages(options.messages)
const optimalParams = getOptimalParametersByTask(taskType)

// Only override if not explicitly set by caller
const finalOptions = {
...options,
temperature: (options as any).temperature ?? optimalParams.temperature,
maxTokens: (options as any).maxTokens ?? optimalParams.maxTokens,
}

const response = streamText({
...finalOptions,
model: aiSDKModel,
maxRetries: options.maxRetries,
messages: convertCbToModelMessages(options),
Expand Down Expand Up @@ -262,14 +361,49 @@ export const promptAiSdk = async function (
const startTime = Date.now()
let aiSDKModel = modelToAiSDKModel(options.model)

const response = await generateText({
// Cost optimization: Apply task-based parameter optimization
const taskType = detectTaskTypeFromMessages(options.messages)
const optimalParams = getOptimalParametersByTask(taskType)

// Only override if not explicitly set by caller
const finalOptions = {
...options,
temperature: (options as any).temperature ?? optimalParams.temperature,
maxTokens: (options as any).maxTokens ?? optimalParams.maxTokens,
}

// Cost optimization: Check cache for similar requests
const cacheKey = generateCacheKey(options.messages, options.model, finalOptions)
const cachedResponse = responseCache.get(cacheKey)

if (cachedResponse && isCacheableSystemPrompt(options.messages)) {
logger.debug({ cacheKey: cacheKey.substring(0, 32) + '...' }, 'Cache hit for prompt')

// Return cached response but still track for cost accounting
const creditsUsed = 0 // Cache hits are free!
if (options.onCostCalculated) {
await options.onCostCalculated(creditsUsed)
}

return cachedResponse
}

const response = await generateText({
...finalOptions,
model: aiSDKModel,
messages: convertCbToModelMessages(options),
})

const content = response.text

// Cache successful responses for cacheable system prompts
if (isCacheableSystemPrompt(options.messages) && content.length > 0) {
responseCache.set(cacheKey, content, 15 * 60 * 1000) // 15 min cache
logger.debug({ cacheKey: cacheKey.substring(0, 32) + '...' }, 'Cached prompt response')
}

const inputTokens = response.usage.inputTokens || 0
const outputTokens = response.usage.inputTokens || 0
const outputTokens = response.usage.outputTokens || 0

const creditsUsedPromise = saveMessage({
messageId: generateCompactId(),
Expand Down Expand Up @@ -334,8 +468,19 @@ export const promptAiSdkStructured = async function <T>(options: {
const startTime = Date.now()
let aiSDKModel = modelToAiSDKModel(options.model)

const responsePromise = generateObject<z.ZodType<T>, 'object'>({
// Cost optimization: Apply task-based parameter optimization
const taskType = detectTaskTypeFromMessages(options.messages)
const optimalParams = getOptimalParametersByTask(taskType)

// Only override if not explicitly set by caller
const finalOptions = {
...options,
temperature: (options as any).temperature ?? optimalParams.temperature,
maxTokens: (options as any).maxTokens ?? optimalParams.maxTokens,
}

const responsePromise = generateObject<z.ZodType<T>, 'object'>({
...finalOptions,
model: aiSDKModel,
output: 'object',
messages: convertCbToModelMessages(options),
Expand Down
Loading
Loading