Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions app/lambda/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ func (s *LambdaHandler) Shutdown() {
func (s *LambdaHandler) getProxyFunction() (any, error) {
switch s.config.ProxySource {
case ProxySourceApiGatewayV1:
return httpadapter.New(s.mux).ProxyWithContext, nil
return httpadapter.New(server.NormalizePath(s.mux)).ProxyWithContext, nil
case ProxySourceApiGatewayV2:
return httpadapter.NewV2(s.mux).ProxyWithContext, nil
return httpadapter.NewV2(server.NormalizePath(s.mux)).ProxyWithContext, nil
case ProxySourceAlb:
return httpadapter.NewALB(s.mux).ProxyWithContext, nil
return httpadapter.NewALB(server.NormalizePath(s.mux)).ProxyWithContext, nil
default:
return nil, fmt.Errorf("invalid proxy source: %s", s.config.ProxySource)
}
Expand Down
4 changes: 3 additions & 1 deletion handler/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import "go.uber.org/fx"
func Module() fx.Option {
return fx.Module("common",
fx.Provide(NewCommandHandler),
fx.Provide(NewMuEdHandler),
fx.Provide(NewLegacyRoute),
fx.Provide(NewCommandRoute),
fx.Provide(NewHealthRoute),
fx.Provide(NewMuEdEvaluateRoute),
fx.Provide(NewMuEdHealthRoute),
)
}
215 changes: 215 additions & 0 deletions handler/mued.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package handler

import (
"encoding/json"
"fmt"
"io"
"net/http"

"go.uber.org/fx"
"go.uber.org/zap"

"github.com/lambda-feedback/shimmy/config"
"github.com/lambda-feedback/shimmy/runtime"
)

const muEdVersionHeader = "X-Api-Version"

type MuEdHandlerParams struct {
fx.In

Handler runtime.Handler
Runtime runtime.Runtime
Config config.Config
Log *zap.Logger
}

type MuEdHandler struct {
handler runtime.Handler
runtime runtime.Runtime
config config.Config
log *zap.Logger
}

func NewMuEdHandler(params MuEdHandlerParams) *MuEdHandler {
return &MuEdHandler{
handler: params.Handler,
runtime: params.Runtime,
config: params.Config,
log: params.Log,
}
}

// checkMuEdVersion validates the X-Api-Version request header.
// Returns (resolvedVersion, true) on success, or writes a 406 and returns ("", false).
func (h *MuEdHandler) checkMuEdVersion(w http.ResponseWriter, r *http.Request) (string, bool) {
requested := r.Header.Get(muEdVersionHeader)
if requested != "" && !runtime.MuEdIsVersionSupported(requested) {
body, _ := json.Marshal(map[string]any{
"title": "API version not supported",
"message": fmt.Sprintf(
"The requested API version '%s' is not supported. Supported versions are: %v.",
requested, runtime.SupportedMuEdVersions,
),
"code": "VERSION_NOT_SUPPORTED",
"details": map[string]any{
"requestedVersion": requested,
"supportedVersions": runtime.SupportedMuEdVersions,
},
})
w.Header().Set(muEdVersionHeader, runtime.MuEdResolveVersion(requested))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotAcceptable)
w.Write(body) //nolint:errcheck
return "", false
}
return runtime.MuEdResolveVersion(requested), true
}

func (h *MuEdHandler) checkAuth(w http.ResponseWriter, r *http.Request) bool {
if h.config.Auth.Key != "" && r.Header.Get("api-key") != h.config.Auth.Key {
h.log.Debug("unauthorized request", zap.String("path", r.URL.Path))
http.Error(w, "unauthorized", http.StatusUnauthorized)
return false
}
return true
}

// ServeEvaluate handles POST /evaluate.
func (h *MuEdHandler) ServeEvaluate(w http.ResponseWriter, r *http.Request) {
if !h.checkAuth(w, r) {
return
}

version, ok := h.checkMuEdVersion(w, r)
if !ok {
return
}

if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusBadRequest)
return
}

var muEdReq runtime.MuEdEvaluateRequest
if err := json.Unmarshal(body, &muEdReq); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest)
return
}

isPreview := muEdReq.PreSubmissionFeedback != nil && muEdReq.PreSubmissionFeedback.Enabled

var legacyBody map[string]any
if isPreview {
legacyBody, err = runtime.MuEdBuildLegacyPreviewRequest(muEdReq)
} else {
legacyBody, err = runtime.MuEdBuildLegacyEvalRequest(muEdReq)
}
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

legacyBodyBytes, err := json.Marshal(legacyBody)
if err != nil {
http.Error(w, "failed to build request", http.StatusInternalServerError)
return
}

command := runtime.CommandEvaluate
if isPreview {
command = runtime.CommandPreview
}

header := http.Header{}
header.Set("Command", string(command))

req := runtime.Request{
Path: r.URL.Path,
Method: http.MethodPost,
Body: legacyBodyBytes,
Header: header,
}

resp := h.handler.Handle(r.Context(), req)

if resp.StatusCode != http.StatusOK {
for k, v := range resp.Header {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
w.WriteHeader(resp.StatusCode)
w.Write(resp.Body) //nolint:errcheck
return
}

var respBody map[string]any
if err := json.Unmarshal(resp.Body, &respBody); err != nil {
http.Error(w, "failed to parse response", http.StatusInternalServerError)
return
}

result, ok := respBody["result"].(map[string]any)
if !ok {
http.Error(w, "invalid response from evaluation function", http.StatusInternalServerError)
return
}

var feedback []map[string]any
if isPreview {
feedback = runtime.MuEdToPreviewFeedback(result)
} else {
feedback = runtime.MuEdToEvalFeedback(result)
}

w.Header().Set("Content-Type", "application/json")
w.Header().Set(muEdVersionHeader, version)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(feedback) //nolint:errcheck
}

// ServeHealth handles GET /evaluate/health.
func (h *MuEdHandler) ServeHealth(w http.ResponseWriter, r *http.Request) {
if !h.checkAuth(w, r) {
return
}

version, ok := h.checkMuEdVersion(w, r)
if !ok {
return
}

if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}

resp, err := h.runtime.Handle(r.Context(), runtime.EvaluationRequest{
Command: runtime.CommandHealth,
Data: map[string]any{},
})
if err != nil {
http.Error(w, "health check failed", http.StatusInternalServerError)
return
}

legacyResult, ok := resp["result"].(map[string]any)
if !ok {
http.Error(w, "invalid health response", http.StatusInternalServerError)
return
}

result := runtime.MuEdToHealthResponse(legacyResult)

w.Header().Set("Content-Type", "application/json")
w.Header().Set(muEdVersionHeader, version)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(result) //nolint:errcheck
}
Loading
Loading