Skip to content

Commit 939a025

Browse files
committed
Fixes:
- Correctly import blobs at aio - default capabilities to empty list - Revert async deferred init - Add async tests - More aggressive checks for externalized payload
1 parent b5c7ec3 commit 939a025

File tree

5 files changed

+87
-43
lines changed

5 files changed

+87
-43
lines changed

durabletask/extensions/azure_blob_payloads/blob_payload_store.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
import gzip
99
import logging
1010
import uuid
11-
from typing import TYPE_CHECKING, Optional
11+
from typing import Optional
1212

1313
from azure.storage.blob import BlobServiceClient
14-
15-
if TYPE_CHECKING:
16-
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
14+
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
1715

1816
from durabletask.extensions.azure_blob_payloads.options import BlobPayloadStoreOptions
1917
from durabletask.payload.store import PayloadStore
@@ -69,35 +67,25 @@ def __init__(self, options: BlobPayloadStoreOptions):
6967
**extra_kwargs,
7068
)
7169

72-
# Async client is built lazily to avoid importing
73-
# azure.storage.blob.aio when only sync methods are used.
74-
self._async_blob_service_client: AsyncBlobServiceClient | None = None
75-
self._extra_kwargs = extra_kwargs
70+
# Build async client
71+
if options.connection_string:
72+
self._async_blob_service_client = AsyncBlobServiceClient.from_connection_string(
73+
options.connection_string, **extra_kwargs,
74+
)
75+
else:
76+
assert options.account_url is not None # guaranteed by validation above
77+
self._async_blob_service_client = AsyncBlobServiceClient(
78+
account_url=options.account_url,
79+
credential=options.credential,
80+
**extra_kwargs,
81+
)
7682

7783
self._ensure_container_created = False
7884

7985
@property
8086
def options(self) -> BlobPayloadStoreOptions:
8187
return self._options
8288

83-
def _get_async_blob_service_client(self) -> AsyncBlobServiceClient:
84-
"""Lazily create the async blob service client."""
85-
if self._async_blob_service_client is None:
86-
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
87-
88-
if self._options.connection_string:
89-
self._async_blob_service_client = AsyncBlobServiceClient.from_connection_string(
90-
self._options.connection_string, **self._extra_kwargs,
91-
)
92-
else:
93-
assert self._options.account_url is not None
94-
self._async_blob_service_client = AsyncBlobServiceClient(
95-
account_url=self._options.account_url,
96-
credential=self._options.credential,
97-
**self._extra_kwargs,
98-
)
99-
return self._async_blob_service_client
100-
10189
# ------------------------------------------------------------------
10290
# Sync operations
10391
# ------------------------------------------------------------------
@@ -138,8 +126,7 @@ async def upload_async(self, data: bytes, *, instance_id: Optional[str] = None)
138126
data = gzip.compress(data)
139127

140128
blob_name = self._make_blob_name(instance_id)
141-
client = self._get_async_blob_service_client()
142-
container_client = client.get_container_client(self._container_name)
129+
container_client = self._async_blob_service_client.get_container_client(self._container_name)
143130
await container_client.upload_blob(name=blob_name, data=data, overwrite=True)
144131

145132
token = f"{_TOKEN_PREFIX}{self._container_name}:{blob_name}"
@@ -148,8 +135,7 @@ async def upload_async(self, data: bytes, *, instance_id: Optional[str] = None)
148135

149136
async def download_async(self, token: str) -> bytes:
150137
container, blob_name = self._parse_token(token)
151-
client = self._get_async_blob_service_client()
152-
container_client = client.get_container_client(container)
138+
container_client = self._async_blob_service_client.get_container_client(container)
153139
stream = await container_client.download_blob(blob_name)
154140
blob_data = await stream.readall()
155141

@@ -164,7 +150,11 @@ async def download_async(self, token: str) -> bytes:
164150
# ------------------------------------------------------------------
165151

166152
def is_known_token(self, value: str) -> bool:
167-
return value.startswith(_TOKEN_PREFIX)
153+
try:
154+
self._parse_token(value)
155+
return True
156+
except ValueError:
157+
return False
168158

169159
@staticmethod
170160
def _parse_token(token: str) -> tuple[str, str]:
@@ -203,8 +193,7 @@ def _ensure_container_sync(self) -> None:
203193
async def _ensure_container_async(self) -> None:
204194
if self._ensure_container_created:
205195
return
206-
client = self._get_async_blob_service_client()
207-
container_client = client.get_container_client(self._container_name)
196+
container_client = self._async_blob_service_client.get_container_client(self._container_name)
208197
try:
209198
await container_client.create_container()
210199
except Exception:

durabletask/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def should_invalidate_connection(rpc_error):
508508
get_work_items_request = pb.GetWorkItemsRequest(
509509
maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
510510
maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
511-
capabilities=capabilities if capabilities else None,
511+
capabilities=capabilities,
512512
)
513513
self._response_stream = stub.GetWorkItems(get_work_items_request)
514514
self._logger.info(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ opentelemetry = [
3737
"opentelemetry-sdk>=1.0.0"
3838
]
3939
azure-blob-payloads = [
40-
"azure-storage-blob>=12.0.0"
40+
"azure-storage-blob[aio]>=12.0.0"
4141
]
4242

4343
[project.urls]

tests/durabletask/test_large_payload.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
"""Tests for large-payload externalization and de-externalization."""
55

6-
import asyncio
76
from typing import Optional
87
from unittest.mock import MagicMock
98

@@ -317,7 +316,8 @@ def test_externalize_then_deexternalize(self):
317316

318317

319318
class TestAsyncPayloadHelpers:
320-
def test_async_externalize_and_deexternalize(self):
319+
@pytest.mark.asyncio
320+
async def test_async_externalize_and_deexternalize(self):
321321
"""Async versions should work identically to sync."""
322322
store = FakePayloadStore(threshold_bytes=10)
323323
original = "async round trip " * 20
@@ -328,14 +328,10 @@ def test_async_externalize_and_deexternalize(self):
328328
input=sv(original),
329329
)
330330

331-
asyncio.get_event_loop().run_until_complete(
332-
externalize_payloads_async(req, store, instance_id="async-1")
333-
)
331+
await externalize_payloads_async(req, store, instance_id="async-1")
334332
assert req.input.value.startswith(FakePayloadStore.TOKEN_PREFIX)
335333

336-
asyncio.get_event_loop().run_until_complete(
337-
deexternalize_payloads_async(req, store)
338-
)
334+
await deexternalize_payloads_async(req, store)
339335
assert req.input.value == original
340336

341337

@@ -402,10 +398,13 @@ def test_is_known_token(self):
402398

403399
store = MagicMock(spec=BlobPayloadStore)
404400
store.is_known_token = BlobPayloadStore.is_known_token.__get__(store)
401+
store._parse_token = BlobPayloadStore._parse_token
405402

406403
assert store.is_known_token("blob:v1:c:b") is True
407404
assert store.is_known_token("not-a-token") is False
408405
assert store.is_known_token("") is False
406+
assert store.is_known_token("blob:v1:") is False
407+
assert store.is_known_token("blob:v1:container:") is False
409408

410409

411410
# ------------------------------------------------------------------

tests/durabletask/test_large_payload_e2e.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,59 @@ def echo(ctx: task.OrchestrationContext, inp: str):
357357
svc.delete_container(fresh_container)
358358
except Exception:
359359
pass
360+
361+
362+
class TestAsyncBlobClient:
363+
"""Test upload_async / download_async against Azurite."""
364+
365+
@pytest.fixture()
366+
def async_store(self):
367+
"""Per-test BlobPayloadStore with a unique container."""
368+
container = f"async-test-{uuid.uuid4().hex[:8]}"
369+
store = BlobPayloadStore(BlobPayloadStoreOptions(
370+
connection_string=AZURITE_CONN_STR,
371+
container_name=container,
372+
threshold_bytes=THRESHOLD_BYTES,
373+
enable_compression=True,
374+
api_version=AZURITE_API_VERSION,
375+
))
376+
yield store
377+
try:
378+
svc = azure_blob.BlobServiceClient.from_connection_string(
379+
AZURITE_CONN_STR, api_version=AZURITE_API_VERSION,
380+
)
381+
svc.delete_container(container)
382+
except Exception:
383+
pass
384+
385+
@pytest.mark.asyncio
386+
async def test_async_upload_and_download_round_trip(self, async_store):
387+
"""upload_async stores data that download_async can retrieve."""
388+
payload = b"async round-trip payload " * 200
389+
token = await async_store.upload_async(payload, instance_id="async-1")
390+
391+
assert async_store.is_known_token(token)
392+
result = await async_store.download_async(token)
393+
assert result == payload
394+
395+
@pytest.mark.asyncio
396+
async def test_async_upload_with_compression(self, async_store):
397+
"""Compressed upload should still decompress on download."""
398+
payload = b"Z" * 5000
399+
token = await async_store.upload_async(payload)
400+
401+
downloaded = await async_store.download_async(token)
402+
assert downloaded == payload
403+
404+
@pytest.mark.asyncio
405+
async def test_async_upload_instance_id_scopes_blob(self, async_store):
406+
"""Blobs uploaded with instance_id are scoped under that prefix."""
407+
payload = b"scoped payload"
408+
token = await async_store.upload_async(payload, instance_id="inst-42")
409+
410+
# Token format: blob:v1:<container>:<instance_id>/<uuid>
411+
_, blob_name = BlobPayloadStore._parse_token(token)
412+
assert blob_name.startswith("inst-42/")
413+
414+
downloaded = await async_store.download_async(token)
415+
assert downloaded == payload

0 commit comments

Comments
 (0)