Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 24 additions & 33 deletions rpcserver/jsonrpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ var (
)

const (
maxOriginIDLength = 255
requestSizeThreshold = 50_000

highPriorityHeader = "high_prio"
builderNetSentAtHeader = "X-BuilderNet-SentAtUs"
flashbotsSignatureHeader = "X-Flashbots-Signature"
flashbotsOriginHeader = "X-Flashbots-Origin"
FlashbotsOriginHeader = "X-Flashbots-Origin"
FlashbotsUserAgentHeader = "X-Flashbots-User-Agent"
)

type (
highPriorityKey struct{}
builderNetSentAtKey struct{}
signerKey struct{}
urlKey struct{}
originKey struct{}
sizeKey struct{}
requestKey struct{}
)

type jsonRPCRequest struct {
Expand Down Expand Up @@ -107,9 +107,6 @@ type JSONRPCHandlerOpts struct {
ExtractPriorityFromHeader bool
// If true, extracts the `X-BuilderNet-SentAtUs` header value and sets it in the context.
ExtractBuilderNetSentAtFromHeader bool
// If true extract value from x-flashbots-origin header
// Result can be extracted from the context using GetOrigin
ExtractOriginFromHeader bool
// GET response content
GetResponseContent []byte
// Custom handler for /readyz endpoint. If not nil then it is expected to write the response to the provided ResponseWriter.
Expand Down Expand Up @@ -183,6 +180,7 @@ func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
methodForMetrics := unknownMethodLabel
bigRequest := false
ctx := r.Context()
ctx = context.WithValue(ctx, requestKey{}, r)

defer func() {
incRequestCount(methodForMetrics, h.ServerName, bigRequest)
Expand Down Expand Up @@ -355,18 +353,6 @@ func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

if h.ExtractOriginFromHeader {
origin := r.Header.Get(flashbotsOriginHeader)
if origin != "" {
if len(origin) > maxOriginIDLength {
h.writeJSONRPCError(w, req.ID, CodeInvalidRequest, "x-flashbots-origin header is too long")
incIncorrectRequest(h.ServerName)
return
}
ctx = context.WithValue(ctx, originKey{}, origin)
}
}

// get method
method, ok := h.methods[req.Method]
if !ok {
Expand Down Expand Up @@ -442,25 +428,30 @@ func GetBuilderNetSentAt(ctx context.Context) time.Time {
return value
}

func GetOrigin(ctx context.Context) string {
value, ok := ctx.Value(originKey{}).(string)
if !ok {
return ""
}
return value
}

// WithOrigin returns a new request with the origin set in its context.
// Use this in middleware to set the origin for downstream handlers to read via GetOrigin.
func WithOrigin(r *http.Request, origin string) *http.Request {
ctx := context.WithValue(r.Context(), originKey{}, origin)
return r.WithContext(ctx)
}

func GetRequestSize(ctx context.Context) int {
return ctx.Value(sizeKey{}).(int)
}

func GetURL(ctx context.Context) *url.URL {
return ctx.Value(urlKey{}).(*url.URL)
}

// GetRequest returns the HTTP request from the context.
// This allows handlers to access any HTTP request data including headers.
// Returns nil if no request is in context.
func GetRequest(ctx context.Context) *http.Request {
req, _ := ctx.Value(requestKey{}).(*http.Request)
return req
}

// TestContextWithRequest returns a context with the given HTTP request.
// This is intended for use in tests only.
func TestContextWithRequest(ctx context.Context, req *http.Request) context.Context {
return context.WithValue(ctx, requestKey{}, req)
}

// TestContextWithSigner returns a context with the given signer address.
// This is intended for use in tests only.
func TestContextWithSigner(ctx context.Context, signer common.Address) context.Context {
return context.WithValue(ctx, signerKey{}, signer)
}
30 changes: 30 additions & 0 deletions rpcserver/jsonrpc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,33 @@ func TestURLExtraction(t *testing.T) {
require.Equal(t, "/fast?", capturedURL)
})
}

func TestGetRequest(t *testing.T) {
var capturedUserAgent string
var capturedCustomHeader string
testHandler := func(ctx context.Context) (string, error) {
req := GetRequest(ctx)
capturedUserAgent = req.Header.Get("User-Agent")
capturedCustomHeader = req.Header.Get("X-Custom-Header")
return "ok", nil
}

handler, err := NewJSONRPCHandler(map[string]interface{}{
"test": testHandler,
}, JSONRPCHandlerOpts{})
require.NoError(t, err)

body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"test","params":[]}`))
request, err := http.NewRequest(http.MethodPost, "/", body)
require.NoError(t, err)
request.Header.Add("Content-Type", "application/json")
request.Header.Add("User-Agent", "test-client/1.0")
request.Header.Add("X-Custom-Header", "custom-value")

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, request)

require.Equal(t, http.StatusOK, rr.Code)
require.Equal(t, "test-client/1.0", capturedUserAgent)
require.Equal(t, "custom-value", capturedCustomHeader)
}