Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions aws/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,40 @@ func (s *S3) Client() *s3.Client {
// Get gets the object referred to by key and version from bucket and writes it into b.
// Version can be empty.
func (s *S3) Get(bucket, key, version string, b *bytes.Buffer) error {
_, err := s.GetWithContext(context.Background(), bucket, key, version, b)
return err
}

// Get gets the object referred to by key and version from bucket and writes it into b.
// with the provided context.
// Version can be empty.
func (s *S3) GetWithContext(
ctx context.Context,
bucket, key, version string,
w io.Writer,
) (int64, error) {

input := s3.GetObjectInput{
Key: aws.String(key),
Bucket: aws.String(bucket),
Key: aws.String(key),
}
if version != "" {
input.VersionId = aws.String(version)
}
result, err := s.client.GetObject(context.TODO(), &input)

result, err := s.client.GetObject(ctx, &input)
if err != nil {
return err
return 0, err
}
defer result.Body.Close()

_, err = b.ReadFrom(result.Body)
n, err := io.Copy(w, result.Body)

return err
// Distinguish cancellation from real errors
if ctx.Err() != nil {
return n, ctx.Err()
}
return n, err
}

// GetByteRange gets the specified byte range of an object referred to by key and version
Expand Down
149 changes: 144 additions & 5 deletions aws/s3/s3_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,23 @@ func setAwsEnv() {
os.Setenv("AWS_SECRET_ACCESS_KEY", "test")
os.Setenv("AWS_ACCESS_KEY_ID", "test")
os.Setenv("AWS_ENDPOINT_URL", customAWSEndpoint)
os.Setenv("AWS_S3_DISABLE_CHECKSUM", "true")
}

func setup() {
// setup environment variable to run AWS CLI/SDK
setAwsEnv()

// create bucket
if err := exec.Command( //nolint:gosec
cmd := exec.Command( //nolint:gosec
"aws", "s3api",
"create-bucket",
"--bucket", testBucket,
"--create-bucket-configuration", fmt.Sprintf(
"{\"LocationConstraint\": \"%v\"}", testRegion),
).Run(); err != nil {
)
if output, err := cmd.CombinedOutput(); err != nil {
fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output))
panic(err)
}
}
Expand Down Expand Up @@ -146,11 +149,12 @@ func awsCmdPutKeys(keys []string) {
testFile.Close()
}
// sync to bucket
if err := exec.Command(
cmd := exec.Command(
"aws", "s3",
"sync", tmpDir, fmt.Sprintf("s3://%v", testBucket),
).Run(); err != nil {

)
if output, err := cmd.CombinedOutput(); err != nil {
fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output))
panic(err)
}
}
Expand Down Expand Up @@ -360,6 +364,141 @@ func TestS3Get(t *testing.T) {
assert.Equal(t, testObjectData, dataObject.String())
}

func TestS3GetWithContext(t *testing.T) {
// ARRANGE
setup()
defer teardown()

awsCmdPopulateBucket()

client, err := New()
require.NoError(t, err, "error creating s3 client")

t.Run("normal", func(t *testing.T) {
var buf bytes.Buffer
ctx := context.Background()

// ACTION
written, err := client.GetWithContext(
ctx,
testBucket,
testObjectKey,
"",
&buf,
)

// ASSERT
require.NoError(t, err)
assert.Equal(t, int64(len(testObjectData)), written)
assert.Equal(t, testObjectData, buf.String())
})

t.Run("cancelled", func(t *testing.T) {
var buf bytes.Buffer
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately

// ACTION
written, err := client.GetWithContext(
ctx,
testBucket,
testObjectKey,
"",
&buf,
)

// ASSERT
require.Error(t, err)
assert.ErrorIs(t, err, context.Canceled)
assert.Equal(t, int64(0), written)
})

t.Run("cancel-during-processing", func(t *testing.T) {
// We’ll cancel after a portion of the object has been written to the buffer.
ctx, cancel := context.WithCancel(context.Background())
var buf bytes.Buffer

// Choose a threshold smaller than the total size so we cancel mid-stream.
sw := &cancelAfterNWriter{
dst: &buf,
cancel: cancel,
limit: 4, // cancel after 4 bytes are written
sleep: 0 * time.Millisecond, // optional; set to >0 to slow per-write
}

// ACTION
written, err := client.GetWithContext(
ctx,
testBucket,
testObjectKey,
"",
sw,
)
t.Log("written bytes:", written)
// ASSERT: it should end early with a context error and partial bytes written
require.Error(t, err, "expected error due to mid-run cancellation")
assert.ErrorIs(t, err, context.Canceled)
assert.GreaterOrEqual(t, written, int64(1), "should write some bytes before cancel")
assert.Equal(t, written, int64(buf.Len()), "buffer length should match reported written")
assert.Less(t, written, int64(len(testObjectData)), "should not complete full object")
})
}

// cancelAfterNWriter writes at most limit bytes to dst.
// Once limit is reached, it cancels ctx and returns context.Canceled.
// If a single Write would exceed the limit, it performs a **partial write**
// and then returns context.Canceled so the copy loop stops immediately.
type cancelAfterNWriter struct {
dst io.Writer
cancel context.CancelFunc
limit int64 // total bytes allowed before we cancel & error
sleep time.Duration
wrote int64
}

func (w *cancelAfterNWriter) Write(p []byte) (int, error) {
if w.sleep > 0 {
time.Sleep(w.sleep)
}

remaining := w.limit - w.wrote
if remaining <= 0 {
// Already reached the limit: cancel & error without writing.
if w.cancel != nil {
w.cancel()
w.cancel = nil
}
return 0, context.Canceled
}

// If the incoming chunk exceeds the remaining budget, do a **partial write**.
if int64(len(p)) > remaining {
// write only `remaining` bytes
n, err := w.dst.Write(p[:remaining])
if err != nil {
return n, err
}
w.wrote += int64(n)
// Now cancel & return error to abort the transfer
if w.cancel != nil {
w.cancel()
w.cancel = nil
}
// ignore underlying err to ensure we signal cancel; return context.Canceled with partial write
return n, context.Canceled
}

// Normal path: whole chunk fits.
n, err := w.dst.Write(p)
w.wrote += int64(n)
// If we *exactly* hit the limit after this write, cancel & error on the next call.
if w.wrote >= w.limit && w.cancel != nil {
w.cancel()
w.cancel = nil
}
return n, err
}

func TestS3GetByteRange(t *testing.T) {
// ARRANGE
setup()
Expand Down
Loading