Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions components/backend/handlers/sessions.go
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/url"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -1400,6 +1401,172 @@ func UpdateSessionDisplayName(c *gin.Context) {
c.JSON(http.StatusOK, session)
}

// SwitchModel switches the LLM model for a running session
// POST /api/projects/:projectName/agentic-sessions/:sessionName/model
func SwitchModel(c *gin.Context) {
project := c.GetString("project")
sessionName := c.Param("sessionName")
_, k8sDyn := GetK8sClientsForRequest(c)
if k8sDyn == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or missing token"})
c.Abort()
return
}

var req struct {
Model string `json:"model" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body: model is required"})
return
}

if req.Model == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "model must not be empty"})
return
}

gvr := GetAgenticSessionV1Alpha1Resource()

// Get current session
item, err := k8sDyn.Resource(gvr).Namespace(project).Get(context.TODO(), sessionName, v1.GetOptions{})
if err != nil {
if errors.IsNotFound(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "Session not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get session"})
return
}

// Ensure session is Running
if err := ensureRuntimeMutationAllowed(item); err != nil {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return
}

// Get current model for comparison
spec, ok := item.Object["spec"].(map[string]interface{})
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid session spec"})
return
}
llmSettings, _, _ := unstructured.NestedMap(spec, "llmSettings")
previousModel, _ := llmSettings["model"].(string)

// No-op if same model
if previousModel == req.Model {
session := types.AgenticSession{
APIVersion: item.GetAPIVersion(),
Kind: item.GetKind(),
}
if meta, ok := item.Object["metadata"].(map[string]interface{}); ok {
session.Metadata = meta
}
session.Spec = parseSpec(spec)
if status, ok := item.Object["status"].(map[string]interface{}); ok {
session.Status = parseStatus(status)
}
c.JSON(http.StatusOK, session)
return
}

Comment on lines +1416 to +1473
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the requested model before proxying.

After Line 1423, this path only rejects empty input. Unlike CreateSession (Lines 678-692), it never verifies that req.Model is valid for the session's runner/provider, so an unsupported model can be accepted, persisted, and only fail on the next LLM call instead of returning 400 here.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/backend/handlers/sessions.go` around lines 1415 - 1466, After
binding req.Model and before comparing to previousModel, run the same model
validation performed in CreateSession to ensure the requested model is supported
by the session's runner/provider: read spec and llmSettings from item (same way
you do into spec and llmSettings), extract the runner/provider info, and reuse
the validation logic from CreateSession (the model-checking block around lines
678-692) to check req.Model; if validation fails, return HTTP 400 with a clear
error message; otherwise continue with the existing mutation flow (variables:
req.Model, spec, llmSettings, previousModel, and the CreateSession validation
routine).

// Update the CR first to validate RBAC (user needs update permission).
// This ensures a user with only get access cannot trigger a runner-side
// model switch without also being allowed to persist the change.
if llmSettings == nil {
llmSettings = map[string]interface{}{}
}
llmSettings["model"] = req.Model
spec["llmSettings"] = llmSettings

updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{})
if err != nil {
log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"})
return
}
Comment on lines +1483 to +1488
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Concurrent switch conflicts are returned as 500 instead of conflict

At Line 1483, a resource-version conflict during rapid switches is returned as 500. This should be surfaced as a conflict/retry condition, not an internal error.

Suggested fix
 	updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{})
 	if err != nil {
+		if errors.IsConflict(err) {
+			c.JSON(http.StatusConflict, gin.H{"error": "Model switch conflict, please retry"})
+			return
+		}
 		log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err)
 		c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"})
 		return
 	}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{})
if err != nil {
log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"})
return
}
updated, err := k8sDyn.Resource(gvr).Namespace(project).Update(context.TODO(), item, v1.UpdateOptions{})
if err != nil {
if errors.IsConflict(err) {
c.JSON(http.StatusConflict, gin.H{"error": "Model switch conflict, please retry"})
return
}
log.Printf("Failed to update session CR %s for model switch: %v", sessionName, err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update session record"})
return
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/backend/handlers/sessions.go` around lines 1483 - 1488, The Update
call to k8sDyn.Resource(gvr).Namespace(project).Update(...) can return a
resource-version conflict which must be surfaced as HTTP 409 rather than 500;
modify the error handling around that Update (the block referencing sessionName,
project, item and v1.UpdateOptions) to detect a Kubernetes conflict using
apierrors.IsConflict(err) (from k8s.io/apimachinery/pkg/api/errors) and return
c.JSON(http.StatusConflict, ...) with a concise conflict/retry message, while
preserving the existing 500 path for other errors.


// Proxy to runner — if runner rejects (e.g., agent is mid-generation), revert the CR.
// Sanitize the CR name against a strict allowlist to prevent SSRF.
sanitizedName, err := sanitizeK8sName(item.GetName())
if err != nil {
log.Printf("Invalid session name %q for model switch: %v", item.GetName(), err)
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid session name"})
return
}
sanitizedProject, err := sanitizeK8sName(project)
if err != nil {
log.Printf("Invalid project name %q for model switch: %v", project, err)
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid project name"})
return
}
serviceName := getRunnerServiceName(sanitizedName)
runnerURL := fmt.Sprintf("http://%s.%s.svc.cluster.local:8001/model", serviceName, sanitizedProject)
runnerReq := map[string]string{"model": req.Model}
reqBody, _ := json.Marshal(runnerReq)

httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", runnerURL, bytes.NewReader(reqBody))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create runner request"})
return
}
httpReq.Header.Set("Content-Type", "application/json")

client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
log.Printf("Failed to proxy model switch to runner for session %s: %v", sessionName, err)
// Revert the CR update on the server-returned object
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
c.JSON(http.StatusBadGateway, gin.H{"error": "Failed to reach session runner"})
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
log.Printf("Runner rejected model switch for session %s: %d %s", sessionName, resp.StatusCode, string(body))
// Revert the CR update on the server-returned object
revertModelSwitch(updated, previousModel, k8sDyn, gvr, project)
// Forward runner's status code and error
c.Data(resp.StatusCode, "application/json", body)
return
}

session := types.AgenticSession{
APIVersion: updated.GetAPIVersion(),
Kind: updated.GetKind(),
}
if meta, ok := updated.Object["metadata"].(map[string]interface{}); ok {
session.Metadata = meta
}
if s, ok := updated.Object["spec"].(map[string]interface{}); ok {
session.Spec = parseSpec(s)
}
if status, ok := updated.Object["status"].(map[string]interface{}); ok {
session.Status = parseStatus(status)
}

c.JSON(http.StatusOK, session)
}

// revertModelSwitch restores the previous model on the server-returned CR object.
// Called when the runner rejects a model switch after the CR was already updated.
func revertModelSwitch(updated *unstructured.Unstructured, previousModel string, k8sDyn dynamic.Interface, gvr schema.GroupVersionResource, namespace string) {
if updatedSpec, ok := updated.Object["spec"].(map[string]interface{}); ok {
if updatedLLM, ok := updatedSpec["llmSettings"].(map[string]interface{}); ok {
updatedLLM["model"] = previousModel
_, err := k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), updated, v1.UpdateOptions{})
if err != nil {
log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err)
}
}
Comment on lines +1558 to +1566
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Rollback is not durable under concurrent updates

revertModelSwitch does a single Update on a previously returned object. If that object is stale, rollback fails and the CR may remain on the new model even though the runner rejected the switch.

Suggested fix
 func revertModelSwitch(updated *unstructured.Unstructured, previousModel string, k8sDyn dynamic.Interface, gvr schema.GroupVersionResource, namespace string) {
-	if updatedSpec, ok := updated.Object["spec"].(map[string]interface{}); ok {
-		if updatedLLM, ok := updatedSpec["llmSettings"].(map[string]interface{}); ok {
-			updatedLLM["model"] = previousModel
-			_, err := k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), updated, v1.UpdateOptions{})
-			if err != nil {
-				log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err)
-			}
-		}
-	}
+	for i := 0; i < 3; i++ {
+		current, err := k8sDyn.Resource(gvr).Namespace(namespace).Get(context.TODO(), updated.GetName(), v1.GetOptions{})
+		if err != nil {
+			log.Printf("Failed to fetch session %s for model revert: %v", updated.GetName(), err)
+			return
+		}
+		spec, found, _ := unstructured.NestedMap(current.Object, "spec")
+		if !found || spec == nil {
+			return
+		}
+		llm, _, _ := unstructured.NestedMap(spec, "llmSettings")
+		if llm == nil {
+			llm = map[string]interface{}{}
+		}
+		llm["model"] = previousModel
+		spec["llmSettings"] = llm
+		current.Object["spec"] = spec
+
+		if _, err = k8sDyn.Resource(gvr).Namespace(namespace).Update(context.TODO(), current, v1.UpdateOptions{}); err == nil {
+			return
+		}
+		if !errors.IsConflict(err) {
+			log.Printf("Failed to revert model switch for session %s: %v", updated.GetName(), err)
+			return
+		}
+	}
+	log.Printf("Failed to revert model switch for session %s after retries", updated.GetName())
 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/backend/handlers/sessions.go` around lines 1558 - 1566,
revertModelSwitch currently writes a single Update against an object that may be
stale; change it to perform a safe retry loop: fetch the latest object with
k8sDyn.Resource(gvr).Namespace(namespace).Get(...), modify its
spec.llmSettings.model to previousModel, then attempt Update and on conflict (or
error indicating resourceVersion mismatch) re-fetch and retry a few times (use
exponential/backoff and a limited retry count) until success or final failure;
ensure you preserve/modify the live object's spec (using
updated.GetName()/GetResourceVersion() from the fetched object) rather than
updating the originally passed-in stale Unstructured.

}
}

// SelectWorkflow sets the active workflow for a session
// POST /api/projects/:projectName/agentic-sessions/:sessionName/workflow
func SelectWorkflow(c *gin.Context) {
Expand Down Expand Up @@ -1871,6 +2038,20 @@ func RemoveRepo(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Repository removed", "session": session})
}

// k8sNameRegexp matches valid Kubernetes resource names (RFC 1123 DNS label).
var k8sNameRegexp = regexp.MustCompile(`^[a-z0-9]([a-z0-9\-]*[a-z0-9])?$`)

// sanitizeK8sName validates that name is a valid Kubernetes resource name
// and returns it unchanged if valid, or returns an error. This breaks the
// taint chain for static analysis (CodeQL SSRF) by proving the value matches
// a strict allowlist before it reaches any network call.
func sanitizeK8sName(name string) (string, error) {
if len(name) == 0 || len(name) > 253 || !k8sNameRegexp.MatchString(name) {
return "", fmt.Errorf("invalid Kubernetes resource name: %q", name)
}
return name, nil
}

// getRunnerServiceName returns the K8s Service name for a session's runner.
// The runner serves both AG-UI and content endpoints on port 8001.
func getRunnerServiceName(session string) string {
Expand Down
1 change: 1 addition & 0 deletions components/backend/routes.go
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func registerRoutes(r *gin.Engine) {
projectGroup.GET("/agentic-sessions/:sessionName/repos/status", handlers.GetReposStatus)
projectGroup.DELETE("/agentic-sessions/:sessionName/repos/:repoName", handlers.RemoveRepo)
projectGroup.PUT("/agentic-sessions/:sessionName/displayname", handlers.UpdateSessionDisplayName)
projectGroup.POST("/agentic-sessions/:sessionName/model", handlers.SwitchModel)

// OAuth integration - requires user auth like all other session endpoints
projectGroup.GET("/agentic-sessions/:sessionName/oauth/:provider/url", handlers.GetOAuthURL)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { render, screen } from '@testing-library/react';
import { LiveModelSelector } from '../live-model-selector';
import type { ListModelsResponse } from '@/types/api';

const mockAnthropicModels: ListModelsResponse = {
models: [
{ id: 'claude-haiku-4-5', label: 'Claude Haiku 4.5', provider: 'anthropic', isDefault: false },
{ id: 'claude-sonnet-4-5', label: 'Claude Sonnet 4.5', provider: 'anthropic', isDefault: true },
{ id: 'claude-opus-4-6', label: 'Claude Opus 4.6', provider: 'anthropic', isDefault: false },
],
defaultModel: 'claude-sonnet-4-5',
};

const mockUseModels = vi.fn(() => ({ data: mockAnthropicModels }));

vi.mock('@/services/queries/use-models', () => ({
useModels: () => mockUseModels(),
}));

describe('LiveModelSelector', () => {
const defaultProps = {
projectName: 'test-project',
currentModel: 'claude-sonnet-4-5',
onSelect: vi.fn(),
};

beforeEach(() => {
vi.clearAllMocks();
mockUseModels.mockReturnValue({ data: mockAnthropicModels });
});

it('renders with current model name displayed', () => {
render(<LiveModelSelector {...defaultProps} />);
const button = screen.getByRole('button');
expect(button.textContent).toContain('Claude Sonnet 4.5');
});

it('renders with model id fallback when model not in list', () => {
render(
<LiveModelSelector
{...defaultProps}
currentModel="unknown-model-id"
/>
);
const button = screen.getByRole('button');
expect(button.textContent).toContain('unknown-model-id');
});

it('shows spinner when switching', () => {
render(<LiveModelSelector {...defaultProps} switching />);
const spinner = document.querySelector('.animate-spin');
expect(spinner).not.toBeNull();
});

it('button is disabled when disabled prop is true', () => {
render(<LiveModelSelector {...defaultProps} disabled />);
const button = screen.getByRole('button');
expect((button as HTMLButtonElement).disabled).toBe(true);
});

it('button is disabled when switching prop is true', () => {
render(<LiveModelSelector {...defaultProps} switching />);
const button = screen.getByRole('button');
expect((button as HTMLButtonElement).disabled).toBe(true);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"use client";

import { useMemo } from "react";
import { ChevronDown, Loader2 } from "lucide-react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuRadioGroup,
DropdownMenuRadioItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { useModels } from "@/services/queries/use-models";

type LiveModelSelectorProps = {
projectName: string;
currentModel: string;
provider?: string;
disabled?: boolean;
switching?: boolean;
onSelect: (model: string) => void;
};

export function LiveModelSelector({
projectName,
currentModel,
provider,
disabled,
switching,
onSelect,
}: LiveModelSelectorProps) {
const { data: modelsData, isLoading, isError } = useModels(projectName, true, provider);

const models = useMemo(() => {
return modelsData?.models.map((m) => ({ id: m.id, name: m.label })) ?? [];
}, [modelsData]);

const currentModelName =
models.find((m) => m.id === currentModel)?.name ?? currentModel;

return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="sm"
className="gap-1 text-xs text-muted-foreground hover:text-foreground h-7 px-2"
disabled={disabled || switching}
>
{switching ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : null}
<span className="truncate max-w-[160px]">
{currentModelName}
</span>
<ChevronDown className="h-3 w-3 opacity-50 flex-shrink-0" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end" side="top" sideOffset={4}>
{isLoading ? (
<div className="flex items-center justify-center px-2 py-4">
<Loader2 className="h-4 w-4 animate-spin text-muted-foreground" />
</div>
) : isError ? (
<div className="px-2 py-4 text-center text-sm text-destructive">
Failed to load models
</div>
) : models.length > 0 ? (
<DropdownMenuRadioGroup
value={currentModel}
onValueChange={(modelId) => {
if (modelId !== currentModel) {
onSelect(modelId);
}
}}
>
{models.map((model) => (
<DropdownMenuRadioItem key={model.id} value={model.id}>
{model.name}
</DropdownMenuRadioItem>
))}
</DropdownMenuRadioGroup>
) : (
<div className="px-2 py-4 text-center text-sm text-muted-foreground">
No models available
</div>
)}
</DropdownMenuContent>
</DropdownMenu>
);
}
Loading
Loading