Skip to content

Commit a8ced7e

Browse files
committed
feat(experimental): add write resumption strategy
1 parent 3587822 commit a8ced7e

File tree

8 files changed

+532
-31
lines changed

8 files changed

+532
-31
lines changed

google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ async def open(self, retry_policy: Optional[AsyncRetry] = None) -> None:
174174
if retry_policy is None:
175175
# Default policy: retry generic transient errors
176176
retry_policy = AsyncRetry(
177-
predicate=lambda e: isinstance(e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded))
177+
predicate=lambda e: isinstance(
178+
e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded)
179+
)
178180
)
179181

180182
async def _do_open():
@@ -201,7 +203,7 @@ async def download_ranges(
201203
self,
202204
read_ranges: List[Tuple[int, int, BytesIO]],
203205
lock: asyncio.Lock = None,
204-
retry_policy: AsyncRetry = None
206+
retry_policy: AsyncRetry = None,
205207
) -> None:
206208
"""Downloads multiple byte ranges from the object into the buffers
207209
provided by user with automatic retries.
@@ -260,7 +262,9 @@ async def download_ranges(
260262

261263
if retry_policy is None:
262264
retry_policy = AsyncRetry(
263-
predicate=lambda e: isinstance(e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded))
265+
predicate=lambda e: isinstance(
266+
e, (exceptions.ServiceUnavailable, exceptions.DeadlineExceeded)
267+
)
264268
)
265269

266270
# Initialize Global State for Retry Strategy
@@ -270,20 +274,19 @@ async def download_ranges(
270274
download_states[read_id] = _DownloadState(
271275
initial_offset=read_range[0],
272276
initial_length=read_range[1],
273-
user_buffer=read_range[2]
277+
user_buffer=read_range[2],
274278
)
275279

276280
initial_state = {
277281
"download_states": download_states,
278282
"read_handle": self.read_handle,
279-
"routing_token": None
283+
"routing_token": None,
280284
}
281285

282286
# Track attempts to manage stream reuse
283287
is_first_attempt = True
284288

285289
def stream_opener(requests: List[_storage_v2.ReadRange], state: Dict[str, Any]):
286-
287290
async def generator():
288291
nonlocal is_first_attempt
289292

@@ -294,7 +297,9 @@ async def generator():
294297
# We reopen if it's a redirect (token exists) OR if this is a retry
295298
# (not first attempt). This prevents trying to send data on a dead
296299
# stream from a previous failed attempt.
297-
should_reopen = (not is_first_attempt) or (current_token is not None)
300+
should_reopen = (not is_first_attempt) or (
301+
current_token is not None
302+
)
298303

299304
if should_reopen:
300305
# Close existing stream if any
@@ -313,16 +318,25 @@ async def generator():
313318
# Inject routing_token into metadata if present
314319
metadata = []
315320
if current_token:
316-
metadata.append(("x-goog-request-params", f"routing_token={current_token}"))
317-
318-
await self.read_obj_str.open(metadata=metadata if metadata else None)
321+
metadata.append(
322+
(
323+
"x-goog-request-params",
324+
f"routing_token={current_token}",
325+
)
326+
)
327+
328+
await self.read_obj_str.open(
329+
metadata=metadata if metadata else None
330+
)
319331
self._is_stream_open = True
320332

321333
# Mark first attempt as done; next time this runs it will be a retry
322334
is_first_attempt = False
323335

324336
# Send Requests
325-
for i in range(0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
337+
for i in range(
338+
0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
339+
):
326340
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
327341
await self.read_obj_str.send(
328342
_storage_v2.BidiReadObjectRequest(read_ranges=batch)

google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def update_state_from_response(
9797
raise DataCorruption(
9898
response,
9999
f"Offset mismatch for read_id {read_id}. "
100-
f"Expected {read_state.next_expected_offset}, got {chunk_offset}"
100+
f"Expected {read_state.next_expected_offset}, got {chunk_offset}",
101101
)
102102

103103
# Checksum Verification
@@ -111,7 +111,7 @@ def update_state_from_response(
111111
raise DataCorruption(
112112
response,
113113
f"Checksum mismatch for read_id {read_id}. "
114-
f"Server sent {server_checksum}, client calculated {client_checksum}."
114+
f"Server sent {server_checksum}, client calculated {client_checksum}.",
115115
)
116116

117117
# Update State & Write Data
@@ -130,7 +130,7 @@ def update_state_from_response(
130130
raise DataCorruption(
131131
response,
132132
f"Byte count mismatch for read_id {read_id}. "
133-
f"Expected {read_state.initial_length}, got {read_state.bytes_written}"
133+
f"Expected {read_state.initial_length}, got {read_state.bytes_written}",
134134
)
135135

136136
async def recover_state_on_failure(self, error: Exception, state: Any) -> None:
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict, IO, Iterable, Optional, Union
16+
17+
import google_crc32c
18+
from google.cloud._storage_v2.types import storage as storage_type
19+
from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError
20+
from google.cloud.storage._experimental.asyncio.retry.base_strategy import (
21+
_BaseResumptionStrategy,
22+
)
23+
24+
25+
class _WriteState:
26+
"""A helper class to track the state of a single upload operation.
27+
28+
Attributes:
29+
spec (AppendObjectSpec): The specification for the object to write.
30+
chunk_size (int): The size of chunks to read from the buffer.
31+
user_buffer (IO[bytes]): The data source.
32+
persisted_size (int): The amount of data confirmed as persisted by the server.
33+
bytes_sent (int): The amount of data currently sent in the active stream.
34+
write_handle (bytes | BidiWriteHandle | None): The handle for the append session.
35+
routing_token (str | None): Token for routing to the correct backend.
36+
is_complete (bool): Whether the upload has finished.
37+
"""
38+
39+
def __init__(
40+
self,
41+
spec: storage_type.AppendObjectSpec,
42+
chunk_size: int,
43+
user_buffer: IO[bytes],
44+
):
45+
self.spec = spec
46+
self.chunk_size = chunk_size
47+
self.user_buffer = user_buffer
48+
self.persisted_size: int = 0
49+
self.bytes_sent: int = 0
50+
self.write_handle: Union[bytes, Any, None] = None
51+
self.routing_token: Optional[str] = None
52+
self.is_complete: bool = False
53+
54+
55+
class _WriteResumptionStrategy(_BaseResumptionStrategy):
56+
"""The concrete resumption strategy for bidi writes."""
57+
58+
def generate_requests(
59+
self, state: Dict[str, Any]
60+
) -> Iterable[storage_type.BidiWriteObjectRequest]:
61+
"""Generates BidiWriteObjectRequests to resume or continue the upload.
62+
63+
For Appendable Objects, every stream opening should send an
64+
AppendObjectSpec. If resuming, the `write_handle` is added to that spec.
65+
"""
66+
write_state: _WriteState = state["write_state"]
67+
68+
# Mark that we have generated the first request for this stream attempt
69+
state["first_request"] = False
70+
71+
if write_state.write_handle:
72+
write_state.spec.write_handle = write_state.write_handle
73+
74+
if write_state.routing_token:
75+
write_state.spec.routing_token = write_state.routing_token
76+
77+
do_state_lookup = write_state.write_handle is not None
78+
yield storage_type.BidiWriteObjectRequest(
79+
append_object_spec=write_state.spec, state_lookup=do_state_lookup
80+
)
81+
82+
# The buffer should already be seeked to the correct position (persisted_size)
83+
# by the `recover_state_on_failure` method before this is called.
84+
while not write_state.is_complete:
85+
chunk = write_state.user_buffer.read(write_state.chunk_size)
86+
87+
# End of File detection
88+
if not chunk:
89+
write_state.is_complete = True
90+
yield storage_type.BidiWriteObjectRequest(
91+
write_offset=write_state.bytes_sent,
92+
finish_write=True,
93+
)
94+
return
95+
96+
checksummed_data = storage_type.ChecksummedData(content=chunk)
97+
checksum = google_crc32c.Checksum(chunk)
98+
checksummed_data.crc32c = int.from_bytes(checksum.digest(), "big")
99+
100+
request = storage_type.BidiWriteObjectRequest(
101+
write_offset=write_state.bytes_sent,
102+
checksummed_data=checksummed_data,
103+
)
104+
write_state.bytes_sent += len(chunk)
105+
106+
yield request
107+
108+
def update_state_from_response(
109+
self, response: storage_type.BidiWriteObjectResponse, state: Dict[str, Any]
110+
) -> None:
111+
"""Processes a server response and updates the write state."""
112+
write_state: _WriteState = state["write_state"]
113+
114+
if response.persisted_size is not None:
115+
if response.persisted_size > write_state.persisted_size:
116+
write_state.persisted_size = response.persisted_size
117+
118+
if response.write_handle:
119+
write_state.write_handle = response.write_handle
120+
121+
if response.resource:
122+
write_state.is_complete = True
123+
write_state.persisted_size = response.resource.size
124+
125+
async def recover_state_on_failure(
126+
self, error: Exception, state: Dict[str, Any]
127+
) -> None:
128+
"""Handles errors, specifically BidiWriteObjectRedirectedError, and rewinds state."""
129+
write_state: _WriteState = state["write_state"]
130+
cause = getattr(error, "cause", error)
131+
132+
# Extract routing token and potentially a new write handle.
133+
if isinstance(cause, BidiWriteObjectRedirectedError):
134+
if cause.routing_token:
135+
write_state.routing_token = cause.routing_token
136+
137+
if hasattr(cause, "write_handle") and cause.write_handle:
138+
write_state.write_handle = cause.write_handle
139+
140+
# We must assume any data sent beyond 'persisted_size' was lost.
141+
# Reset the user buffer to the last known good byte.
142+
write_state.user_buffer.seek(write_state.persisted_size)
143+
write_state.bytes_sent = write_state.persisted_size
144+
145+
# Mark next pass as a retry (not the absolute first request)
146+
state["first_request"] = False

google/cloud/storage/_media/requests/download.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,5 @@ def flush(self):
774774
def has_unconsumed_tail(self) -> bool:
775775
return self._decoder.has_unconsumed_tail
776776

777-
778777
else: # pragma: NO COVER
779778
_BrotliDecoder = None # type: ignore # pragma: NO COVER

tests/unit/asyncio/retry/test_reads_resumption_strategy.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,10 @@ def test_initialization(self):
4646

4747

4848
class TestReadResumptionStrategy(unittest.TestCase):
49-
5049
def setUp(self):
5150
self.strategy = _ReadResumptionStrategy()
5251

53-
self.state = {
54-
"download_states": {},
55-
"read_handle": None,
56-
"routing_token": None
57-
}
52+
self.state = {"download_states": {}, "read_handle": None, "routing_token": None}
5853

5954
def _add_download(self, read_id, offset=0, length=100, buffer=None):
6055
"""Helper to inject a download state into the correct nested location."""
@@ -66,7 +61,16 @@ def _add_download(self, read_id, offset=0, length=100, buffer=None):
6661
self.state["download_states"][read_id] = state
6762
return state
6863

69-
def _create_response(self, content, read_id, offset, crc=None, range_end=False, handle=None, has_read_range=True):
64+
def _create_response(
65+
self,
66+
content,
67+
read_id,
68+
offset,
69+
crc=None,
70+
range_end=False,
71+
handle=None,
72+
has_read_range=True,
73+
):
7074
"""Helper to create a response object."""
7175
checksummed_data = None
7276
if content is not None:

0 commit comments

Comments
 (0)