Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export function DialogAgent() {
return (
<DialogSelect
title="Select agent"
current={local.agent.current().name}
current={local.agent.current()?.name ?? options()[0]?.value ?? ""}
options={options()}
onSelect={(option) => {
local.agent.set(option.value)
Expand Down
27 changes: 21 additions & 6 deletions packages/opencode/src/cli/cmd/tui/component/prompt/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,20 @@ export function Prompt(props: PromptProps) {
// Capture mode before it gets reset
const currentMode = store.mode
const variant = local.model.variant.current()
const agent = local.agent.current()?.name
if (!agent) {
toast.show({
variant: "warning",
message: "No agent is available",
duration: 3000,
})
return
}

if (store.mode === "shell") {
sdk.client.session.shell({
sessionID,
agent: local.agent.current().name,
agent,
model: {
providerID: selectedModel.providerID,
modelID: selectedModel.modelID,
Expand All @@ -555,7 +564,7 @@ export function Prompt(props: PromptProps) {
sessionID,
command: command.slice(1),
arguments: args.join(" "),
agent: local.agent.current().name,
agent,
model: `${selectedModel.providerID}/${selectedModel.modelID}`,
messageID,
variant,
Expand All @@ -571,7 +580,7 @@ export function Prompt(props: PromptProps) {
sessionID,
...selectedModel,
messageID,
agent: local.agent.current().name,
agent,
model: selectedModel,
variant,
parts: [
Expand Down Expand Up @@ -688,10 +697,12 @@ export function Prompt(props: PromptProps) {
return
}

const currentAgentName = createMemo(() => local.agent.current()?.name ?? "")

const highlight = createMemo(() => {
if (keybind.leader) return theme.border
if (store.mode === "shell") return theme.primary
return local.agent.color(local.agent.current().name)
return local.agent.color(currentAgentName())
})

const showVariant = createMemo(() => {
Expand All @@ -702,7 +713,7 @@ export function Prompt(props: PromptProps) {
})

const spinnerDef = createMemo(() => {
const color = local.agent.color(local.agent.current().name)
const color = local.agent.color(currentAgentName())
return {
frames: createFrames({
color,
Expand Down Expand Up @@ -933,7 +944,11 @@ export function Prompt(props: PromptProps) {
/>
<box flexDirection="row" flexShrink={0} paddingTop={1} gap={1}>
<text fg={highlight()}>
{store.mode === "shell" ? "Shell" : Locale.titlecase(local.agent.current().name)}{" "}
{store.mode === "shell"
? "Shell"
: currentAgentName()
? Locale.titlecase(currentAgentName())
: "Agent"}{" "}
</text>
<Show when={store.mode === "normal"}>
<box flexDirection="row" gap={1}>
Expand Down
54 changes: 40 additions & 14 deletions packages/opencode/src/cli/cmd/tui/context/local.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,21 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({

const agent = iife(() => {
const agents = createMemo(() => sync.data.agent.filter((x) => x.mode !== "subagent" && !x.hidden))
const [agentStore, setAgentStore] = createStore<{
current: string
}>({
current: agents()[0].name,
const [agentStore, setAgentStore] = createStore<{ current: string | undefined }>({
current: undefined,
})

createEffect(() => {
const list = agents()
if (list.length === 0) {
if (agentStore.current !== undefined) setAgentStore("current", undefined)
return
}
if (!agentStore.current || !list.some((x) => x.name === agentStore.current)) {
setAgentStore("current", list[0].name)
}
})

const { theme } = useTheme()
const colors = createMemo(() => [
theme.secondary,
Expand All @@ -54,7 +64,10 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
return agents()
},
current() {
return agents().find((x) => x.name === agentStore.current)!
const list = agents()
if (list.length === 0) return undefined
if (!agentStore.current) return list[0]
return list.find((x) => x.name === agentStore.current) ?? list[0]
},
set(name: string) {
if (!agents().some((x) => x.name === name))
Expand All @@ -66,11 +79,15 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
setAgentStore("current", name)
},
move(direction: 1 | -1) {
const list = agents()
if (list.length === 0) return
batch(() => {
let next = agents().findIndex((x) => x.name === agentStore.current) + direction
if (next < 0) next = agents().length - 1
if (next >= agents().length) next = 0
const value = agents()[next]
const current = agentStore.current
const index = current ? list.findIndex((x) => x.name === current) : -1
let next = (index === -1 ? 0 : index) + direction
if (next < 0) next = list.length - 1
if (next >= list.length) next = 0
const value = list[next]
setAgentStore("current", value.name)
})
},
Expand Down Expand Up @@ -181,8 +198,8 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
const a = agent.current()
return (
getFirstValidModel(
() => modelStore.model[a.name],
() => a.model,
() => (a ? modelStore.model[a.name] : undefined),
() => a?.model,
fallbackModel,
) ?? undefined
)
Expand Down Expand Up @@ -227,7 +244,9 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
if (next >= recent.length) next = 0
const val = recent[next]
if (!val) return
setModelStore("model", agent.current().name, { ...val })
const a = agent.current()
if (!a) return
setModelStore("model", a.name, { ...val })
},
cycleFavorite(direction: 1 | -1) {
const favorites = modelStore.favorite.filter((item) => isModelValid(item))
Expand All @@ -253,7 +272,10 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
}
const next = favorites[index]
if (!next) return
setModelStore("model", agent.current().name, { ...next })
const a = agent.current()
if (!a) return
setModelStore("model", a.name, { ...next })

const uniq = uniqueBy([next, ...modelStore.recent], (x) => `${x.providerID}/${x.modelID}`)
if (uniq.length > 10) uniq.pop()
setModelStore(
Expand All @@ -272,7 +294,10 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
})
return
}
setModelStore("model", agent.current().name, model)
const a = agent.current()
if (!a) return
setModelStore("model", a.name, model)

if (options?.recent) {
const uniq = uniqueBy([model, ...modelStore.recent], (x) => `${x.providerID}/${x.modelID}`)
if (uniq.length > 10) uniq.pop()
Expand Down Expand Up @@ -368,6 +393,7 @@ export const { use: useLocal, provider: LocalProvider } = createSimpleContext({
// Automatically update model when agent changes
createEffect(() => {
const value = agent.current()
if (!value) return
if (value.model) {
if (isModelValid(value.model))
model.set({
Expand Down
44 changes: 25 additions & 19 deletions packages/opencode/src/provider/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export namespace ProviderAuth {

export async function methods() {
const s = await state().then((x) => x.methods)
return mapValues(s, (x) =>
return mapValues(s, (x: NonNullable<Hooks["auth"]>) =>
x.methods.map(
(y): Method => ({
type: y.type,
Expand Down Expand Up @@ -78,42 +78,48 @@ export namespace ProviderAuth {
code: z.string().optional(),
}),
async (input) => {
const clearPending = () => state().then((s) => delete s.pending[input.providerID])
const match = await state().then((s) => s.pending[input.providerID])
if (!match) throw new OauthMissing({ providerID: input.providerID })
let result
if (match.method === "code" && !input.code) throw new OauthCodeMissing({ providerID: input.providerID })

if (match.method === "code") {
if (!input.code) throw new OauthCodeMissing({ providerID: input.providerID })
result = await match.callback(input.code)
}
return (async () => {
const result = await (match.method === "code" ? match.callback(input.code!) : match.callback())

if (match.method === "auto") {
result = await match.callback()
}
if (!result || result.type !== "success") {
throw new OauthCallbackFailed({})
}

const providerID =
"provider" in result && typeof result.provider === "string" && result.provider
? result.provider
: input.providerID

if (result?.type === "success") {
if ("key" in result) {
await Auth.set(input.providerID, {
await Auth.set(providerID, {
type: "api",
key: result.key,
})
}

if ("refresh" in result) {
const accountId =
"accountId" in result && typeof result.accountId === "string" ? result.accountId : undefined
const enterpriseUrl =
"enterpriseUrl" in result && typeof result.enterpriseUrl === "string" ? result.enterpriseUrl : undefined

const info: Auth.Info = {
type: "oauth",
access: result.access,
refresh: result.refresh,
expires: result.expires,
...(accountId ? { accountId } : {}),
...(enterpriseUrl ? { enterpriseUrl } : {}),
}
if (result.accountId) {
info.accountId = result.accountId
}
await Auth.set(input.providerID, info)
}
return
}

throw new OauthCallbackFailed({})
await Auth.set(providerID, info)
}
})().finally(clearPending)
},
)

Expand Down
66 changes: 66 additions & 0 deletions packages/opencode/test/provider/auth-extra-fields.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import { expect, mock, test } from "bun:test"

mock.module("../../src/plugin", () => ({
Plugin: {
list: async () => [
{
auth: {
provider: "openai-test",
methods: [
{
label: "Mock OAuth",
type: "oauth",
authorize: async () => {
return {
url: "https://example.com/oauth",
method: "auto",
instructions: "Complete auth in your browser",
callback: async () => {
return {
type: "success",
refresh: "refresh-token",
access: "access-token",
expires: 123,
accountId: "acct_123",
enterpriseUrl: "https://ghe.example.com",
}
},
}
},
},
],
},
},
],
},
}))

const { tmpdir } = await import("../fixture/fixture")
const { Instance } = await import("../../src/project/instance")
const { ProviderAuth } = await import("../../src/provider/auth")
const { Auth } = await import("../../src/auth")

test("ProviderAuth oauth callback persists accountId and enterpriseUrl", async () => {
await using tmp = await tmpdir()

await Instance.provide({
directory: tmp.path,
fn: async () => {
const auth = await ProviderAuth.authorize({
providerID: "openai-test",
method: 0,
})
expect(auth).toBeDefined()

await ProviderAuth.callback({
providerID: "openai-test",
method: 0,
})

const saved = await Auth.get("openai-test")
expect(saved?.type).toBe("oauth")
expect(saved && saved.type === "oauth" ? saved.accountId : undefined).toBe("acct_123")
expect(saved && saved.type === "oauth" ? saved.enterpriseUrl : undefined).toBe("https://ghe.example.com")
},
})
})
2 changes: 2 additions & 0 deletions packages/plugin/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ export type AuthOuathResult = { url: string; instructions: string } & (
access: string
expires: number
accountId?: string
enterpriseUrl?: string
}
| { key: string }
))
Expand All @@ -135,6 +136,7 @@ export type AuthOuathResult = { url: string; instructions: string } & (
access: string
expires: number
accountId?: string
enterpriseUrl?: string
}
| { key: string }
))
Expand Down