Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion server/cmd/api/api/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,6 @@ func (s *ApiService) ProcessResize(ctx context.Context, request oapi.ProcessResi
return oapi.ProcessResize200JSONResponse(oapi.OkResponse{Ok: true}), nil
}


// writeJSON writes a JSON response with the given status code.
// Unlike http.Error, this sets the correct Content-Type for JSON.
func writeJSON(w http.ResponseWriter, status int, body string) {
Expand Down
22 changes: 22 additions & 0 deletions server/lib/scaletozero/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@ package scaletozero

import (
"context"
"net"
"net/http"

"github.com/onkernel/kernel-images/server/lib/logger"
)

// Middleware returns a standard net/http middleware that disables scale-to-zero
// at the start of each request and re-enables it after the handler completes.
// Connections from loopback addresses are ignored and do not affect the
// scale-to-zero state.
func Middleware(ctrl Controller) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isLoopbackAddr(r.RemoteAddr) {
next.ServeHTTP(w, r)
return
}

if err := ctrl.Disable(r.Context()); err != nil {
logger.FromContext(r.Context()).Error("failed to disable scale-to-zero", "error", err)
http.Error(w, "failed to disable scale-to-zero", http.StatusInternalServerError)
Expand All @@ -23,3 +31,17 @@ func Middleware(ctrl Controller) func(http.Handler) http.Handler {
})
}
}

// isLoopbackAddr reports whether addr is a loopback address.
// addr may be an "ip:port" pair or a bare IP.
func isLoopbackAddr(addr string) bool {
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
return ip.IsLoopback()
}
114 changes: 114 additions & 0 deletions server/lib/scaletozero/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package scaletozero

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{}
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.50:12345"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 1, mock.disableCalls)
assert.Equal(t, 1, mock.enableCalls)
}

func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) {
t.Parallel()

loopbackAddrs := []struct {
name string
addr string
}{
{"loopback-v4", "127.0.0.1:8080"},
{"loopback-v6", "[::1]:8080"},
}

for _, tc := range loopbackAddrs {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{}
var called bool
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = tc.addr
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.True(t, called, "handler should still be called")
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 0, mock.disableCalls, "should not disable for loopback addr")
assert.Equal(t, 0, mock.enableCalls, "should not enable for loopback addr")
})
}
}

func TestMiddlewareDisableError(t *testing.T) {
t.Parallel()
mock := &mockScaleToZeroer{disableErr: assert.AnError}
var called bool
handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}))

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "203.0.113.50:12345"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

assert.False(t, called, "handler should not be called on disable error")
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, 0, mock.enableCalls)
}

func TestIsLoopbackAddr(t *testing.T) {
t.Parallel()

tests := []struct {
addr string
loopback bool
}{
// Loopback
{"127.0.0.1:80", true},
{"[::1]:80", true},
{"127.0.0.1", true},
{"::1", true},
// Non-loopback
{"10.0.0.1:80", false},
{"172.16.0.1:80", false},
{"192.168.1.1:80", false},
{"203.0.113.50:80", false},
{"8.8.8.8:53", false},
{"[2001:db8::1]:80", false},
// Unparseable
{"not-an-ip:80", false},
{"", false},
}

for _, tc := range tests {
t.Run(tc.addr, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.loopback, isLoopbackAddr(tc.addr))
})
}
}
2 changes: 2 additions & 0 deletions server/lib/scaletozero/scaletozero.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (c *unikraftCloudController) Enable(ctx context.Context) error {
func (c *unikraftCloudController) write(ctx context.Context, char string) error {
if _, err := os.Stat(c.path); err != nil {
if os.IsNotExist(err) {
logger.FromContext(ctx).Info("scale-to-zero control file not found, skipping write", "path", c.path, "value", char)
return nil
}
logger.FromContext(ctx).Error("failed to stat scale-to-zero control file", "path", c.path, "err", err)
Expand All @@ -54,6 +55,7 @@ func (c *unikraftCloudController) write(ctx context.Context, char string) error
logger.FromContext(ctx).Error("failed to write scale-to-zero control file", "path", c.path, "err", err)
return err
}
logger.FromContext(ctx).Info("scale-to-zero control file written", "path", c.path, "value", char)
return nil
}

Expand Down
Loading