Skip to content

Commit 92fb660

Browse files
committed
GH-47348: [C++][Python] S3FileSystem can be pickled with AwsRetryStrategy
[C++] Added Equals() to S3RetryStrategy. Now S3Options and S3FileSystem will compare their retry strategies as well. [Python] Included retry_strategy in S3FileSystem.__reduce__ method. S3FileSystem's retry_strategy should survive serialization and deserialization. Signed-off-by: Kit Lee <7000003+wingkitlee0@users.noreply.github.com>
1 parent fddd356 commit 92fb660

File tree

5 files changed

+224
-13
lines changed

5 files changed

+224
-13
lines changed

cpp/src/arrow/filesystem/s3fs.cc

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,29 +179,92 @@ bool S3ProxyOptions::Equals(const S3ProxyOptions& other) const {
179179
username == other.username && password == other.password);
180180
}
181181

182+
// -----------------------------------------------------------------------
183+
// Custom comparison for AWS retry strategies
184+
// To add a new strategy, add it to the AwsRetryStrategyVariant and
185+
// add a new specialization to the AwsRetryStrategyEquality struct
186+
using AwsRetryStrategyVariant =
187+
std::variant<std::shared_ptr<Aws::Client::DefaultRetryStrategy>,
188+
std::shared_ptr<Aws::Client::StandardRetryStrategy>>;
189+
190+
struct AwsRetryStrategyEquality {
191+
bool operator()(const std::shared_ptr<Aws::Client::DefaultRetryStrategy>& lhs,
192+
const std::shared_ptr<Aws::Client::DefaultRetryStrategy>& rhs) const {
193+
if (!lhs && !rhs) return true;
194+
if (!lhs || !rhs) return false;
195+
196+
return lhs->GetMaxAttempts() == rhs->GetMaxAttempts();
197+
}
198+
199+
bool operator()(const std::shared_ptr<Aws::Client::StandardRetryStrategy>& lhs,
200+
const std::shared_ptr<Aws::Client::StandardRetryStrategy>& rhs) const {
201+
if (!lhs && !rhs) return true;
202+
if (!lhs || !rhs) return false;
203+
204+
return lhs->GetMaxAttempts() == rhs->GetMaxAttempts();
205+
}
206+
207+
// Template function for same unknown RetryStrategy type - returns true if same pointer
208+
template <typename T>
209+
bool operator()(const std::shared_ptr<T>& lhs, const std::shared_ptr<T>& rhs) const {
210+
if (!lhs && !rhs) return true;
211+
if (!lhs || !rhs) return false;
212+
213+
return lhs.get() == rhs.get();
214+
}
215+
216+
// Template function for different RetryStrategy types - returns false for different
217+
// types
218+
template <typename T, typename U>
219+
bool operator()(const std::shared_ptr<T>& lhs, const std::shared_ptr<U>& rhs) const {
220+
return false;
221+
}
222+
};
223+
182224
// -----------------------------------------------------------------------
183225
// AwsRetryStrategy implementation
184226

185227
class AwsRetryStrategy : public S3RetryStrategy {
186228
public:
187-
explicit AwsRetryStrategy(std::shared_ptr<Aws::Client::RetryStrategy> retry_strategy)
229+
explicit AwsRetryStrategy(AwsRetryStrategyVariant retry_strategy)
188230
: retry_strategy_(std::move(retry_strategy)) {}
189231

190232
bool ShouldRetry(const AWSErrorDetail& detail, int64_t attempted_retries) override {
191233
Aws::Client::AWSError<Aws::Client::CoreErrors> error = DetailToError(detail);
192-
return retry_strategy_->ShouldRetry(
193-
error, static_cast<long>(attempted_retries)); // NOLINT: runtime/int
234+
return std::visit(
235+
[&](const auto& strategy) {
236+
return strategy->ShouldRetry(error, attempted_retries);
237+
},
238+
retry_strategy_);
194239
}
195240

196241
int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& detail,
197242
int64_t attempted_retries) override {
198243
Aws::Client::AWSError<Aws::Client::CoreErrors> error = DetailToError(detail);
199-
return retry_strategy_->CalculateDelayBeforeNextRetry(
200-
error, static_cast<long>(attempted_retries)); // NOLINT: runtime/int
244+
return std::visit(
245+
[&](const auto& strategy) {
246+
return strategy->CalculateDelayBeforeNextRetry(error, attempted_retries);
247+
},
248+
retry_strategy_);
201249
}
202250

251+
bool Equals(const S3RetryStrategy& other) const override {
252+
auto other_aws = dynamic_cast<const AwsRetryStrategy*>(&other);
253+
if (!other_aws) {
254+
return false;
255+
}
256+
257+
return std::visit(
258+
[](const auto& lhs, const auto& rhs) {
259+
return AwsRetryStrategyEquality()(lhs, rhs);
260+
},
261+
retry_strategy_, other_aws->retry_strategy_);
262+
}
263+
264+
protected:
265+
AwsRetryStrategyVariant retry_strategy_;
266+
203267
private:
204-
std::shared_ptr<Aws::Client::RetryStrategy> retry_strategy_;
205268
static Aws::Client::AWSError<Aws::Client::CoreErrors> DetailToError(
206269
const S3RetryStrategy::AWSErrorDetail& detail) {
207270
auto exception_name = ToAwsString(detail.exception_name);
@@ -426,6 +489,12 @@ bool S3Options::Equals(const S3Options& other) const {
426489
default_metadata_size
427490
? (other.default_metadata && other.default_metadata->Equals(*default_metadata))
428491
: (!other.default_metadata || other.default_metadata->size() == 0);
492+
493+
// Compare retry strategies
494+
const bool retry_strategy_equals = retry_strategy && other.retry_strategy
495+
? retry_strategy->Equals(*other.retry_strategy)
496+
: (!retry_strategy && !other.retry_strategy);
497+
429498
return (smart_defaults == other.smart_defaults && region == other.region &&
430499
connect_timeout == other.connect_timeout &&
431500
request_timeout == other.request_timeout &&
@@ -442,7 +511,7 @@ bool S3Options::Equals(const S3Options& other) const {
442511
tls_ca_dir_path == other.tls_ca_dir_path &&
443512
tls_verify_certificates == other.tls_verify_certificates &&
444513
sse_customer_key == other.sse_customer_key && default_metadata_equals &&
445-
GetAccessKey() == other.GetAccessKey() &&
514+
retry_strategy_equals && GetAccessKey() == other.GetAccessKey() &&
446515
GetSecretKey() == other.GetSecretKey() &&
447516
GetSessionToken() == other.GetSessionToken());
448517
}

cpp/src/arrow/filesystem/s3fs.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ class ARROW_EXPORT S3RetryStrategy {
8686
/// Returns the time in milliseconds the S3 client should sleep for until retrying.
8787
virtual int64_t CalculateDelayBeforeNextRetry(const AWSErrorDetail& error,
8888
int64_t attempted_retries) = 0;
89+
/// Returns true if this retry strategy is equal to another retry strategy.
90+
/// By default, it returns true if the two objects are of the same type.
91+
virtual bool Equals(const S3RetryStrategy& other) const {
92+
return typeid(*this) == typeid(other);
93+
}
8994
/// Returns a stock AWS Default retry strategy.
9095
static std::shared_ptr<S3RetryStrategy> GetAwsDefaultRetryStrategy(
9196
int64_t max_attempts);

cpp/src/arrow/filesystem/s3fs_test.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,109 @@ TEST_F(S3OptionsTest, FromAssumeRole) {
414414
options = S3Options::FromAssumeRole("my_role_arn", "session", "id", 42, sts_client);
415415
}
416416

417+
TEST_F(S3OptionsTest, RetryStrategyEquals) {
418+
// Test DefaultRetryStrategy equality
419+
auto default_strategy1 = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
420+
auto default_strategy2 = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
421+
auto default_strategy3 = S3RetryStrategy::GetAwsDefaultRetryStrategy(5);
422+
423+
ASSERT_TRUE(default_strategy1->Equals(*default_strategy2));
424+
ASSERT_FALSE(default_strategy1->Equals(*default_strategy3));
425+
426+
// Test StandardRetryStrategy equality
427+
auto standard_strategy1 = S3RetryStrategy::GetAwsStandardRetryStrategy(3);
428+
auto standard_strategy2 = S3RetryStrategy::GetAwsStandardRetryStrategy(3);
429+
auto standard_strategy3 = S3RetryStrategy::GetAwsStandardRetryStrategy(5);
430+
431+
ASSERT_TRUE(standard_strategy1->Equals(*standard_strategy2));
432+
ASSERT_FALSE(standard_strategy1->Equals(*standard_strategy3));
433+
434+
// Test different strategy types
435+
ASSERT_FALSE(default_strategy1->Equals(*standard_strategy1));
436+
ASSERT_FALSE(standard_strategy1->Equals(*default_strategy1));
437+
}
438+
439+
TEST_F(S3OptionsTest, RetryStrategyInS3Options) {
440+
// Test S3Options with null retry strategy
441+
S3Options options_null = S3Options::Defaults();
442+
ASSERT_EQ(options_null.retry_strategy, nullptr);
443+
444+
// Test S3Options with DefaultRetryStrategy - different max_attempts
445+
S3Options options_default_3 = S3Options::Defaults();
446+
options_default_3.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
447+
ASSERT_NE(options_default_3.retry_strategy, nullptr);
448+
449+
S3Options options_default_5 = S3Options::Defaults();
450+
options_default_5.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(5);
451+
ASSERT_NE(options_default_5.retry_strategy, nullptr);
452+
453+
// Test S3Options with StandardRetryStrategy - different max_attempts
454+
S3Options options_standard_3 = S3Options::Defaults();
455+
options_standard_3.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(3);
456+
ASSERT_NE(options_standard_3.retry_strategy, nullptr);
457+
458+
S3Options options_standard_5 = S3Options::Defaults();
459+
options_standard_5.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(5);
460+
ASSERT_NE(options_standard_5.retry_strategy, nullptr);
461+
462+
// Test equality: same strategy type and max_attempts should be equal
463+
S3Options options_default_3_copy = S3Options::Defaults();
464+
options_default_3_copy.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
465+
ASSERT_TRUE(options_default_3.Equals(options_default_3_copy));
466+
467+
S3Options options_standard_5_copy = S3Options::Defaults();
468+
options_standard_5_copy.retry_strategy =
469+
S3RetryStrategy::GetAwsStandardRetryStrategy(5);
470+
ASSERT_TRUE(options_standard_5.Equals(options_standard_5_copy));
471+
472+
// Test inequality: different max_attempts should not be equal
473+
ASSERT_FALSE(options_default_3.Equals(options_default_5));
474+
ASSERT_FALSE(options_standard_3.Equals(options_standard_5));
475+
476+
// Test inequality: different strategy types should not be equal
477+
ASSERT_FALSE(options_default_3.Equals(options_standard_3));
478+
ASSERT_FALSE(options_standard_5.Equals(options_default_5));
479+
480+
// Test inequality: null vs non-null retry strategy should not be equal
481+
ASSERT_FALSE(options_null.Equals(options_default_3));
482+
ASSERT_FALSE(options_default_3.Equals(options_null));
483+
}
484+
485+
TEST_F(S3OptionsTest, RetryStrategyInS3FileSystem) {
486+
// Test S3FileSystem with null retry strategy
487+
S3Options options_null = S3Options::Defaults();
488+
ASSERT_OK_AND_ASSIGN(auto fs_null, S3FileSystem::Make(options_null));
489+
ASSERT_EQ(fs_null->options().retry_strategy, nullptr);
490+
491+
// Test S3FileSystem with DefaultRetryStrategy
492+
S3Options options_default = S3Options::Defaults();
493+
options_default.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(3);
494+
ASSERT_OK_AND_ASSIGN(auto fs_default, S3FileSystem::Make(options_default));
495+
ASSERT_NE(fs_default->options().retry_strategy, nullptr);
496+
ASSERT_TRUE(
497+
fs_default->options().retry_strategy->Equals(*options_default.retry_strategy));
498+
499+
// Test that same default strategy but different max_attempts create different file
500+
// systems
501+
S3Options options_default_5 = S3Options::Defaults();
502+
options_default_5.retry_strategy = S3RetryStrategy::GetAwsDefaultRetryStrategy(5);
503+
ASSERT_OK_AND_ASSIGN(auto fs_default_5, S3FileSystem::Make(options_default_5));
504+
ASSERT_FALSE(fs_default->Equals(*fs_default_5));
505+
506+
// Test S3FileSystem with StandardRetryStrategy
507+
S3Options options_standard = S3Options::Defaults();
508+
options_standard.retry_strategy = S3RetryStrategy::GetAwsStandardRetryStrategy(5);
509+
ASSERT_OK_AND_ASSIGN(auto fs_standard, S3FileSystem::Make(options_standard));
510+
ASSERT_NE(fs_standard->options().retry_strategy, nullptr);
511+
ASSERT_TRUE(
512+
fs_standard->options().retry_strategy->Equals(*options_standard.retry_strategy));
513+
514+
// Test that different retry strategies create different file systems
515+
ASSERT_FALSE(fs_null->Equals(*fs_default));
516+
ASSERT_FALSE(fs_default->Equals(*fs_standard));
517+
ASSERT_FALSE(fs_null->Equals(*fs_standard));
518+
}
519+
417520
////////////////////////////////////////////////////////////////////////////
418521
// Region resolution test
419522

python/pyarrow/_s3fs.pyx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class S3RetryStrategy:
120120
def __init__(self, max_attempts=3):
121121
self.max_attempts = max_attempts
122122

123+
def __reduce__(self):
124+
return (self.__class__, (self.max_attempts,))
125+
123126

124127
class AwsStandardS3RetryStrategy(S3RetryStrategy):
125128
"""
@@ -281,6 +284,7 @@ cdef class S3FileSystem(FileSystem):
281284

282285
cdef:
283286
CS3FileSystem* s3fs
287+
object _retry_strategy
284288

285289
def __init__(self, *, access_key=None, secret_key=None, session_token=None,
286290
bint anonymous=False, region=None, request_timeout=None,
@@ -412,9 +416,11 @@ cdef class S3FileSystem(FileSystem):
412416
if isinstance(retry_strategy, AwsStandardS3RetryStrategy):
413417
options.value().retry_strategy = CS3RetryStrategy.GetAwsStandardRetryStrategy(
414418
retry_strategy.max_attempts)
419+
self._retry_strategy = retry_strategy
415420
elif isinstance(retry_strategy, AwsDefaultS3RetryStrategy):
416421
options.value().retry_strategy = CS3RetryStrategy.GetAwsDefaultRetryStrategy(
417422
retry_strategy.max_attempts)
423+
self._retry_strategy = retry_strategy
418424
else:
419425
raise ValueError(f'Invalid retry_strategy {retry_strategy!r}')
420426
if tls_ca_file_path is not None:
@@ -470,6 +476,7 @@ cdef class S3FileSystem(FileSystem):
470476
allow_bucket_creation=opts.allow_bucket_creation,
471477
allow_bucket_deletion=opts.allow_bucket_deletion,
472478
check_directory_existence_before_creation=opts.check_directory_existence_before_creation,
479+
retry_strategy=self._retry_strategy,
473480
default_metadata=pyarrow_wrap_metadata(opts.default_metadata),
474481
proxy_options={'scheme': frombytes(opts.proxy_options.scheme),
475482
'host': frombytes(opts.proxy_options.host),
@@ -489,3 +496,10 @@ cdef class S3FileSystem(FileSystem):
489496
The AWS region this filesystem connects to.
490497
"""
491498
return frombytes(self.s3fs.region())
499+
500+
@property
501+
def retry_strategy(self):
502+
"""
503+
The retry strategy currently configured for this S3 filesystem.
504+
"""
505+
return self._retry_strategy

python/pyarrow/tests/test_fs.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,14 +1225,34 @@ def test_s3_options(pickle_module):
12251225
assert isinstance(fs, S3FileSystem)
12261226
assert pickle_module.loads(pickle_module.dumps(fs)) == fs
12271227

1228-
# Note that the retry strategy won't survive pickling for now
1229-
fs = S3FileSystem(
1228+
# Test S3FileSystem with different retry strategies
1229+
# They are equal only when the retry strategy is from the same class
1230+
# and the same parameters
1231+
fs_std_5 = S3FileSystem(
12301232
retry_strategy=AwsStandardS3RetryStrategy(max_attempts=5))
1231-
assert isinstance(fs, S3FileSystem)
1232-
1233-
fs = S3FileSystem(
1233+
assert isinstance(fs_std_5, S3FileSystem)
1234+
assert pickle_module.loads(pickle_module.dumps(fs_std_5)) == fs_std_5
1235+
assert fs_std_5.retry_strategy.max_attempts == 5
1236+
1237+
fs_std_10 = S3FileSystem(
1238+
retry_strategy=AwsStandardS3RetryStrategy(max_attempts=10))
1239+
assert isinstance(fs_std_10, S3FileSystem)
1240+
assert pickle_module.loads(pickle_module.dumps(fs_std_10)) == fs_std_10
1241+
assert fs_std_10.retry_strategy.max_attempts == 10
1242+
assert fs_std_10 != fs_std_5
1243+
1244+
fs_std_10_2 = S3FileSystem(
1245+
retry_strategy=AwsStandardS3RetryStrategy(max_attempts=10))
1246+
assert isinstance(fs_std_10_2, S3FileSystem)
1247+
assert pickle_module.loads(pickle_module.dumps(fs_std_10_2)) == fs_std_10_2
1248+
assert fs_std_10_2 == fs_std_10
1249+
1250+
fs_def_5 = S3FileSystem(
12341251
retry_strategy=AwsDefaultS3RetryStrategy(max_attempts=5))
1235-
assert isinstance(fs, S3FileSystem)
1252+
assert isinstance(fs_def_5, S3FileSystem)
1253+
assert pickle_module.loads(pickle_module.dumps(fs_def_5)) == fs_def_5
1254+
assert fs_def_5.retry_strategy.max_attempts == 5
1255+
assert fs_def_5 != fs_std_5
12361256

12371257
fs2 = S3FileSystem(role_arn='role')
12381258
assert isinstance(fs2, S3FileSystem)

0 commit comments

Comments
 (0)