Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions pkg/inference/models/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ import (
"github.com/docker/model-runner/pkg/middleware"
)

// parseBoolQueryParam parses a boolean query parameter from the request.
// Returns the parsed value, or false if the parameter is absent or unparseable
// (logging a warning in the latter case).
func parseBoolQueryParam(r *http.Request, log logging.Logger, name string) bool {
if !r.URL.Query().Has(name) {
return false
}
val, err := strconv.ParseBool(r.URL.Query().Get(name))
if err != nil {
log.Warn("error while parsing query parameter", "param", name, "error", err)
return false
}
return val
}
Comment on lines +29 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of parseBoolQueryParam treats a query parameter present without a value (e.g., ?force) as false because r.URL.Query().Get("force") returns an empty string, and strconv.ParseBool("") results in an error. This is likely not the expected behavior for users of the API, who would probably expect the presence of the flag to mean true.

I suggest modifying the logic to treat the presence of the key with an empty value as true. This would make the API more intuitive and align with common web API conventions. The updated implementation also avoids calling r.URL.Query() multiple times and improves the warning log message to include the unparseable value.

func parseBoolQueryParam(r *http.Request, log logging.Logger, name string) bool {
	q := r.URL.Query()
	if !q.Has(name) {
		return false
	}
	valStr := q.Get(name)
	// Treat presence of key with empty value as true (e.g. `?force`)
	if valStr == "" {
		return true
	}
	val, err := strconv.ParseBool(valStr)
	if err != nil {
		log.Warn("error while parsing query parameter", "param", name, "value", valStr, "error", err)
		return false
	}
	return val
}


// HTTPHandler manages inference model pulls and storage.
type HTTPHandler struct {
// log is the associated logger.
Expand Down Expand Up @@ -195,16 +210,7 @@ func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) {
}

func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) {
// Parse remote query parameter
remote := false
if r.URL.Query().Has("remote") {
val, err := strconv.ParseBool(r.URL.Query().Get("remote"))
if err != nil {
h.log.Warn("error while parsing remote query parameter", "error", err)
} else {
remote = val
}
}
remote := parseBoolQueryParam(r, h.log, "remote")

var (
apiModel *Model
Expand Down Expand Up @@ -309,14 +315,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request)

modelRef := r.PathValue("name")

var force bool
if r.URL.Query().Has("force") {
if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil {
h.log.Warn("error while parsing force query parameter", "error", err)
} else {
force = val
}
}
force := parseBoolQueryParam(r, h.log, "force")

// First try to delete without normalization (as ID), then with normalization if not found
resp, err := h.manager.Delete(modelRef, force)
Expand Down
59 changes: 26 additions & 33 deletions pkg/inference/scheduling/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ import (

type contextKey bool

// readRequestBody reads up to maxSize bytes from the request body and writes
// an appropriate HTTP error if reading fails. Returns (body, true) on success
// or (nil, false) after writing the error response.
func readRequestBody(w http.ResponseWriter, r *http.Request, maxSize int64) ([]byte, bool) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
return nil, false
}
return body, true
}

const preloadOnlyKey contextKey = false

// HTTPHandler handles HTTP requests for the scheduler.
Expand Down Expand Up @@ -132,14 +149,8 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque

// Read the entire request body. We put some basic size constraints in place
// to avoid DoS attacks. We do this early to avoid client write timeouts.
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand Down Expand Up @@ -338,14 +349,8 @@ func (h *HTTPHandler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) {
// Unload unloads the specified runners (backend, model) from the backend.
// Currently, this doesn't work for runners that are handling an OpenAI request.
func (h *HTTPHandler) Unload(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand All @@ -371,14 +376,8 @@ type installBackendRequest struct {
// InstallBackend handles POST <inference-prefix>/install-backend requests.
// It triggers on-demand installation of a deferred backend.
func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand Down Expand Up @@ -414,14 +413,8 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
return
}

body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
var maxBytesError *http.MaxBytesError
if errors.As(err, &maxBytesError) {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "failed to read request body", http.StatusInternalServerError)
}
body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize)
if !ok {
return
}

Expand All @@ -433,7 +426,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
return
}

backend, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
backend, err := h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Avoid shadowing the outer backend/err here to prevent subtle bugs

This line used to reassign the existing backend/err; switching to := creates new inner-scope variables and leaves the outer ones unchanged. Any code after this block will still see the old values, which changes the previous behavior and can introduce hard-to-spot bugs. Unless a new inner scope is intended, please use backend, err = ... to match the original semantics.

if err != nil {
if errors.Is(err, errRunnerAlreadyActive) {
http.Error(w, err.Error(), http.StatusConflict)
Expand Down
Loading