Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions typescript/src/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ const messageSchema = z.object({
});

export const chatRequestSchema = z.object({
messages: z.array(messageSchema),
temperature: z.number().optional(),
max_tokens: z.number().optional(),
messages: z.array(messageSchema),
temperature: z.number().optional(),
max_tokens: z.number().optional(),
});
type ChatRequest = z.infer<typeof chatRequestSchema>;

interface ChatResponse {
model: string;
content: string;
confidence?: number;
}

export interface Chunk {
Expand All @@ -45,29 +46,36 @@ export async function getChatResponse(

return {
model: "gpt-4o",
confidence: response.confidence,
content,
};
}

export async function streamChatResponse(
chatRequest: ChatRequest
): Promise<{ model: string; stream: AsyncIterable<Chunk> }> {
export async function streamChatResponse(chatRequest: ChatRequest): Promise<{
model: string;
confidence?: number;
stream: AsyncIterable<Chunk>;
}> {
try {
const stream = await openAiChatCompletion({
const { confidence, stream } = await openAiChatCompletion({
model: "gpt-4o",
messages: chatRequest.messages,
stream: true,
temperature: chatRequest.temperature,
max_tokens: chatRequest.max_tokens,
});
return { model: "gpt-4o", stream: chunksFromOpenAiStream(stream) };
return {
model: "gpt-4o",
confidence,
stream: chunksFromOpenAiStream(stream),
};
} catch (e) {
console.warn("Error streaming chat response from OpenAI", e);

const systemMessage = chatRequest.messages.find(
(message) => message.role === "system"
);
const iterator = await bedrockChatCompletion({
const { confidence, stream } = await bedrockChatCompletion({
modelId: "anthropic.claude-3-sonnet-20240229-v1:0",
system: systemMessage
? [
Expand All @@ -90,11 +98,12 @@ export async function streamChatResponse(
inferenceConfig: {
temperature: chatRequest.temperature,
maxTokens: chatRequest.max_tokens,
}
},
});
return {
model: "claude-3-sonnet",
stream: chunksFromBedrockStream(iterator),
confidence,
stream: chunksFromBedrockStream(stream),
};
}
}
Expand Down
12 changes: 8 additions & 4 deletions typescript/src/stubs/stub-bedrock-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ async function* stubGenerator(): AsyncGenerator<ConverseStreamOutput> {
}
}

export function bedrockChatCompletion(
_: ConverseStreamCommandInput
): Promise<AsyncIterable<ConverseStreamOutput>> {
return Promise.resolve(stubGenerator());
export function bedrockChatCompletion(_: ConverseStreamCommandInput): {
confidence: number;
stream: AsyncIterable<ConverseStreamOutput>;
} {
return {
confidence: (Math.floor(Math.random() * 10) + 1) / 10,
stream: stubGenerator(),
};
}
30 changes: 20 additions & 10 deletions typescript/src/stubs/stub-openai-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ import { Stream } from "openai/streaming";

const chunks = ["Hello ", "from ", "OpenAI!"];

async function* stubGenerator(model: string): AsyncGenerator<OpenAI.Chat.ChatCompletionChunk> {
async function* stubGenerator(
model: string
): AsyncGenerator<OpenAI.Chat.ChatCompletionChunk> {
for (const [index, output] of chunks.entries()) {
yield {
id: index.toString(),
choices: [{
index: 0,
delta: {
role: "assistant",
content: output,
choices: [
{
index: 0,
delta: {
role: "assistant",
content: output,
},
finish_reason: null,
},
finish_reason: null,
}],
],
model: model,
object: "chat.completion.chunk",
created: Date.now(),
Expand All @@ -24,6 +28,12 @@ async function* stubGenerator(model: string): AsyncGenerator<OpenAI.Chat.ChatCom

export function openAiChatCompletion(
input: OpenAI.Chat.ChatCompletionCreateParamsStreaming
): Promise<Stream<OpenAI.Chat.ChatCompletionChunk>> {
return Promise.resolve(new Stream(() => stubGenerator(input.model), new AbortController()));
): {
confidence: number;
stream: Stream<OpenAI.Chat.ChatCompletionChunk>;
} {
return {
confidence: (Math.floor(Math.random() * 10) + 1) / 10,
stream: new Stream(() => stubGenerator(input.model), new AbortController()),
};
}