Skip to content

Commit ed69ec4

Browse files
committed
sdk: implement promptAiSdkStructured
1 parent 7c6c3fa commit ed69ec4

File tree

6 files changed

+84
-19
lines changed

6 files changed

+84
-19
lines changed

common/src/types/contracts/llm.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ export type PromptAiSdkFn = (
6464
) => Promise<string>
6565

6666
export type PromptAiSdkStructuredInput<T> = {
67+
apiKey: string
6768
messages: Message[]
6869
schema: z.ZodType<T>
6970
clientSessionId: string

evals/git-evals/judge-git-eval.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ export async function judgeEvalRun(evalRun: EvalRunLog) {
198198
sessionConnections: {},
199199
logger: console,
200200
trackEvent: () => {},
201+
apiKey: 'unused-api-key',
201202
}).catch((error) => {
202203
console.warn(`Judge ${index + 1} failed:`, error)
203204
return null

evals/git-evals/pick-commits.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ async function screenCommitsWithGpt5(
384384
sessionConnections: {},
385385
logger: console,
386386
trackEvent: () => {},
387+
apiKey: 'unused-api-key',
387388
})
388389

389390
// Handle empty or invalid response

evals/git-evals/post-eval-analysis.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,5 +195,6 @@ export async function analyzeEvalResults(
195195
sessionConnections: {},
196196
logger: console,
197197
trackEvent: () => {},
198+
apiKey: 'unused-api-key',
198199
})
199200
}

evals/git-evals/run-git-evals.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ Explain your reasoning in detail. Do not ask Codebuff to git commit changes.`,
163163
sessionConnections: {},
164164
logger: console,
165165
trackEvent: () => {},
166+
apiKey: 'unused-api-key',
166167
})
167168
} catch (agentError) {
168169
throw new Error(

sdk/src/impl/llm.ts

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,23 @@ import { buildArray } from '@codebuff/common/util/array'
88
import { getErrorObject } from '@codebuff/common/util/error'
99
import { convertCbToModelMessages } from '@codebuff/common/util/messages'
1010
import { StopSequenceHandler } from '@codebuff/common/util/stop-sequence'
11-
import { streamText, APICallError, generateText } from 'ai'
11+
import { streamText, APICallError, generateText, generateObject } from 'ai'
1212

1313
import { WEBSITE_URL } from '../constants'
1414

1515
import type { LanguageModelV2 } from '@ai-sdk/provider'
1616
import type {
1717
PromptAiSdkFn,
1818
PromptAiSdkStreamFn,
19+
PromptAiSdkStructuredInput,
20+
PromptAiSdkStructuredOutput,
1921
} from '@codebuff/common/types/contracts/llm'
2022
import type { ParamsOf } from '@codebuff/common/types/function-params'
2123
import type {
2224
OpenRouterProviderOptions,
2325
OpenRouterUsageAccounting,
2426
} from '@openrouter/ai-sdk-provider'
27+
import type z from 'zod/v4'
2528

2629
function getAiSdkModel(params: {
2730
apiKey: string
@@ -286,22 +289,79 @@ export async function promptAiSdk(
286289
return content
287290
}
288291

289-
console.log(
290-
await promptAiSdk({
291-
apiKey: '12345',
292-
messages: [{ role: 'user', content: 'Hello' }],
293-
clientSessionId: 'test-session',
294-
fingerprintId: 'test-fingerprint',
295-
model: 'openai/gpt-5',
296-
userId: 'test-user-id',
297-
userInputId: '64a2e61f-1fab-4701-8651-7ff7a473e97a',
298-
sendAction: () => {},
299-
logger: console,
300-
trackEvent: () => {},
301-
liveUserInputRecord: {
302-
'test-user-id': ['64a2e61f-1fab-4701-8651-7ff7a473e97a'],
292+
export async function promptAiSdkStructured<T>(
293+
params: PromptAiSdkStructuredInput<T>,
294+
): PromptAiSdkStructuredOutput<T> {
295+
const { logger } = params
296+
297+
if (!checkLiveUserInput(params)) {
298+
logger.info(
299+
{
300+
userId: params.userId,
301+
userInputId: params.userInputId,
302+
liveUserInputId: getLiveUserInputIds(params),
303+
},
304+
'Skipping structured prompt due to canceled user input',
305+
)
306+
return {} as T
307+
}
308+
let aiSDKModel = getAiSdkModel(params)
309+
310+
const response = await generateObject<z.ZodType<T>, 'object'>({
311+
...params,
312+
prompt: undefined,
313+
model: aiSDKModel,
314+
output: 'object',
315+
messages: convertCbToModelMessages(params),
316+
providerOptions: {
317+
codebuff: {
318+
codebuff_metadata: {
319+
run_id: params.userInputId,
320+
client_id: params.clientSessionId,
321+
},
322+
},
303323
},
304-
sessionConnections: { 'test-session': true },
305-
}),
306-
'asdf',
307-
)
324+
})
325+
326+
const content = response.object
327+
328+
const messageId = response.response.id
329+
const providerMetadata = response.providerMetadata ?? {}
330+
const usage = response.usage
331+
let inputTokens = usage.inputTokens || 0
332+
const outputTokens = usage.outputTokens || 0
333+
let cacheReadInputTokens: number = 0
334+
let cacheCreationInputTokens: number = 0
335+
let costOverrideDollars: number | undefined
336+
if (providerMetadata.anthropic) {
337+
cacheReadInputTokens =
338+
typeof providerMetadata.anthropic.cacheReadInputTokens === 'number'
339+
? providerMetadata.anthropic.cacheReadInputTokens
340+
: 0
341+
cacheCreationInputTokens =
342+
typeof providerMetadata.anthropic.cacheCreationInputTokens === 'number'
343+
? providerMetadata.anthropic.cacheCreationInputTokens
344+
: 0
345+
}
346+
if (providerMetadata.openrouter) {
347+
if (providerMetadata.openrouter.usage) {
348+
const openrouterUsage = providerMetadata.openrouter
349+
.usage as OpenRouterUsageAccounting
350+
cacheReadInputTokens =
351+
openrouterUsage.promptTokensDetails?.cachedTokens ?? 0
352+
inputTokens = openrouterUsage.promptTokens - cacheReadInputTokens
353+
354+
costOverrideDollars =
355+
(openrouterUsage.cost ?? 0) +
356+
(openrouterUsage.costDetails?.upstreamInferenceCost ?? 0)
357+
}
358+
}
359+
360+
// Call the cost callback if provided
361+
if (params.onCostCalculated && costOverrideDollars) {
362+
const creditsUsed = costOverrideDollars * (1 + PROFIT_MARGIN)
363+
await params.onCostCalculated(creditsUsed)
364+
}
365+
366+
return content
367+
}

0 commit comments

Comments
 (0)