Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions Sources/AgentRunKit/LLM/GoogleAuthService.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import Foundation

// MARK: - ADC Credential File

private struct ADCCredentials: Decodable {
let type: String
let clientId: String
let clientSecret: String
let refreshToken: String

enum CodingKeys: String, CodingKey {
case type
case clientId = "client_id"
case clientSecret = "client_secret"
case refreshToken = "refresh_token"
}
}

// MARK: - Token Response

private struct TokenResponse: Decodable {
let accessToken: String
let expiresIn: Int
let tokenType: String

enum CodingKeys: String, CodingKey {
case accessToken = "access_token"
case expiresIn = "expires_in"
case tokenType = "token_type"
}
}

/// Manages Google OAuth2 tokens from Application Default Credentials (ADC).
///
/// Reads `~/.config/gcloud/application_default_credentials.json` (created by
/// `gcloud auth application-default login`) and transparently refreshes access
/// tokens as needed.
///
/// Thread-safe via `actor` isolation — only one refresh request can be in
/// flight at a time.
public actor GoogleAuthService {
// MARK: - Errors

public enum GoogleAuthError: Error, LocalizedError, Sendable {
case credentialsFileNotFound(path: String)
case unsupportedCredentialType(String)
case refreshFailed(statusCode: Int, body: String)
case decodingFailed(String)

public var errorDescription: String? {
switch self {
case let .credentialsFileNotFound(path):
"Google ADC credentials not found at \(path). Run `gcloud auth application-default login`."
case let .unsupportedCredentialType(type):
"Unsupported ADC credential type: \(type). Only 'authorized_user' is supported."
case let .refreshFailed(code, body):
"Token refresh failed (HTTP \(code)): \(body)"
case let .decodingFailed(message):
"Failed to decode ADC credentials: \(message)"
}
}
}

// MARK: - State

private let clientID: String
private let clientSecret: String
private let refreshToken: String
private let session: URLSession

private var cachedAccessToken: String?
private var tokenExpiry: Date?

/// Refresh the token when it has fewer than this many seconds remaining.
private let refreshMargin: TimeInterval = 300 // 5 minutes

private static let tokenEndpoint = URL(string: "https://oauth2.googleapis.com/token")!

// MARK: - Init

/// Creates an auth service by reading the ADC file at the default path.
public init(session: URLSession = .shared) throws {
try self.init(credentialsPath: Self.defaultCredentialsPath(), session: session)
}

/// Creates an auth service by reading the ADC file at a custom path.
public init(credentialsPath: String, session: URLSession = .shared) throws {
guard FileManager.default.fileExists(atPath: credentialsPath) else {
throw GoogleAuthError.credentialsFileNotFound(path: credentialsPath)
}
let data: Data
do {
data = try Data(contentsOf: URL(fileURLWithPath: credentialsPath))
} catch {
throw GoogleAuthError.decodingFailed("Failed to read file: \(error.localizedDescription)")
}
let credentials: ADCCredentials
do {
credentials = try JSONDecoder().decode(ADCCredentials.self, from: data)
} catch {
throw GoogleAuthError.decodingFailed(error.localizedDescription)
}
guard credentials.type == "authorized_user" else {
throw GoogleAuthError.unsupportedCredentialType(credentials.type)
}
clientID = credentials.clientId
clientSecret = credentials.clientSecret
refreshToken = credentials.refreshToken
self.session = session
}

// MARK: - Public API

/// Returns a valid access token, refreshing if necessary.
public func accessToken() async throws -> String {
if let token = cachedAccessToken,
let expiry = tokenExpiry,
Date() < expiry.addingTimeInterval(-refreshMargin) {
return token
}
return try await refreshAccessToken()
}

// MARK: - Private

private func refreshAccessToken() async throws -> String {
var request = URLRequest(url: Self.tokenEndpoint)
request.httpMethod = "POST"
request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type")

let body = [
"client_id=\(urlEncode(clientID))",
"client_secret=\(urlEncode(clientSecret))",
"refresh_token=\(urlEncode(refreshToken))",
"grant_type=refresh_token",
].joined(separator: "&")
request.httpBody = Data(body.utf8)

let (data, response) = try await session.data(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
throw GoogleAuthError.refreshFailed(statusCode: 0, body: "Invalid response")
}
guard httpResponse.statusCode == 200 else {
let responseBody = String(data: data, encoding: .utf8) ?? "<unreadable>"
throw GoogleAuthError.refreshFailed(statusCode: httpResponse.statusCode, body: responseBody)
}

let tokenResponse = try JSONDecoder().decode(TokenResponse.self, from: data)
cachedAccessToken = tokenResponse.accessToken
tokenExpiry = Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn))
return tokenResponse.accessToken
}

private func urlEncode(_ string: String) -> String {
string.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? string
}

/// The default path to the ADC credentials file.
public static func defaultCredentialsPath() -> String {
let home = FileManager.default.homeDirectoryForCurrentUser.path
return "\(home)/.config/gcloud/application_default_credentials.json"
}

/// Whether an ADC credentials file exists at the default path.
public static func credentialsAvailable() -> Bool {
FileManager.default.fileExists(atPath: defaultCredentialsPath())
}
}
214 changes: 214 additions & 0 deletions Sources/AgentRunKit/LLM/VertexAnthropicClient.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import Foundation

/// An LLM client for Anthropic Claude models served via Vertex AI.
///
/// Uses OAuth2 Bearer token authentication (via ``GoogleAuthService`` or a custom
/// token provider closure) instead of Anthropic API key authentication.
///
/// The wire format is the standard Anthropic Messages API with a
/// `"anthropic_version": "vertex-2023-10-16"` field injected into the request body.
/// Response parsing and SSE streaming are delegated to an internal ``AnthropicClient``.
///
/// ```swift
/// let auth = try GoogleAuthService()
/// let client = VertexAnthropicClient(
/// projectID: "my-project",
/// location: "us-east5",
/// model: "claude-sonnet-4-6",
/// authService: auth
/// )
/// ```
public struct VertexAnthropicClient: LLMClient, Sendable {
public let contextWindowSize: Int?

let anthropic: AnthropicClient
private let projectID: String
private let location: String
private let model: String
private let tokenProvider: @Sendable () async throws -> String
private let session: URLSession
private let retryPolicy: RetryPolicy

public init(
projectID: String,
location: String,
model: String,
tokenProvider: @Sendable @escaping () async throws -> String,
maxTokens: Int = 8192,
contextWindowSize: Int? = nil,
session: URLSession = .shared,
retryPolicy: RetryPolicy = .default,
reasoningConfig: ReasoningConfig? = nil,
interleavedThinking: Bool = true,
cachingEnabled: Bool = false
) {
self.projectID = projectID
self.location = location
self.model = model
self.tokenProvider = tokenProvider
self.session = session
self.retryPolicy = retryPolicy
self.contextWindowSize = contextWindowSize
anthropic = AnthropicClient(
apiKey: "",
model: model,
maxTokens: maxTokens,
contextWindowSize: contextWindowSize,
session: session,
retryPolicy: retryPolicy,
reasoningConfig: reasoningConfig,
interleavedThinking: interleavedThinking,
cachingEnabled: cachingEnabled
)
}

/// Convenience initializer that uses a ``GoogleAuthService`` for authentication.
public init(
projectID: String,
location: String,
model: String,
authService: GoogleAuthService,
maxTokens: Int = 8192,
contextWindowSize: Int? = nil,
session: URLSession = .shared,
retryPolicy: RetryPolicy = .default,
reasoningConfig: ReasoningConfig? = nil,
interleavedThinking: Bool = true,
cachingEnabled: Bool = false
) {
self.init(
projectID: projectID,
location: location,
model: model,
tokenProvider: { try await authService.accessToken() },
maxTokens: maxTokens,
contextWindowSize: contextWindowSize,
session: session,
retryPolicy: retryPolicy,
reasoningConfig: reasoningConfig,
interleavedThinking: interleavedThinking,
cachingEnabled: cachingEnabled
)
}

// MARK: - LLMClient

public func generate(
messages: [ChatMessage],
tools: [ToolDefinition],
responseFormat: ResponseFormat?,
requestContext: RequestContext?
) async throws -> AssistantMessage {
if responseFormat != nil {
throw AgentError.llmError(.other("VertexAnthropicClient does not support responseFormat"))
}
let request = try anthropic.buildRequest(
messages: messages,
tools: tools,
extraFields: requestContext?.extraFields ?? [:]
)
let token = try await tokenProvider()
let urlRequest = try buildVertexURLRequest(
VertexAnthropicRequest(inner: request), stream: false, token: token
)
let (data, httpResponse) = try await HTTPRetry.performData(
urlRequest: urlRequest, session: session, retryPolicy: retryPolicy
)
requestContext?.onResponse?(httpResponse)
return try anthropic.parseResponse(data)
}

public func stream(
messages: [ChatMessage],
tools: [ToolDefinition],
requestContext: RequestContext?
) -> AsyncThrowingStream<StreamDelta, Error> {
AsyncThrowingStream { continuation in
let task = Task {
do {
try await performStreamRequest(
messages: messages,
tools: tools,
extraFields: requestContext?.extraFields ?? [:],
onResponse: requestContext?.onResponse,
continuation: continuation
)
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in task.cancel() }
}
}

// MARK: - Streaming

private func performStreamRequest(
messages: [ChatMessage],
tools: [ToolDefinition],
extraFields: [String: JSONValue],
onResponse: (@Sendable (HTTPURLResponse) -> Void)?,
continuation: AsyncThrowingStream<StreamDelta, Error>.Continuation
) async throws {
let request = try anthropic.buildRequest(
messages: messages, tools: tools,
stream: true, extraFields: extraFields
)
let token = try await tokenProvider()
let urlRequest = try buildVertexURLRequest(
VertexAnthropicRequest(inner: request), stream: true, token: token
)
let (bytes, httpResponse) = try await HTTPRetry.performStream(
urlRequest: urlRequest, session: session, retryPolicy: retryPolicy
)
onResponse?(httpResponse)

let state = AnthropicStreamState()

try await processSSEStream(
bytes: bytes,
stallTimeout: retryPolicy.streamStallTimeout
) { line in
try await anthropic.handleSSELine(
line, state: state, continuation: continuation
)
}
continuation.finish()
}

// MARK: - URL Construction

func buildVertexURLRequest(
_ request: VertexAnthropicRequest,
stream: Bool,
token: String
) throws -> URLRequest {
let action = stream ? "streamRawPredict" : "rawPredict"
let basePath = "v1/projects/\(projectID)/locations/\(location)"
+ "/publishers/anthropic/models/\(model):\(action)"
let baseURL = URL(string: "https://\(location)-aiplatform.googleapis.com")!
let url = baseURL.appendingPathComponent(basePath)

let headers = ["Authorization": "Bearer \(token)"]
return try buildJSONPostRequest(url: url, body: request, headers: headers)
}
}

// MARK: - Vertex Anthropic Request Wrapper

/// Wraps an ``AnthropicRequest`` and injects `"anthropic_version": "vertex-2023-10-16"`
/// into the encoded JSON body for Vertex AI compatibility.
struct VertexAnthropicRequest: Encodable {
static let vertexAnthropicVersion = "vertex-2023-10-16"

let inner: AnthropicRequest

func encode(to encoder: any Encoder) throws {
try inner.encode(to: encoder)
var container = encoder.container(keyedBy: DynamicCodingKey.self)
try container.encode(
Self.vertexAnthropicVersion,
forKey: DynamicCodingKey("anthropic_version")
)
}
}
Loading