Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"strings"

"github.com/tiny-go/codec"
"github.com/tiny-go/errors"
)

const (
acceptHeader = "Accept"
contentTypeHeader = "Content-Type"
acceptHeader = "Accept"
contentTypeHeader = "Content-Type"
contentLengthHeader = "Content-Length"
transferEncodingHeader = "Transfer-Encoding"
defaultTransferEncoding = "identity"
)

// codecKey is a private unique key that is used to put/get codec from the context.
Expand All @@ -35,16 +40,22 @@ func Codec(fn errors.HandlerFunc, codecs Codecs) Middleware {
var reqCodec, resCodec codec.Codec
// get request codec
if reqCodec = codecs.Lookup(r.Header.Get(contentTypeHeader)); reqCodec == nil {
fn(w, fmt.Sprintf("unsupported request codec: %q", r.Header.Get(contentTypeHeader)), http.StatusBadRequest)
return
if isContentTypeHeaderRequired(r) {
fn(w, fmt.Sprintf("unsupported request codec: %q", r.Header.Get(contentTypeHeader)), http.StatusBadRequest)
return
}
} else {
r = r.WithContext(context.WithValue(r.Context(), codecKey{"req"}, reqCodec))
}
r = r.WithContext(context.WithValue(r.Context(), codecKey{"req"}, reqCodec))
// get response codec
if resCodec = codecs.Lookup(r.Header.Get(acceptHeader)); resCodec == nil {
fn(w, fmt.Sprintf("unsupported response codec: %q", r.Header.Get(acceptHeader)), http.StatusBadRequest)
return
if isAcceptHeaderRequired(r) {
fn(w, fmt.Sprintf("unsupported response codec: %q", r.Header.Get(acceptHeader)), http.StatusBadRequest)
return
}
} else {
r = r.WithContext(context.WithValue(r.Context(), codecKey{"res"}, resCodec))
}
r = r.WithContext(context.WithValue(r.Context(), codecKey{"res"}, resCodec))
// call the next handler
next.ServeHTTP(w, r)
})
Expand All @@ -62,3 +73,72 @@ func ResponseCodecFromContext(ctx context.Context) codec.Codec {
codec, _ := ctx.Value(codecKey{"res"}).(codec.Codec)
return codec
}

// isContentTypeHeaderRequired returns the HTTP method request body type requirement.
// By RFC7231 (https://tools.ietf.org/html/rfc7231) only POST, PUT and PATCH methods
// should contain a request body. DELETE method body is optional.
func isContentTypeHeaderRequired(r *http.Request) bool {
switch r.Method {
// Body is required
case http.MethodPost: fallthrough
case http.MethodPut: fallthrough
case http.MethodPatch:
return shouldRequestBodyBeProcessed(r, true)
// May have body, but not required
case http.MethodDelete:
return shouldRequestBodyBeProcessed(r, false)
// No body
case http.MethodGet: fallthrough
case http.MethodHead: fallthrough
case http.MethodConnect: fallthrough
case http.MethodOptions: fallthrough
case http.MethodTrace: fallthrough
default:
return false
}
}

// isAcceptHeaderRequired returns the HTTP method response body type requirement.
// By RFC7231 (https://tools.ietf.org/html/rfc7231) only GET, POST, CONNECT,
// OPTIONS and PATCH methods should indicate the details of a response body.
// DELETE method response body is optional.
func isAcceptHeaderRequired(r *http.Request) bool {
switch r.Method {
// Body is required
case http.MethodGet: fallthrough
case http.MethodPost: fallthrough
case http.MethodConnect: fallthrough
case http.MethodOptions: fallthrough
case http.MethodPatch:
return true
// May have body, but not required
case http.MethodDelete: fallthrough
// No body
case http.MethodHead: fallthrough
case http.MethodPut: fallthrough
case http.MethodTrace: fallthrough
default:
return false
}
}

func shouldRequestBodyBeProcessed(r *http.Request, required bool) bool {
transferEncoding := r.Header.Get(transferEncodingHeader)
hasRequestBody := transferEncoding != "" && !strings.EqualFold(transferEncoding, defaultTransferEncoding)

hasRequestBody = hasRequestBody || func() bool {
contentLengthStr := r.Header.Get(contentLengthHeader)
if contentLengthStr != "" {
contentLength, err := strconv.Atoi(contentLengthStr)
if err != nil || contentLength < 0 {
return false
}

return contentLength > 0
}

return required
}()

return hasRequestBody
}
86 changes: 80 additions & 6 deletions codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,119 @@ func TestCodecFromList(t *testing.T) {
body string
}

type Data struct {
Test string
}

cases := []testCase{
{
title: "should throw an error if request codec in not supported",
title: "should throw an error if a request codec is required but not supported",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(nil),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodGet, "", nil)
r, _ := http.NewRequest(http.MethodPost, "", nil)
r.Header.Set(contentTypeHeader, "unknown")
r.Header.Set(contentLengthHeader, "1")
return r
}(),
code: http.StatusBadRequest,
body: "unsupported request codec: \"unknown\"\n",
},
{
title: "should throw an error if response codec in not supported",
title: "should ignore a request codec if not supported but not required",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("passed"))
}),
),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodDelete, "", nil)
r.Header.Set(contentTypeHeader, "unknown")
return r
}(),
code: http.StatusOK,
body: "passed",
},
{
title: "should use a request codec if supported but not required",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(
BodyClose(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var data Data
RequestCodecFromContext(r.Context()).Decoder(r.Body).Decode(&data)
w.Write([]byte(data.Test))
}),
),
),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodDelete, "", strings.NewReader("{\"test\":\"passed\"}\n"))
r.Header.Set(contentTypeHeader, "application/json")
r.Header.Set(contentLengthHeader, "1")
return r
}(),
code: http.StatusOK,
body: "passed",
},
{
title: "should throw an error if response codec is required but not supported",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(nil),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodGet, "", nil)
r, _ := http.NewRequest(http.MethodPost, "", nil)
r.Header.Set(contentTypeHeader, "application/json")
r.Header.Set(contentLengthHeader, "0")
r.Header.Set(acceptHeader, "unknown")
return r
}(),
code: http.StatusBadRequest,
body: "unsupported response codec: \"unknown\"\n",
},
{
title: "should ignore a response codec if not supported but not required",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("passed"))
}),
),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodDelete, "", nil)
r.Header.Set(acceptHeader, "unknown")
return r
}(),
code: http.StatusOK,
body: "passed",
},
{
title: "should use a response codec if supported but not required",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data := Data{
Test: "passed",
}
ResponseCodecFromContext(r.Context()).Encoder(w).Encode(data)
}),
),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodDelete, "", nil)
r.Header.Set(acceptHeader, "application/xml")
return r
}(),
code: http.StatusOK,
body: "<Data><Test>passed</Test></Data>",
},
{
title: "should find corresponding codecs and handle the request successfully",
handler: Codec(nil, driver.DummyRegistry{&json.JSON{}, &xml.XML{}})(
BodyClose(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
type Data struct{ Test string }
var data Data
RequestCodecFromContext(r.Context()).Decoder(r.Body).Decode(&data)
ResponseCodecFromContext(r.Context()).Encoder(w).Encode(data)
}),
),
),
request: func() *http.Request {
r, _ := http.NewRequest(http.MethodGet, "", strings.NewReader("{\"test\":\"passed\"}\n"))
r, _ := http.NewRequest(http.MethodPost, "", strings.NewReader("{\"test\":\"passed\"}\n"))
r.Header.Set(contentTypeHeader, "application/json")
r.Header.Set(contentLengthHeader, "1")
r.Header.Set(acceptHeader, "application/xml")
return r
}(),
Expand Down