Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions internal/cache/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"io"
"net/textproto"
"net/http"
"time"

"github.com/alecthomas/errors"
Expand Down Expand Up @@ -97,8 +97,8 @@ func (k *Key) MarshalText() ([]byte, error) {

// FilterTransportHeaders returns a copy of the given headers with standard HTTP transport headers removed.
// These headers are typically added by HTTP clients/servers and should not be cached.
func FilterTransportHeaders(headers textproto.MIMEHeader) textproto.MIMEHeader {
filtered := make(textproto.MIMEHeader)
func FilterTransportHeaders(headers http.Header) http.Header {
filtered := make(http.Header)
for key, values := range headers {
// Skip standard HTTP headers added by transport layer or that shouldn't be cached
if key == "Content-Length" || key == "Date" || key == "Accept-Encoding" ||
Expand All @@ -120,21 +120,21 @@ type Cache interface {
//
// Expired files MUST not be returned.
// Must return os.ErrNotExist if the file does not exist.
Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error)
Stat(ctx context.Context, key Key) (http.Header, error)
// Open an existing file in the cache.
//
// Expired files MUST NOT be returned.
// The returned headers MUST include a Last-Modified header.
// Must return os.ErrNotExist if the file does not exist.
Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error)
Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error)
// Create a new file in the cache.
//
// If "ttl" is zero, a maximum TTL MUST be used by the implementation.
//
// The file MUST NOT be available for read until completely written and closed.
//
// If the context is cancelled the object MUST NOT be made available in the cache.
Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error)
Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error)
// Delete a file from the cache.
//
// MUST be atomic.
Expand Down
7 changes: 3 additions & 4 deletions internal/cache/cachetest/suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"io"
"net/http"
"net/textproto"
"os"
"testing"
"time"
Expand Down Expand Up @@ -215,7 +214,7 @@ func testHeaders(t *testing.T, c cache.Cache) {
key := cache.NewKey("test-key-with-headers")

// Create headers to store
headers := textproto.MIMEHeader{
headers := http.Header{
"Content-Type": []string{"application/json"},
"Cache-Control": []string{"max-age=3600"},
"X-Custom-Field": []string{"custom-value"},
Expand Down Expand Up @@ -258,7 +257,7 @@ func testContextCancellation(t *testing.T, c cache.Cache) {

// Create an object with the cancellable context
key := cache.NewKey("test-cancelled")
writer, err := c.Create(cancelledCtx, key, textproto.MIMEHeader{}, time.Hour)
writer, err := c.Create(cancelledCtx, key, http.Header{}, time.Hour)
assert.NoError(t, err)

// Write some data
Expand Down Expand Up @@ -310,7 +309,7 @@ func testLastModified(t *testing.T, c cache.Cache) {
// Test with explicit Last-Modified header
key2 := cache.NewKey("test-last-modified-explicit")
explicitTime := time.Date(2023, 1, 15, 12, 30, 0, 0, time.UTC)
explicitHeaders := textproto.MIMEHeader{
explicitHeaders := http.Header{
"Last-Modified": []string{explicitTime.Format(http.TimeFormat)},
}

Expand Down
11 changes: 5 additions & 6 deletions internal/cache/disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"log/slog"
"maps"
"net/http"
"net/textproto"
"os"
"path/filepath"
"sort"
Expand Down Expand Up @@ -130,14 +129,14 @@ func (d *Disk) Size() int64 {
return d.size.Load()
}

func (d *Disk) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
func (d *Disk) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) {
if ttl > d.config.MaxTTL || ttl == 0 {
ttl = d.config.MaxTTL
}

now := time.Now()
// Clone headers to avoid concurrent map writes
clonedHeaders := make(textproto.MIMEHeader)
clonedHeaders := make(http.Header)
maps.Copy(clonedHeaders, headers)
if clonedHeaders.Get("Last-Modified") == "" {
clonedHeaders.Set("Last-Modified", now.UTC().Format(http.TimeFormat))
Expand Down Expand Up @@ -204,7 +203,7 @@ func (d *Disk) Delete(_ context.Context, key Key) error {
return nil
}

func (d *Disk) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) {
func (d *Disk) Stat(ctx context.Context, key Key) (http.Header, error) {
path := d.keyToPath(key)
fullPath := filepath.Join(d.config.Root, path)

Expand All @@ -229,7 +228,7 @@ func (d *Disk) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error)
return headers, nil
}

func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
func (d *Disk) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) {
path := d.keyToPath(key)
fullPath := filepath.Join(d.config.Root, path)

Expand Down Expand Up @@ -378,7 +377,7 @@ type diskWriter struct {
path string
tempPath string
expiresAt time.Time
headers textproto.MIMEHeader
headers http.Header
size int64
ctx context.Context
}
Expand Down
8 changes: 4 additions & 4 deletions internal/cache/disk_metadb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cache

import (
"encoding/json"
"net/textproto"
"net/http"
"time"

"github.com/alecthomas/errors"
Expand Down Expand Up @@ -55,7 +55,7 @@ func (s *diskMetaDB) setTTL(key Key, expiresAt time.Time) error {
}))
}

func (s *diskMetaDB) set(key Key, expiresAt time.Time, headers textproto.MIMEHeader) error {
func (s *diskMetaDB) set(key Key, expiresAt time.Time, headers http.Header) error {
ttlBytes, err := expiresAt.MarshalBinary()
if err != nil {
return errors.Errorf("failed to marshal TTL: %w", err)
Expand Down Expand Up @@ -90,8 +90,8 @@ func (s *diskMetaDB) getTTL(key Key) (time.Time, error) {
return expiresAt, errors.WithStack(err)
}

func (s *diskMetaDB) getHeaders(key Key) (textproto.MIMEHeader, error) {
var headers textproto.MIMEHeader
func (s *diskMetaDB) getHeaders(key Key) (http.Header, error) {
var headers http.Header
err := s.db.View(func(tx *bbolt.Tx) error {
bucket := tx.Bucket(headersBucketName)
headersBytes := bucket.Get(key[:])
Expand Down
5 changes: 2 additions & 3 deletions internal/cache/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"io"
"maps"
"net/http"
"net/textproto"
"os"

"github.com/alecthomas/errors"
Expand All @@ -27,7 +26,7 @@ func Fetch(client *http.Client, r *http.Request, c Cache) (*http.Response, error
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header(headers),
Header: headers,
Body: cr,
ContentLength: -1,
Request: r,
Expand All @@ -53,7 +52,7 @@ func FetchDirect(client *http.Client, r *http.Request, c Cache, key Key) (*http.
return resp, nil
}

responseHeaders := textproto.MIMEHeader(maps.Clone(resp.Header))
responseHeaders := maps.Clone(resp.Header)
cw, err := c.Create(r.Context(), key, responseHeaders, 0)
if err != nil {
_ = resp.Body.Close()
Expand Down
13 changes: 6 additions & 7 deletions internal/cache/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"maps"
"net/http"
"net/textproto"
"os"
"sync"
"time"
Expand All @@ -33,7 +32,7 @@ type MemoryConfig struct {
type memoryEntry struct {
data []byte
expiresAt time.Time
headers textproto.MIMEHeader
headers http.Header
}

type Memory struct {
Expand All @@ -53,7 +52,7 @@ func NewMemory(ctx context.Context, config MemoryConfig) (*Memory, error) {

func (m *Memory) String() string { return fmt.Sprintf("memory:%dMB", m.config.LimitMB) }

func (m *Memory) Stat(_ context.Context, key Key) (textproto.MIMEHeader, error) {
func (m *Memory) Stat(_ context.Context, key Key) (http.Header, error) {
m.mu.RLock()
defer m.mu.RUnlock()

Expand All @@ -69,7 +68,7 @@ func (m *Memory) Stat(_ context.Context, key Key) (textproto.MIMEHeader, error)
return entry.headers, nil
}

func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, http.Header, error) {
m.mu.RLock()
defer m.mu.RUnlock()

Expand All @@ -85,14 +84,14 @@ func (m *Memory) Open(_ context.Context, key Key) (io.ReadCloser, textproto.MIME
return io.NopCloser(bytes.NewReader(entry.data)), entry.headers, nil
}

func (m *Memory) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
func (m *Memory) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) {
if ttl == 0 {
ttl = m.config.MaxTTL
}

now := time.Now()
// Clone headers to avoid concurrent map writes
clonedHeaders := make(textproto.MIMEHeader)
clonedHeaders := make(http.Header)
maps.Copy(clonedHeaders, headers)
if clonedHeaders.Get("Last-Modified") == "" {
clonedHeaders.Set("Last-Modified", now.UTC().Format(http.TimeFormat))
Expand Down Expand Up @@ -136,7 +135,7 @@ type memoryWriter struct {
key Key
buf *bytes.Buffer
expiresAt time.Time
headers textproto.MIMEHeader
headers http.Header
closed bool
ctx context.Context
}
Expand Down
11 changes: 5 additions & 6 deletions internal/cache/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io"
"maps"
"net/http"
"net/textproto"
"os"
"time"

Expand All @@ -32,7 +31,7 @@ func NewRemote(baseURL string) *Remote {
func (c *Remote) String() string { return "remote:" + c.baseURL }

// Open retrieves an object from the remote.
func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) {
url := fmt.Sprintf("%s/%s", c.baseURL, key.String())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
Expand All @@ -53,13 +52,13 @@ func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MI
}

// Filter out HTTP transport headers
headers := FilterTransportHeaders(textproto.MIMEHeader(resp.Header))
headers := FilterTransportHeaders(resp.Header)

return resp.Body, headers, nil
}

// Stat retrieves headers for an object from the remote.
func (c *Remote) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) {
func (c *Remote) Stat(ctx context.Context, key Key) (http.Header, error) {
url := fmt.Sprintf("%s/%s", c.baseURL, key.String())
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
Expand All @@ -81,13 +80,13 @@ func (c *Remote) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error
}

// Filter out HTTP transport headers
headers := FilterTransportHeaders(textproto.MIMEHeader(resp.Header))
headers := FilterTransportHeaders(resp.Header)

return headers, nil
}

// Create stores a new object in the remote.
func (c *Remote) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
func (c *Remote) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) {
pr, pw := io.Pipe()

url := fmt.Sprintf("%s/%s", c.baseURL, key.String())
Expand Down
15 changes: 7 additions & 8 deletions internal/cache/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"log/slog"
"maps"
"net/http"
"net/textproto"
"os"
"runtime"
"time"
Expand Down Expand Up @@ -162,7 +161,7 @@ func (s *S3) keyToPath(key Key) string {
return hexKey[:2] + "/" + hexKey
}

func (s *S3) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) {
func (s *S3) Stat(ctx context.Context, key Key) (http.Header, error) {
objectName := s.keyToPath(key)

// Get object info to check metadata
Expand Down Expand Up @@ -190,7 +189,7 @@ func (s *S3) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) {

// Retrieve headers from metadata
// Note: UserMetadata keys are returned WITHOUT the "X-Amz-Meta-" prefix by minio-go
headers := make(textproto.MIMEHeader)
headers := make(http.Header)
if headersJSON := objInfo.UserMetadata["Headers"]; headersJSON != "" {
if err := json.Unmarshal([]byte(headersJSON), &headers); err != nil {
return nil, errors.Errorf("failed to unmarshal headers: %w", err)
Expand All @@ -205,7 +204,7 @@ func (s *S3) Stat(ctx context.Context, key Key) (textproto.MIMEHeader, error) {
return headers, nil
}

func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHeader, error) {
func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, error) {
objectName := s.keyToPath(key)

// Get object info to retrieve metadata and check expiration
Expand All @@ -230,7 +229,7 @@ func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHe
}

// Retrieve headers from metadata
headers := make(textproto.MIMEHeader)
headers := make(http.Header)
if headersJSON := objInfo.UserMetadata["Headers"]; headersJSON != "" {
if err := json.Unmarshal([]byte(headersJSON), &headers); err != nil {
return nil, nil, errors.Errorf("failed to unmarshal headers: %w", err)
Expand All @@ -251,13 +250,13 @@ func (s *S3) Open(ctx context.Context, key Key) (io.ReadCloser, textproto.MIMEHe
return obj, headers, nil
}

func (s *S3) Create(ctx context.Context, key Key, headers textproto.MIMEHeader, ttl time.Duration) (io.WriteCloser, error) {
func (s *S3) Create(ctx context.Context, key Key, headers http.Header, ttl time.Duration) (io.WriteCloser, error) {
if ttl > s.config.MaxTTL || ttl == 0 {
ttl = s.config.MaxTTL
}

// Clone headers to avoid concurrent access issues
clonedHeaders := make(textproto.MIMEHeader)
clonedHeaders := make(http.Header)
maps.Copy(clonedHeaders, headers)

expiresAt := time.Now().Add(ttl)
Expand Down Expand Up @@ -296,7 +295,7 @@ type s3Writer struct {
key Key
pipe *io.PipeWriter
expiresAt time.Time
headers textproto.MIMEHeader
headers http.Header
ctx context.Context
errCh chan error
}
Expand Down
Loading