-
Notifications
You must be signed in to change notification settings - Fork 84
feat: live model switching for running sessions #1239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d543287
77e7e24
ce60fc0
aa44efc
8a21d51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ import ( | |||||||||||||||||||||||||||||||||
| "net/url" | ||||||||||||||||||||||||||||||||||
| "os" | ||||||||||||||||||||||||||||||||||
| "path/filepath" | ||||||||||||||||||||||||||||||||||
| "regexp" | ||||||||||||||||||||||||||||||||||
| "sort" | ||||||||||||||||||||||||||||||||||
| "strings" | ||||||||||||||||||||||||||||||||||
| "sync" | ||||||||||||||||||||||||||||||||||
|
|
@@ -1400,6 +1401,172 @@ func UpdateSessionDisplayName(c *gin.Context) { | |||||||||||||||||||||||||||||||||
| c.JSON(http.StatusOK, session) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // SwitchModel switches the LLM model for a running session | ||||||||||||||||||||||||||||||||||
| // POST /api/projects/:projectName/agentic-sessions/:sessionName/model | ||||||||||||||||||||||||||||||||||
| func SwitchModel(c *gin.Context) { | ||||||||||||||||||||||||||||||||||
| project := c.GetString("project") | ||||||||||||||||||||||||||||||||||
| sessionName := c.Param("sessionName") | ||||||||||||||||||||||||||||||||||
| _, k8sDyn := GetK8sClientsForRequest(c) | ||||||||||||||||||||||||||||||||||
| if k8sDyn == nil { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"}) | ||||||||||||||||||||||||||||||||||
| c.Abort() | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| var req struct { | ||||||||||||||||||||||||||||||||||
| Model string `json:"model" binding:"required"` | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| if err := c.ShouldBindJSON(&req); err != nil { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body: model is required"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if req.Model == "" { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusBadRequest, gin.H{"error": "model must not be empty"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| gvr := GetAgenticSessionV1Alpha1Resource() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Get current session | ||||||||||||||||||||||||||||||||||
| item, err := k8sDyn.Resource(gvr).Namespace(project).Get(context.TODO(), sessionName, v1.GetOptions{}) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| if errors.IsNotFound(err) { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get session"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Ensure session is Running | ||||||||||||||||||||||||||||||||||
| if err := ensureRuntimeMutationAllowed(item); err != nil { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Get current model for comparison | ||||||||||||||||||||||||||||||||||
| spec, ok := item.Object["spec"].(map[string]interface{}) | ||||||||||||||||||||||||||||||||||
| if !ok { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid session spec"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| llmSettings, _, _ := unstructured.NestedMap(spec, "llmSettings") | ||||||||||||||||||||||||||||||||||
| previousModel, _ := llmSettings["model"].(string) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // No-op if same model | ||||||||||||||||||||||||||||||||||
| if previousModel == req.Model { | ||||||||||||||||||||||||||||||||||
| session := types.AgenticSession{ | ||||||||||||||||||||||||||||||||||
| APIVersion: item.GetAPIVersion(), | ||||||||||||||||||||||||||||||||||
| Kind: item.GetKind(), | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| if meta, ok := item.Object["metadata"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| session.Metadata = meta | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| session.Spec = parseSpec(spec) | ||||||||||||||||||||||||||||||||||
| if status, ok := item.Object["status"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| session.Status = parseStatus(status) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusOK, session) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Update the CR first to validate RBAC (user needs update permission). | ||||||||||||||||||||||||||||||||||
| // This ensures a user with only get access cannot trigger a runner-side | ||||||||||||||||||||||||||||||||||
| // model switch without also being allowed to persist the change. | ||||||||||||||||||||||||||||||||||
| if llmSettings == nil { | ||||||||||||||||||||||||||||||||||
| llmSettings = map[string]interface{}{} | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| llmSettings["model"] = req.Model | ||||||||||||||||||||||||||||||||||
| spec["llmSettings"] = llmSettings | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{}) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err) | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+1483
to
+1488
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Concurrent switch conflicts are returned as 500 instead of conflict At Line 1483, a resource-version conflict during rapid switches is returned as 500. This should be surfaced as a conflict/retry condition, not an internal error. Suggested fix updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{})
if err != nil {
+ if errors.IsConflict(err) {
+ c.JSON(http.StatusConflict, gin.H{"error": "Model switch conflict, please retry"})
+ return
+ }
log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"})
return
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // Proxy to runner — if runner rejects (e.g., agent is mid-generation), revert the CR. | ||||||||||||||||||||||||||||||||||
| // Sanitize the CR name against a strict allowlist to prevent SSRF. | ||||||||||||||||||||||||||||||||||
| sanitizedName, err := sanitizeK8sName(item.GetName()) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| log.Printf("Invalid session name %q for model switch: %v", item.GetName(), err) | ||||||||||||||||||||||||||||||||||
| revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session name"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| sanitizedProject, err := sanitizeK8sName(project) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| log.Printf("Invalid project name %q for model switch: %v", project, err) | ||||||||||||||||||||||||||||||||||
| revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| serviceName := getRunnerServiceName(sanitizedName) | ||||||||||||||||||||||||||||||||||
| runnerURL := fmt.Sprintf("http://%s.%s.svc.cluster.local:8001/model", serviceName, sanitizedProject) | ||||||||||||||||||||||||||||||||||
| runnerReq := map[string]string{"model": req.Model} | ||||||||||||||||||||||||||||||||||
| reqBody, _ := json.Marshal(runnerReq) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", runnerURL, bytes.NewReader(reqBody)) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create runner request"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| httpReq.Header.Set("Content-Type", "application/json") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| client := &http.Client{Timeout: 30 * time.Second} | ||||||||||||||||||||||||||||||||||
| resp, err := client.Do(httpReq) | ||||||||||||||||||||||||||||||||||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
|
||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| log.Printf("Failed to proxy model switch to runner for session %s: %v", sessionName, err) | ||||||||||||||||||||||||||||||||||
| // Revert the CR update on the server-returned object | ||||||||||||||||||||||||||||||||||
| revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) | ||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusBadGateway, gin.H{"error": "Failed to reach session runner"}) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| defer resp.Body.Close() | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if resp.StatusCode != http.StatusOK { | ||||||||||||||||||||||||||||||||||
| body, _ := io.ReadAll(resp.Body) | ||||||||||||||||||||||||||||||||||
| log.Printf("Runner rejected model switch for session %s: %d %s", sessionName, resp.StatusCode, string(body)) | ||||||||||||||||||||||||||||||||||
| // Revert the CR update on the server-returned object | ||||||||||||||||||||||||||||||||||
| revertModelSwitch(updated, previousModel, k8sDyn, gvr, project) | ||||||||||||||||||||||||||||||||||
| // Forward runner's status code and error | ||||||||||||||||||||||||||||||||||
| c.Data(resp.StatusCode, "application/json", body) | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| session := types.AgenticSession{ | ||||||||||||||||||||||||||||||||||
| APIVersion: updated.GetAPIVersion(), | ||||||||||||||||||||||||||||||||||
| Kind: updated.GetKind(), | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| if meta, ok := updated.Object["metadata"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| session.Metadata = meta | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| if s, ok := updated.Object["spec"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| session.Spec = parseSpec(s) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| if status, ok := updated.Object["status"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| session.Status = parseStatus(status) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| c.JSON(http.StatusOK, session) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // revertModelSwitch restores the previous model on the server-returned CR object. | ||||||||||||||||||||||||||||||||||
| // Called when the runner rejects a model switch after the CR was already updated. | ||||||||||||||||||||||||||||||||||
| func revertModelSwitch(updated *unstructured.Unstructured, previousModel string, k8sDyn dynamic.Interface, gvr schema.GroupVersionResource, namespace string) { | ||||||||||||||||||||||||||||||||||
| if updatedSpec, ok := updated.Object["spec"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| if updatedLLM, ok := updatedSpec["llmSettings"].(map[string]interface{}); ok { | ||||||||||||||||||||||||||||||||||
| updatedLLM["model"] = previousModel | ||||||||||||||||||||||||||||||||||
| _, err := k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), updated, v1.UpdateOptions{}) | ||||||||||||||||||||||||||||||||||
| if err != nil { | ||||||||||||||||||||||||||||||||||
| log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+1558
to
+1566
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rollback is not durable under concurrent updates
Suggested fix func revertModelSwitch(updated *unstructured.Unstructured, previousModel string, k8sDyn dynamic.Interface, gvr schema.GroupVersionResource, namespace string) {
- if updatedSpec, ok := updated.Object["spec"].(map[string]interface{}); ok {
- if updatedLLM, ok := updatedSpec["llmSettings"].(map[string]interface{}); ok {
- updatedLLM["model"] = previousModel
- _, err := k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), updated, v1.UpdateOptions{})
- if err != nil {
- log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err)
- }
- }
- }
+ for i := 0; i < 3; i++ {
+ current, err := k8sDyn.Resource(gvr).Namespace(namespace).Get(context.TODO(), updated.GetName(), v1.GetOptions{})
+ if err != nil {
+ log.Printf("Failed to fetch session %s for model revert: %v", updated.GetName(), err)
+ return
+ }
+ spec, found, _ := unstructured.NestedMap(current.Object, "spec")
+ if !found || spec == nil {
+ return
+ }
+ llm, _, _ := unstructured.NestedMap(spec, "llmSettings")
+ if llm == nil {
+ llm = map[string]interface{}{}
+ }
+ llm["model"] = previousModel
+ spec["llmSettings"] = llm
+ current.Object["spec"] = spec
+
+ if _, err = k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), current, v1.UpdateOptions{}); err == nil {
+ return
+ }
+ if !errors.IsConflict(err) {
+ log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err)
+ return
+ }
+ }
+ log.Printf("Failed to revert model switch for session %s after retries", updated.GetName())
}🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // SelectWorkflow sets the active workflow for a session | ||||||||||||||||||||||||||||||||||
| // POST /api/projects/:projectName/agentic-sessions/:sessionName/workflow | ||||||||||||||||||||||||||||||||||
| func SelectWorkflow(c *gin.Context) { | ||||||||||||||||||||||||||||||||||
|
|
@@ -1871,6 +2038,20 @@ func RemoveRepo(c *gin.Context) { | |||||||||||||||||||||||||||||||||
| c.JSON(http.StatusOK, gin.H{"message": "Repository removed", "session": session}) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // k8sNameRegexp matches valid Kubernetes resource names (RFC 1123 DNS label). | ||||||||||||||||||||||||||||||||||
| var k8sNameRegexp = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-]*[a-z0-9])?$`) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // sanitizeK8sName validates that name is a valid Kubernetes resource name | ||||||||||||||||||||||||||||||||||
| // and returns it unchanged if valid, or returns an error. This breaks the | ||||||||||||||||||||||||||||||||||
| // taint chain for static analysis (CodeQL SSRF) by proving the value matches | ||||||||||||||||||||||||||||||||||
| // a strict allowlist before it reaches any network call. | ||||||||||||||||||||||||||||||||||
| func sanitizeK8sName(name string) (string, error) { | ||||||||||||||||||||||||||||||||||
| if len(name) == 0 || len(name) > 253 || !k8sNameRegexp.MatchString(name) { | ||||||||||||||||||||||||||||||||||
| return "", fmt.Errorf("invalid Kubernetes resource name: %q", name) | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| return name, nil | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // getRunnerServiceName returns the K8s Service name for a session's runner. | ||||||||||||||||||||||||||||||||||
| // The runner serves both AG-UI and content endpoints on port 8001. | ||||||||||||||||||||||||||||||||||
| func getRunnerServiceName(session string) string { | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import { describe, it, expect, vi, beforeEach } from 'vitest'; | ||
| import { render, screen } from '@testing-library/react'; | ||
| import { LiveModelSelector } from '../live-model-selector'; | ||
| import type { ListModelsResponse } from '@/types/api'; | ||
|
|
||
| const mockAnthropicModels: ListModelsResponse = { | ||
| models: [ | ||
| { id: 'claude-haiku-4-5', label: 'Claude Haiku 4.5', provider: 'anthropic', isDefault: false }, | ||
| { id: 'claude-sonnet-4-5', label: 'Claude Sonnet 4.5', provider: 'anthropic', isDefault: true }, | ||
| { id: 'claude-opus-4-6', label: 'Claude Opus 4.6', provider: 'anthropic', isDefault: false }, | ||
| ], | ||
| defaultModel: 'claude-sonnet-4-5', | ||
| }; | ||
|
|
||
| const mockUseModels = vi.fn(() => ({ data: mockAnthropicModels })); | ||
|
|
||
| vi.mock('@/services/queries/use-models', () => ({ | ||
| useModels: () => mockUseModels(), | ||
| })); | ||
|
|
||
| describe('LiveModelSelector', () => { | ||
| const defaultProps = { | ||
| projectName: 'test-project', | ||
| currentModel: 'claude-sonnet-4-5', | ||
| onSelect: vi.fn(), | ||
| }; | ||
|
|
||
| beforeEach(() => { | ||
| vi.clearAllMocks(); | ||
| mockUseModels.mockReturnValue({ data: mockAnthropicModels }); | ||
| }); | ||
|
|
||
| it('renders with current model name displayed', () => { | ||
| render(<LiveModelSelector {...defaultProps} />); | ||
| const button = screen.getByRole('button'); | ||
| expect(button.textContent).toContain('Claude Sonnet 4.5'); | ||
| }); | ||
|
|
||
| it('renders with model id fallback when model not in list', () => { | ||
| render( | ||
| <LiveModelSelector | ||
| {...defaultProps} | ||
| currentModel="unknown-model-id" | ||
| /> | ||
| ); | ||
| const button = screen.getByRole('button'); | ||
| expect(button.textContent).toContain('unknown-model-id'); | ||
| }); | ||
|
|
||
| it('shows spinner when switching', () => { | ||
| render(<LiveModelSelector {...defaultProps} switching />); | ||
| const spinner = document.querySelector('.animate-spin'); | ||
| expect(spinner).not.toBeNull(); | ||
| }); | ||
|
|
||
| it('button is disabled when disabled prop is true', () => { | ||
| render(<LiveModelSelector {...defaultProps} disabled />); | ||
| const button = screen.getByRole('button'); | ||
| expect((button as HTMLButtonElement).disabled).toBe(true); | ||
| }); | ||
|
|
||
| it('button is disabled when switching prop is true', () => { | ||
| render(<LiveModelSelector {...defaultProps} switching />); | ||
| const button = screen.getByRole('button'); | ||
| expect((button as HTMLButtonElement).disabled).toBe(true); | ||
| }); | ||
| }); |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| "use client"; | ||
|
|
||
| import { useMemo } from "react"; | ||
| import { ChevronDown, Loader2 } from "lucide-react"; | ||
| import { Button } from "@/components/ui/button"; | ||
| import { | ||
| DropdownMenu, | ||
| DropdownMenuContent, | ||
| DropdownMenuRadioGroup, | ||
| DropdownMenuRadioItem, | ||
| DropdownMenuTrigger, | ||
| } from "@/components/ui/dropdown-menu"; | ||
| import { useModels } from "@/services/queries/use-models"; | ||
|
|
||
| type LiveModelSelectorProps = { | ||
| projectName: string; | ||
| currentModel: string; | ||
| provider?: string; | ||
| disabled?: boolean; | ||
| switching?: boolean; | ||
| onSelect: (model: string) => void; | ||
| }; | ||
|
|
||
| export function LiveModelSelector({ | ||
| projectName, | ||
| currentModel, | ||
| provider, | ||
| disabled, | ||
| switching, | ||
| onSelect, | ||
| }: LiveModelSelectorProps) { | ||
| const { data: modelsData, isLoading, isError } = useModels(projectName, true, provider); | ||
|
|
||
| const models = useMemo(() => { | ||
| return modelsData?.models.map((m) => ({ id: m.id, name: m.label })) ?? []; | ||
| }, [modelsData]); | ||
|
|
||
| const currentModelName = | ||
| models.find((m) => m.id === currentModel)?.name ?? currentModel; | ||
|
|
||
| return ( | ||
| <DropdownMenu> | ||
| <DropdownMenuTrigger asChild> | ||
| <Button | ||
| variant="ghost" | ||
| size="sm" | ||
| className="gap-1 text-xs text-muted-foreground hover:text-foreground h-7 px-2" | ||
| disabled={disabled || switching} | ||
| > | ||
| {switching ? ( | ||
| <Loader2 className="h-3 w-3 animate-spin" /> | ||
| ) : null} | ||
| <span className="truncate max-w-[160px]"> | ||
| {currentModelName} | ||
| </span> | ||
| <ChevronDown className="h-3 w-3 opacity-50 flex-shrink-0" /> | ||
| </Button> | ||
| </DropdownMenuTrigger> | ||
| <DropdownMenuContent align="end" side="top" sideOffset={4}> | ||
| {isLoading ? ( | ||
| <div className="flex items-center justify-center px-2 py-4"> | ||
| <Loader2 className="h-4 w-4 animate-spin text-muted-foreground" /> | ||
| </div> | ||
| ) : isError ? ( | ||
| <div className="px-2 py-4 text-center text-sm text-destructive"> | ||
| Failed to load models | ||
| </div> | ||
| ) : models.length > 0 ? ( | ||
| <DropdownMenuRadioGroup | ||
| value={currentModel} | ||
| onValueChange={(modelId) => { | ||
| if (modelId !== currentModel) { | ||
| onSelect(modelId); | ||
| } | ||
| }} | ||
| > | ||
| {models.map((model) => ( | ||
| <DropdownMenuRadioItem key={model.id} value={model.id}> | ||
| {model.name} | ||
| </DropdownMenuRadioItem> | ||
| ))} | ||
| </DropdownMenuRadioGroup> | ||
| ) : ( | ||
| <div className="px-2 py-4 text-center text-sm text-muted-foreground"> | ||
| No models available | ||
| </div> | ||
| )} | ||
| </DropdownMenuContent> | ||
| </DropdownMenu> | ||
| ); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate the requested model before proxying.
After Line 1423, this path only rejects empty input. Unlike
CreateSession(Lines 678-692), it never verifies thatreq.Modelis valid for the session's runner/provider, so an unsupported model can be accepted, persisted, and only fail on the next LLM call instead of returning 400 here.🤖 Prompt for AI Agents