Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ func (r *responseCopier) readAll() ([]byte, error) {

// forwardResp writes whole response as received to ResponseWriter
func (r *responseCopier) forwardResp(w http.ResponseWriter) error {
// no response was received, nothing to forward
if !r.responseReceived.Load() {
return nil
}

w.Header().Set("Content-Type", r.responseHeaders.Get("Content-Type"))
w.WriteHeader(r.responseStatus)

Expand Down
42 changes: 42 additions & 0 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package responses

import (
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -434,3 +435,44 @@ func TestRecordTokenUsage(t *testing.T) {
})
}
}

type mockResponseWriter struct {
headerCalled bool
writeCalled bool
writeHeaderCalled bool
}

func (mrw *mockResponseWriter) Header() http.Header {
mrw.headerCalled = true
return http.Header{}
}

func (mrw *mockResponseWriter) Write([]byte) (int, error) {
mrw.writeCalled = true
return 0, nil
}

func (mrw *mockResponseWriter) WriteHeader(statusCode int) {
mrw.writeHeaderCalled = true
}

func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
mrw := mockResponseWriter{}

respCopy := responseCopier{}
body := "test_body"
respCopy.buff.Write([]byte(body))

respCopy.forwardResp(&mrw)
require.False(t, mrw.headerCalled)
require.False(t, mrw.writeCalled)
require.False(t, mrw.writeHeaderCalled)

// after response is received data is forwarded
respCopy.responseReceived.Store(true)

respCopy.forwardResp(&mrw)
require.True(t, mrw.headerCalled)
require.True(t, mrw.writeCalled)
require.True(t, mrw.writeHeaderCalled)
}
2 changes: 2 additions & 0 deletions intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package responses
import (
"context"
"errors"
"fmt"
"net/http"
"time"

Expand Down Expand Up @@ -102,6 +103,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
if upstreamErr != nil && !respCopy.responseReceived.Load() {
// no response received from upstream, return custom error
i.sendCustomErr(ctx, w, http.StatusInternalServerError, upstreamErr)
return fmt.Errorf("failed to connect to upstream: %w", upstreamErr)
}

err = respCopy.forwardResp(w)
Expand Down