Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions packages/apple-llm/src/ai-sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,55 +41,62 @@ export function createAppleProvider({
}
provider.isAvailable = () => NativeAppleLLM.isAvailable()
provider.languageModel = createLanguageModel
provider.textEmbeddingModel = (modelId: string = 'NLContextualEmbedding') => {
if (modelId !== 'NLContextualEmbedding') {
throw new Error('Only the default model is supported')
}
return new AppleTextEmbeddingModel()
provider.textEmbeddingModel = (options: AppleEmbeddingOptions = {}) => {
return new AppleTextEmbeddingModel(options)
}
provider.imageModel = () => {
throw new Error('Image generation models are not supported by Apple LLM')
}
provider.transcriptionModel = (modelId: string = 'SpeechTranscriber') => {
if (modelId !== 'SpeechTranscriber') {
throw new Error('Only the default model is supported')
}
return new AppleTranscriptionModel()
provider.transcriptionModel = (options: AppleTranscriptionOptions = {}) => {
return new AppleTranscriptionModel(options)
}
provider.speechModel = (modelId: string = 'AVSpeechSynthesizer') => {
if (modelId !== 'AVSpeechSynthesizer') {
throw new Error('Only the default model is supported')
}
return new AppleSpeechModel()
provider.speechModel = (options: AppleSpeechOptions = {}) => {
return new AppleSpeechModel(options)
}
return provider
}

export const apple = createAppleProvider()

export interface AppleTranscriptionOptions {
language?: string
}

class AppleTranscriptionModel implements TranscriptionModelV2 {
readonly specificationVersion = 'v2'
readonly provider = 'apple'

readonly modelId = 'SpeechTranscriber'

private prepared = false
private language: string

constructor(options: AppleTranscriptionOptions = {}) {
this.language = options.language ?? NativeAppleUtils.getCurrentLocale()
}

async prepare(): Promise<void> {
await NativeAppleTranscription.prepare(this.language)
this.prepared = true
}

async doGenerate(options: TranscriptionModelV2CallOptions) {
try {
let audio = options.audio
if (typeof audio === 'string') {
audio = this.base64ToArrayBuffer(audio)
}

const language = String(
options.providerOptions?.apple?.language ??
NativeAppleUtils.getCurrentLocale()
)

await NativeAppleTranscription.prepare(language)
if (!this.prepared) {
console.warn(
'[apple-llm] Model not prepared. Call prepare() ahead of time to optimize performance.'
)
await this.prepare()
}

const transcriptionResult = await NativeAppleTranscription.transcribe(
audio.buffer,
language
this.language
)

const transcriptionText = transcriptionResult.segments
Expand All @@ -99,7 +106,7 @@ class AppleTranscriptionModel implements TranscriptionModelV2 {
return {
text: transcriptionText,
segments: transcriptionResult.segments,
language,
language: this.language,
durationInSeconds: transcriptionResult.duration,
warnings: [],
response: {
Expand All @@ -124,19 +131,27 @@ class AppleTranscriptionModel implements TranscriptionModelV2 {
}
}

export interface AppleSpeechOptions {
language?: string
}

class AppleSpeechModel implements SpeechModelV2 {
readonly specificationVersion = 'v2'
readonly provider = 'apple'

readonly modelId = 'AVSpeechSynthesizer'

async doGenerate(options: SpeechModelV2CallOptions) {
const language = String(
options.language ?? NativeAppleUtils.getCurrentLocale()
)
private language: string

constructor(options: AppleSpeechOptions = {}) {
this.language = options.language ?? NativeAppleUtils.getCurrentLocale()
}

async prepare(): Promise<void> {}

async doGenerate(options: SpeechModelV2CallOptions) {
const speechOptions = {
language,
language: this.language,
voice: options.voice,
}

Expand All @@ -162,6 +177,10 @@ class AppleSpeechModel implements SpeechModelV2 {
}
}

export interface AppleEmbeddingOptions {
language?: string
}

class AppleTextEmbeddingModel implements EmbeddingModelV2<string> {
readonly specificationVersion = 'v2'
readonly provider = 'apple'
Expand All @@ -170,22 +189,29 @@ class AppleTextEmbeddingModel implements EmbeddingModelV2<string> {
readonly maxEmbeddingsPerCall = Infinity
readonly supportsParallelCalls = false

async doEmbed(options: {
values: string[]
providerOptions?: {
apple?: {
language?: string
}
private prepared = false
private language: string

constructor(options: AppleEmbeddingOptions = {}) {
this.language = options.language ?? NativeAppleUtils.getCurrentLocale()
}

async prepare(): Promise<void> {
await NativeAppleEmbeddings.prepare(this.language)
this.prepared = true
}

async doEmbed(options: { values: string[] }) {
if (!this.prepared) {
console.warn(
'[apple-llm] Model not prepared. Call prepare() ahead of time to optimize performance.'
)
await this.prepare()
}
}) {
const language = String(
options.providerOptions?.apple?.language ??
NativeAppleUtils.getCurrentLocale()
)
await NativeAppleEmbeddings.prepare(language)

const embeddings = await NativeAppleEmbeddings.generateEmbeddings(
options.values,
language
this.language
)
return {
embeddings,
Expand All @@ -206,6 +232,8 @@ class AppleLLMChatLanguageModel implements LanguageModelV2 {
this.tools = tools
}

async prepare(): Promise<void> {}

private prepareMessages(messages: LanguageModelV2Prompt): AppleMessage[] {
return messages.map((message): AppleMessage => {
const content = Array.isArray(message.content)
Expand Down
Loading
Loading