Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 78 additions & 7 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,29 +179,94 @@ bool S3ProxyOptions::Equals(const S3ProxyOptions& other) const {
username == other.username && password == other.password);
}

// -----------------------------------------------------------------------
// Custom comparison for AWS retry strategies
// To add a new strategy, add it to the AwsRetryStrategyVariant and
// add a new specialization to the AwsRetryStrategyEquality struct
using AwsRetryStrategyVariant =
std::variant<std::shared_ptr<Aws::Client::DefaultRetryStrategy>,
std::shared_ptr<Aws::Client::StandardRetryStrategy>>;

struct AwsRetryStrategyEquality {
bool operator()(const std::shared_ptr<Aws::Client::DefaultRetryStrategy>& lhs,
const std::shared_ptr<Aws::Client::DefaultRetryStrategy>& rhs) const {
if (!lhs && !rhs) return true;
if (!lhs || !rhs) return false;

return lhs->GetMaxAttempts() == rhs->GetMaxAttempts();
}

bool operator()(const std::shared_ptr<Aws::Client::StandardRetryStrategy>& lhs,
const std::shared_ptr<Aws::Client::StandardRetryStrategy>& rhs) const {
if (!lhs && !rhs) return true;
if (!lhs || !rhs) return false;

return lhs->GetMaxAttempts() == rhs->GetMaxAttempts();
}

// Template function for same unknown RetryStrategy type - returns true if same pointer
template <typename T>
bool operator()(const std::shared_ptr<T>& lhs, const std::shared_ptr<T>& rhs) const {
if (!lhs && !rhs) return true;
if (!lhs || !rhs) return false;

return lhs.get() == rhs.get();
}

// Template function for different RetryStrategy types - returns false for different
// types
template <typename T, typename U>
bool operator()(const std::shared_ptr<T>& lhs, const std::shared_ptr<U>& rhs) const {
return false;
}
};

// -----------------------------------------------------------------------
// AwsRetryStrategy implementation

class AwsRetryStrategy : public S3RetryStrategy {
public:
explicit AwsRetryStrategy(std::shared_ptr<Aws::Client::RetryStrategy> retry_strategy)
explicit AwsRetryStrategy(AwsRetryStrategyVariant retry_strategy)
: retry_strategy_(std::move(retry_strategy)) {}

bool ShouldRetry(const AWSErrorDetail& detail, int64_t attempted_retries) override {
Aws::Client::AWSError<Aws::Client::CoreErrors> error = DetailToError(detail);
return retry_strategy_->ShouldRetry(
error, static_cast<long>(attempted_retries)); // NOLINT: runtime/int
return std::visit(
[&](const auto& strategy) {
return strategy->ShouldRetry(
error, static_cast<long>(attempted_retries)); // NOLINT: runtime/int
},
retry_strategy_);
}

int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& detail,
int64_t attempted_retries) override {
Aws::Client::AWSError<Aws::Client::CoreErrors> error = DetailToError(detail);
return retry_strategy_->CalculateDelayBeforeNextRetry(
error, static_cast<long>(attempted_retries)); // NOLINT: runtime/int
return std::visit(
[&](const auto& strategy) {
return strategy->CalculateDelayBeforeNextRetry(
error, static_cast<long>(attempted_retries)); // NOLINT: runtime/int
},
retry_strategy_);
}

bool Equals(const S3RetryStrategy& other) const override {
auto other_aws = dynamic_cast<const AwsRetryStrategy*>(&other);
if (!other_aws) {
return false;
}

return std::visit(
[](const auto& lhs, const auto& rhs) {
return AwsRetryStrategyEquality()(lhs, rhs);
},
retry_strategy_, other_aws->retry_strategy_);
}

protected:
AwsRetryStrategyVariant retry_strategy_;

private:
std::shared_ptr<Aws::Client::RetryStrategy> retry_strategy_;
static Aws::Client::AWSError<Aws::Client::CoreErrors> DetailToError(
const S3RetryStrategy::AWSErrorDetail& detail) {
auto exception_name = ToAwsString(detail.exception_name);
Expand Down Expand Up @@ -426,6 +491,12 @@ bool S3Options::Equals(const S3Options& other) const {
default_metadata_size
? (other.default_metadata && other.default_metadata->Equals(*default_metadata))
: (!other.default_metadata || other.default_metadata->size() == 0);

// Compare retry strategies
const bool retry_strategy_equals = retry_strategy && other.retry_strategy
? retry_strategy->Equals(*other.retry_strategy)
: (!retry_strategy && !other.retry_strategy);

return (smart_defaults == other.smart_defaults && region == other.region &&
connect_timeout == other.connect_timeout &&
request_timeout == other.request_timeout &&
Expand All @@ -442,7 +513,7 @@ bool S3Options::Equals(const S3Options& other) const {
tls_ca_dir_path == other.tls_ca_dir_path &&
tls_verify_certificates == other.tls_verify_certificates &&
sse_customer_key == other.sse_customer_key && default_metadata_equals &&
GetAccessKey() == other.GetAccessKey() &&
retry_strategy_equals && GetAccessKey() == other.GetAccessKey() &&
GetSecretKey() == other.GetSecretKey() &&
GetSessionToken() == other.GetSessionToken());
}
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/filesystem/s3fs.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ class ARROW_EXPORT S3RetryStrategy {
/// Returns the time in milliseconds the S3 client should sleep for until retrying.
virtual int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& error,
int64_t attempted_retries) = 0;
/// Returns true if this retry strategy is equal to another retry strategy.
/// By default, it returns true if the two objects are of the same type.
virtual bool Equals(const S3RetryStrategy& other) const {
return typeid(*this) == typeid(other);
}
/// Returns a stock AWS Default retry strategy.
static std::shared_ptr<S3RetryStrategy> GetAwsDefaultRetryStrategy(
int64_t max_attempts);
Expand Down
103 changes: 103 additions & 0 deletions cpp/src/arrow/filesystem/s3fs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,109 @@ TEST_F(S3OptionsTest, FromAssumeRole) {
options = S3Options::FromAssumeRole("my_role_arn", "session", "id", 42, sts_client);
}

TEST_F(S3OptionsTest, RetryStrategyEquals) {
// Test DefaultRetryStrategy equality
auto default_strategy1 = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
auto default_strategy2 = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
auto default_strategy3 = S3RetryStrategy::GetAwsDefaultRetryStrategy(5);

ASSERT_TRUE(default_strategy1->Equals(*default_strategy2));
ASSERT_FALSE(default_strategy1->Equals(*default_strategy3));

// Test StandardRetryStrategy equality
auto standard_strategy1 = S3RetryStrategy::GetAwsStandardRetryStrategy(3);
auto standard_strategy2 = S3RetryStrategy::GetAwsStandardRetryStrategy(3);
auto standard_strategy3 = S3RetryStrategy::GetAwsStandardRetryStrategy(5);

ASSERT_TRUE(standard_strategy1->Equals(*standard_strategy2));
ASSERT_FALSE(standard_strategy1->Equals(*standard_strategy3));

// Test different strategy types
ASSERT_FALSE(default_strategy1->Equals(*standard_strategy1));
ASSERT_FALSE(standard_strategy1->Equals(*default_strategy1));
}

TEST_F(S3OptionsTest, RetryStrategyInS3Options) {
// Test S3Options with null retry strategy
S3Options options_null = S3Options::Defaults();
ASSERT_EQ(options_null.retry_strategy, nullptr);

// Test S3Options with DefaultRetryStrategy - different max_attempts
S3Options options_default_3 = S3Options::Defaults();
options_default_3.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
ASSERT_NE(options_default_3.retry_strategy, nullptr);

S3Options options_default_5 = S3Options::Defaults();
options_default_5.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(5);
ASSERT_NE(options_default_5.retry_strategy, nullptr);

// Test S3Options with StandardRetryStrategy - different max_attempts
S3Options options_standard_3 = S3Options::Defaults();
options_standard_3.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(3);
ASSERT_NE(options_standard_3.retry_strategy, nullptr);

S3Options options_standard_5 = S3Options::Defaults();
options_standard_5.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(5);
ASSERT_NE(options_standard_5.retry_strategy, nullptr);

// Test equality: same strategy type and max_attempts should be equal
S3Options options_default_3_copy = S3Options::Defaults();
options_default_3_copy.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
ASSERT_TRUE(options_default_3.Equals(options_default_3_copy));

S3Options options_standard_5_copy = S3Options::Defaults();
options_standard_5_copy.retry_strategy =
S3RetryStrategy::GetAwsStandardRetryStrategy(5);
ASSERT_TRUE(options_standard_5.Equals(options_standard_5_copy));

// Test inequality: different max_attempts should not be equal
ASSERT_FALSE(options_default_3.Equals(options_default_5));
ASSERT_FALSE(options_standard_3.Equals(options_standard_5));

// Test inequality: different strategy types should not be equal
ASSERT_FALSE(options_default_3.Equals(options_standard_3));
ASSERT_FALSE(options_standard_5.Equals(options_default_5));

// Test inequality: null vs non-null retry strategy should not be equal
ASSERT_FALSE(options_null.Equals(options_default_3));
ASSERT_FALSE(options_default_3.Equals(options_null));
}

TEST_F(S3OptionsTest, RetryStrategyInS3FileSystem) {
// Test S3FileSystem with null retry strategy
S3Options options_null = S3Options::Defaults();
ASSERT_OK_AND_ASSIGN(auto fs_null, S3FileSystem::Make(options_null));
ASSERT_EQ(fs_null->options().retry_strategy, nullptr);

// Test S3FileSystem with DefaultRetryStrategy
S3Options options_default = S3Options::Defaults();
options_default.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
ASSERT_OK_AND_ASSIGN(auto fs_default, S3FileSystem::Make(options_default));
ASSERT_NE(fs_default->options().retry_strategy, nullptr);
ASSERT_TRUE(
fs_default->options().retry_strategy->Equals(*options_default.retry_strategy));

// Test that same default strategy but different max_attempts create different file
// systems
S3Options options_default_5 = S3Options::Defaults();
options_default_5.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(5);
ASSERT_OK_AND_ASSIGN(auto fs_default_5, S3FileSystem::Make(options_default_5));
ASSERT_FALSE(fs_default->Equals(*fs_default_5));

// Test S3FileSystem with StandardRetryStrategy
S3Options options_standard = S3Options::Defaults();
options_standard.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(5);
ASSERT_OK_AND_ASSIGN(auto fs_standard, S3FileSystem::Make(options_standard));
ASSERT_NE(fs_standard->options().retry_strategy, nullptr);
ASSERT_TRUE(
fs_standard->options().retry_strategy->Equals(*options_standard.retry_strategy));

// Test that different retry strategies create different file systems
ASSERT_FALSE(fs_null->Equals(*fs_default));
ASSERT_FALSE(fs_default->Equals(*fs_standard));
ASSERT_FALSE(fs_null->Equals(*fs_standard));
}

////////////////////////////////////////////////////////////////////////////
// Region resolution test

Expand Down
14 changes: 14 additions & 0 deletions python/pyarrow/_s3fs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class S3RetryStrategy:
def __init__(self, max_attempts=3):
self.max_attempts = max_attempts

def __reduce__(self):
return (self.__class__, (self.max_attempts,))


class AwsStandardS3RetryStrategy(S3RetryStrategy):
"""
Expand Down Expand Up @@ -281,6 +284,7 @@ cdef class S3FileSystem(FileSystem):

cdef:
CS3FileSystem* s3fs
object _retry_strategy

def __init__(self, *, access_key=None, secret_key=None, session_token=None,
bint anonymous=False, region=None, request_timeout=None,
Expand Down Expand Up @@ -412,9 +416,11 @@ cdef class S3FileSystem(FileSystem):
if isinstance(retry_strategy, AwsStandardS3RetryStrategy):
options.value().retry_strategy = CS3RetryStrategy.GetAwsStandardRetryStrategy(
retry_strategy.max_attempts)
self._retry_strategy = retry_strategy
elif isinstance(retry_strategy, AwsDefaultS3RetryStrategy):
options.value().retry_strategy = CS3RetryStrategy.GetAwsDefaultRetryStrategy(
retry_strategy.max_attempts)
self._retry_strategy = retry_strategy
else:
raise ValueError(f'Invalid retry_strategy {retry_strategy!r}')
if tls_ca_file_path is not None:
Expand Down Expand Up @@ -470,6 +476,7 @@ cdef class S3FileSystem(FileSystem):
allow_bucket_creation=opts.allow_bucket_creation,
allow_bucket_deletion=opts.allow_bucket_deletion,
check_directory_existence_before_creation=opts.check_directory_existence_before_creation,
retry_strategy=self._retry_strategy,
default_metadata=pyarrow_wrap_metadata(opts.default_metadata),
proxy_options={'scheme': frombytes(opts.proxy_options.scheme),
'host': frombytes(opts.proxy_options.host),
Expand All @@ -489,3 +496,10 @@ cdef class S3FileSystem(FileSystem):
The AWS region this filesystem connects to.
"""
return frombytes(self.s3fs.region())

@property
def retry_strategy(self):
"""
The retry strategy currently configured for this S3 filesystem.
"""
return self._retry_strategy
32 changes: 26 additions & 6 deletions python/pyarrow/tests/test_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,14 +1226,34 @@ def test_s3_options(pickle_module):
assert isinstance(fs, S3FileSystem)
assert pickle_module.loads(pickle_module.dumps(fs)) == fs

# Note that the retry strategy won't survive pickling for now
fs = S3FileSystem(
# Test S3FileSystem with different retry strategies
# They are equal only when the retry strategy is from the same class
# and the same parameters
fs_std_5 = S3FileSystem(
retry_strategy=AwsStandardS3RetryStrategy(max_attempts=5))
assert isinstance(fs, S3FileSystem)

fs = S3FileSystem(
assert isinstance(fs_std_5, S3FileSystem)
assert pickle_module.loads(pickle_module.dumps(fs_std_5)) == fs_std_5
assert fs_std_5.retry_strategy.max_attempts == 5

fs_std_10 = S3FileSystem(
retry_strategy=AwsStandardS3RetryStrategy(max_attempts=10))
assert isinstance(fs_std_10, S3FileSystem)
assert pickle_module.loads(pickle_module.dumps(fs_std_10)) == fs_std_10
assert fs_std_10.retry_strategy.max_attempts == 10
assert fs_std_10 != fs_std_5

fs_std_10_2 = S3FileSystem(
retry_strategy=AwsStandardS3RetryStrategy(max_attempts=10))
assert isinstance(fs_std_10_2, S3FileSystem)
assert pickle_module.loads(pickle_module.dumps(fs_std_10_2)) == fs_std_10_2
assert fs_std_10_2 == fs_std_10

fs_def_5 = S3FileSystem(
retry_strategy=AwsDefaultS3RetryStrategy(max_attempts=5))
assert isinstance(fs, S3FileSystem)
assert isinstance(fs_def_5, S3FileSystem)
assert pickle_module.loads(pickle_module.dumps(fs_def_5)) == fs_def_5
assert fs_def_5.retry_strategy.max_attempts == 5
assert fs_def_5 != fs_std_5

fs2 = S3FileSystem(role_arn='role')
assert isinstance(fs2, S3FileSystem)
Expand Down
Loading