Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 101 additions & 18 deletions aws/s3/s3_concurrent.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,90 @@ func newConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *Conc
// containing a single HydratedFile with an error is returned.
// Version can be empty, but must be the same for all objects.
func (s *S3Concurrent) GetAllConcurrently(bucket, version string, objects []types.Object) chan HydratedFile {
return s.GetAllConcurrentlyWithContext(context.Background(), bucket, version, objects)
}

// GetAllConcurrentlyWithContext gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles
// to the returned output channel. The closure of this channel is handled, however it's the caller's
// responsibility to purge the channel, and handle any errors present in the HydratedFiles.
// If the ConcurrencyManager is not initialised before calling GetAllConcurrentlyWithContext, an output channel
// containing a single HydratedFile with an error is returned.
// Version can be empty, but must be the same for all objects.
func (s *S3Concurrent) GetAllConcurrentlyWithContext(
ctx context.Context,
bucket, version string,
objects []types.Object,
) chan HydratedFile {

output := make(chan HydratedFile, 1)

// Early cancel check
select {
case <-ctx.Done():
output <- HydratedFile{Error: ctx.Err()}
close(output)
return output
default:
}

if s.manager == nil {
output := make(chan HydratedFile, 1)
output <- HydratedFile{Error: errors.New("error getting files from S3, Concurrency Manager not initialised")}
output <- HydratedFile{
Error: errors.New("error getting files from S3, Concurrency Manager not initialised"),
}
close(output)
return output
}

if s.manager.memoryTotalSize < s.manager.calculateRequiredMemoryFor(objects) {
output := make(chan HydratedFile, 1)
output <- HydratedFile{Error: fmt.Errorf("error: bytes requested greater than max allowed by server (%v)", s.manager.memoryTotalSize)}
output <- HydratedFile{
Error: fmt.Errorf(
"error: bytes requested greater than max allowed by server (%v)",
s.manager.memoryTotalSize,
),
}
close(output)
return output
}

// Secure memory for all objects upfront.
s.manager.secureMemory(objects) // 0.

// ensure memory is released if context cancels before processing finishes
go func() {
<-ctx.Done()
// release all secured memory
for _, o := range objects {
s.manager.releaseMemory(aws.ToInt64(o.Size))
}
}()

processFunc := func(input types.Object) HydratedFile {
// Respect cancellation before starting work
select {
case <-ctx.Done():
return HydratedFile{Error: ctx.Err()}
default:
}

buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size)))
key := aws.ToString(input.Key)
err := s.Get(bucket, key, version, buf)

_, err := s.GetWithContext(ctx, bucket, key, version, buf)

// If context was cancelled during S3 read, surface that
if ctx.Err() != nil {
return HydratedFile{Error: ctx.Err()}
}

return HydratedFile{
Key: key,
Data: buf.Bytes(),
Error: err,
}
}
return s.manager.Process(processFunc, objects)

// Process with a context
return s.manager.Process(ctx, processFunc, objects)
}

// getWorker retrieves a number of workers from the manager's worker pool.
Expand Down Expand Up @@ -244,24 +299,46 @@ func (cm *ConcurrencyManager) releaseMemory(size int64) {
}
}

// Functions for providing a fan-out/fan-in operation. Workers are taken from the
// Functions for providing a fan-out/fan-in operation with provided context. Workers are taken from the
// worker pool and added to a WorkerGroup. All workers are returned to the pool once
// the jobs have finished.
func (cm *ConcurrencyManager) Process(asyncProcessor FileProcessor, objects []types.Object) chan HydratedFile {
workerGroup := cm.newWorkerGroup(context.Background(), asyncProcessor, cm.maxWorkersPerRequest) // 1.
func (cm *ConcurrencyManager) Process(
ctx context.Context,
asyncProcessor FileProcessor,
objects []types.Object,
) chan HydratedFile {

workerGroup := cm.newWorkerGroup(ctx, asyncProcessor, cm.maxWorkersPerRequest)

go func() {
defer func() {
close(workerGroup.reception)
workerGroup.stopWork()
}()

for _, obj := range objects {
workerGroup.addWork(obj)
select {
case <-ctx.Done():
return
default:
if !workerGroup.addWork(ctx, obj) {
return
}
}
}
workerGroup.stopWork() // 9.
}()
return workerGroup.returnOutput() // 2.

return workerGroup.returnOutput()
}

// start begins a worker's process of making itself available for work, doing the work,
// and repeat, until all work is done.
func (w *worker) start(ctx context.Context, processor FileProcessor, roster chan *worker, wg *sync.WaitGroup) {
func (w *worker) start(
ctx context.Context,
processor FileProcessor,
roster chan *worker,
wg *sync.WaitGroup,
) {
go func() {
defer func() {
wg.Done()
Expand Down Expand Up @@ -341,20 +418,26 @@ func (wg *workerGroup) startOutput() {
func (wg *workerGroup) cleanUp(ctx context.Context) {
<-ctx.Done()
wg.group.Wait() // 9.
close(wg.reception)
//close(wg.reception)
close(wg.roster)
}

// addWork gets the first available worker from the workerGroup's
// roster, and gives it an S3 Object to download. The worker's output
// channel is registered to the workerGroup's reception so that
// order is retained.
func (wg *workerGroup) addWork(newWork types.Object) { // 4.
func (wg *workerGroup) addWork(ctx context.Context, newWork types.Object) bool {
for w := range wg.roster {
w.input <- newWork
wg.reception <- w.output
break
select {
case <-ctx.Done():
return false
default:
w.input <- newWork
wg.reception <- w.output
return true
}
}
return false
}

// returnOutput returns the workerGroup's output channel.
Expand Down
178 changes: 177 additions & 1 deletion aws/s3/s3_concurrent_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package s3

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -89,7 +90,7 @@ func TestS3GetAllConcurrently(t *testing.T) {
}

// ASSERT input and output order is the same.
require.Equal(t, len(outputKeys), total)
require.Equal(t, total, len(outputKeys))
for i := 0; i < total; i++ {
assert.Equal(t, aws.ToString(objects[i].Key), outputKeys[i])
}
Expand Down Expand Up @@ -121,3 +122,178 @@ func TestS3GetAllConcurrently(t *testing.T) {
}
}
}

// go test --run TestS3GetAllConcurrentlyWithContext -v
func TestS3GetAllConcurrentlyWithContext(t *testing.T) {
// ARRANGE
setup()
defer teardown()

// ASSERT parameter errors.
_, err := NewConcurrent(0, 100, 1000)
assert.NotNil(t, err)
_, err = NewConcurrent(100, 0, 1000)
assert.NotNil(t, err)
_, err = NewConcurrent(100, 100, 0)
assert.NotNil(t, err)
_, err = NewConcurrent(100, 10, 99)
assert.NotNil(t, err)
_, err = NewConcurrent(100, 101, 1000)
assert.NotNil(t, err)

client, err := NewConcurrent(100, 10, 1000)
require.Nil(t, err, fmt.Sprintf("error creating s3 client concurrency manager: %v", err))

// ASSERT computed fields.
assert.Equal(t, 100, len(client.manager.workerPool.channel))
assert.Equal(t, 100, len(client.manager.memoryPool.channel))
assert.Equal(t, int64(10), client.manager.memoryChunkSize)
assert.Equal(t, int64(10*100), client.manager.memoryTotalSize)
assert.Equal(t, 10, client.manager.maxWorkersPerRequest)

// ASSERT memory chunk size is correct in memory pool.
chunk := <-client.manager.memoryPool.channel
assert.Equal(t, int64(10), chunk)
client.manager.memoryPool.channel <- chunk

// ASSERT worker get/release methods work expectedly.
w := client.manager.getWorkers(1)
assert.Equal(t, 99, len(client.manager.workerPool.channel))
client.manager.returnWorker(w[0])
assert.Equal(t, 100, len(client.manager.workerPool.channel))

// ASSERT memory get/release methods work expectedly.
elevenByteFile := types.Object{Size: aws.Int64(11)} // requires 2 memory chunks.
client.manager.secureMemory([]types.Object{elevenByteFile})
assert.Equal(t, 98, len(client.manager.memoryPool.channel))
client.manager.releaseMemory(20)
assert.Equal(t, 100, len(client.manager.memoryPool.channel))

// ARRANGE bucket with test objects.
total := 20
keys := make([]string, total)
for i := 0; i < total; i++ {
keys[i] = fmt.Sprintf("%s-%v", testObjectKey, i)
}
awsCmdPutKeys(keys)

// ACTION
objects, _ := client.ListAllObjects(testBucket, "")
tooManyBytes := make([]types.Object, 10*len(objects))
for _, o := range objects {
for i := 0; i < 10; i++ {
tooManyBytes = append(tooManyBytes, o)
}
}
output := client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", tooManyBytes)

// ASSERT error returned
for hf := range output {
assert.NotNil(t, hf.Error)
}

// ACTION
objects, _ = client.ListAllObjects(testBucket, "")
output = client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", objects)
outputKeys := make([]string, 0)
for hf := range output {
outputKeys = append(outputKeys, hf.Key)
}

// ASSERT input and output order is the same.
require.Equal(t, total, len(outputKeys))
for i := 0; i < total; i++ {
assert.Equal(t, aws.ToString(objects[i].Key), outputKeys[i])
}

// ASSERT all workers and memory returned to pools.
time.Sleep(2 * time.Second)
assert.Equal(t, 100, len(client.manager.workerPool.channel))
assert.Equal(t, 100, len(client.manager.memoryPool.channel))

// ASSERT that process blocked when all memory secured.
tenByteFile := types.Object{Size: aws.Int64(10)}
oneThousandBytesOfFiles := make([]types.Object, 100)
for i := 0; i < 100; i++ {
oneThousandBytesOfFiles[i] = tenByteFile
}
client.manager.secureMemory(oneThousandBytesOfFiles)
ch := make(chan chan HydratedFile)
go func() {
ch <- client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", objects)
}()

for {
select {
case <-ch:
t.Error("process was not blocked")
case <-time.After(time.Second):
// Timed out as expected
return
}
}
}

// go test --run TestS3GetAllConcurrentlyWithContext_Cancel -v
func TestS3GetAllConcurrentlyWithContext_Cancel(t *testing.T) {
// ARRANGE
setup()
defer teardown()

client, err := NewConcurrent(100, 10, 1000)
require.NoError(t, err)

total := 20
keys := make([]string, total)
for i := 0; i < total; i++ {
keys[i] = fmt.Sprintf("%s-%v", testObjectKey, i)
}
awsCmdPutKeys(keys)
objects, _ := client.ListAllObjects(testBucket, "")

t.Run("early-cancel-before-start", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

out := client.GetAllConcurrentlyWithContext(ctx, testBucket, "", objects)

var got []HydratedFile
for hf := range out {
got = append(got, hf)
}
require.Len(t, got, 1)
require.ErrorIs(t, got[0].Error, context.Canceled)

time.Sleep(200 * time.Millisecond)
assert.Equal(t, 100, len(client.manager.workerPool.channel))
assert.Equal(t, 100, len(client.manager.memoryPool.channel))
})
t.Run("cancel-during-processing", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
out := client.GetAllConcurrentlyWithContext(ctx, testBucket, "", objects)

collected := make([]HydratedFile, 0, len(objects))
cancelAfter := 3

for hf := range out {
collected = append(collected, hf)
if len(collected) == cancelAfter {
cancel()
}
}
// At least some work completed
require.GreaterOrEqual(t, len(collected), cancelAfter)
// But not all objects should be processed
require.Less(t, len(collected), len(objects))
// Pool recovery
require.Eventually(t, func() bool {
return len(client.manager.workerPool.channel) == 100
}, 5*time.Second, 10*time.Millisecond, fmt.Sprintf("workers pool not recovered, expected %d actual %d", 100, len(client.manager.workerPool.channel)))
require.Eventually(t, func() bool {
return len(client.manager.memoryPool.channel) == 100
}, 5*time.Second, 10*time.Millisecond, fmt.Sprintf("memory pool not recovered, expected %d actual %d", 100, len(client.manager.memoryPool.channel)))

})

}
Loading
Loading