Skip to content

Commit 26342bd

Browse files
committed
fix: session lifecycle and store.delete params
Claim 1 - session lifetime scoped to AsyncJigsawStack (async_request.py, __init__.py): - Add optional 'session' field to AsyncRequestConfig. - _SessionContext wraps a shared session without closing it on exit. - AsyncRequest reads the injected session; falls back to a per-request ClientSession when used standalone. - AsyncJigsawStack gains __aenter__ / __aexit__ / aclose() that open a single aiohttp.ClientSession and share it across all service calls. Claim 3 - Store.delete passes raw string as URL query params (store.py): - params=key sent the file key as a malformed query string. The key is already in the URL path. Fix is params={}. (DELETE has no body, so there is nothing to corrupt.) Tests: - tests/test_session_lifecycle.py 8 tests (claim 1) - tests/test_store_delete.py 3 tests (claim 3)
1 parent 5492b09 commit 26342bd

5 files changed

Lines changed: 311 additions & 41 deletions

File tree

jigsawstack/__init__.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
2-
from typing import Dict, Union
2+
from typing import Dict, List, Optional, Union
33

4+
import aiohttp
5+
6+
from ._config import ClientConfig
47
from .audio import AsyncAudio, Audio
58
from .classification import AsyncClassification, Classification
69
from .embedding import AsyncEmbedding, Embedding
@@ -155,51 +158,70 @@ def __init__(
155158
self.base_url = base_url
156159
self.headers = headers or {"Content-Type": "application/json"}
157160

158-
self.web = AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers)
161+
# _async_services holds every async service instance so that
162+
# __aenter__ / aclose() can inject / remove the shared session.
163+
self._async_services: List[ClientConfig] = []
164+
self._session: Optional[aiohttp.ClientSession] = None
159165

160-
self.validate = AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers)
166+
def _reg(svc: ClientConfig) -> ClientConfig:
167+
self._async_services.append(svc)
168+
return svc
161169

162-
self.audio = AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers)
170+
self.web = _reg(AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers))
163171

164-
self.vision = AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers)
172+
self.validate = _reg(AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers))
165173

166-
self.store = AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers)
174+
self.audio = _reg(AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers))
167175

168-
self.summary = AsyncSummary(
169-
api_key=api_key, base_url=base_url + "/v1", headers=headers
170-
).summarize
176+
self.vision = _reg(AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers))
171177

172-
self.prediction = AsyncPrediction(api_key=api_key, base_url=base_url + "/v1").predict
178+
self.store = _reg(AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers))
173179

174-
self.text_to_sql = AsyncSQL(
175-
api_key=api_key, base_url=base_url + "/v1", headers=headers
176-
).text_to_sql
180+
_summary = _reg(AsyncSummary(api_key=api_key, base_url=base_url + "/v1", headers=headers))
181+
self.summary = _summary.summarize
177182

178-
self.sentiment = AsyncSentiment(
179-
api_key=api_key, base_url=base_url + "/v1", headers=headers
180-
).analyze
183+
_prediction = _reg(AsyncPrediction(api_key=api_key, base_url=base_url + "/v1", headers=headers))
184+
self.prediction = _prediction.predict
181185

182-
self.translate = AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers)
186+
_sql = _reg(AsyncSQL(api_key=api_key, base_url=base_url + "/v1", headers=headers))
187+
self.text_to_sql = _sql.text_to_sql
183188

184-
self.embedding = AsyncEmbedding(
185-
api_key=api_key, base_url=base_url + "/v1", headers=headers
186-
).execute
189+
_sentiment = _reg(AsyncSentiment(api_key=api_key, base_url=base_url + "/v1", headers=headers))
190+
self.sentiment = _sentiment.analyze
187191

188-
self.embedding_v2 = AsyncEmbeddingV2(
189-
api_key=api_key, base_url=base_url + "/v2", headers=headers
190-
).execute
192+
self.translate = _reg(AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers))
191193

192-
self.image_generation = AsyncImageGeneration(
193-
api_key=api_key, base_url=base_url + "/v1", headers=headers
194-
).image_generation
194+
_embedding = _reg(AsyncEmbedding(api_key=api_key, base_url=base_url + "/v1", headers=headers))
195+
self.embedding = _embedding.execute
195196

196-
self.classification = AsyncClassification(
197-
api_key=api_key, base_url=base_url + "/v1", headers=headers
198-
).classify
197+
_embedding_v2 = _reg(AsyncEmbeddingV2(api_key=api_key, base_url=base_url + "/v2", headers=headers))
198+
self.embedding_v2 = _embedding_v2.execute
199199

200-
self.prompt_engine = AsyncPromptEngine(
201-
api_key=api_key, base_url=base_url + "/v1", headers=headers
202-
)
200+
_image_gen = _reg(AsyncImageGeneration(api_key=api_key, base_url=base_url + "/v1", headers=headers))
201+
self.image_generation = _image_gen.image_generation
202+
203+
_classification = _reg(AsyncClassification(api_key=api_key, base_url=base_url + "/v1", headers=headers))
204+
self.classification = _classification.classify
205+
206+
self.prompt_engine = _reg(AsyncPromptEngine(api_key=api_key, base_url=base_url + "/v1", headers=headers))
207+
208+
async def __aenter__(self) -> "AsyncJigsawStack":
209+
"""Open a shared aiohttp.ClientSession reused across all requests."""
210+
self._session = aiohttp.ClientSession()
211+
for svc in self._async_services:
212+
svc.config["session"] = self._session
213+
return self
214+
215+
async def aclose(self) -> None:
216+
"""Close the shared session and clear it from all service configs."""
217+
if self._session is not None:
218+
for svc in self._async_services:
219+
svc.config.pop("session", None)
220+
await self._session.close()
221+
self._session = None
222+
223+
async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
224+
await self.aclose()
203225

204226

205227
# Create a global instance of the Web class

jigsawstack/async_request.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import json
22
from io import BytesIO
3-
from typing import Any, AsyncGenerator, Dict, Generic, List, TypedDict, Union, cast
3+
from typing import Any, AsyncGenerator, Dict, Generic, List, Optional, TypedDict, Union, cast
44

55
import aiohttp
6-
from typing_extensions import Literal, TypeVar
6+
from typing_extensions import Literal, NotRequired, TypeVar
77

88
from .exceptions import NoContentError, raise_for_code_and_type
99

@@ -16,6 +16,23 @@ class AsyncRequestConfig(TypedDict):
1616
base_url: str
1717
api_key: str
1818
headers: Union[Dict[str, str], None]
19+
session: NotRequired[aiohttp.ClientSession]
20+
21+
22+
class _SessionContext:
23+
"""Async context manager that wraps an existing ClientSession without closing it.
24+
Used when a shared session is injected from the client (AsyncJigsawStack).
25+
"""
26+
__slots__ = ("_session",)
27+
28+
def __init__(self, session: aiohttp.ClientSession) -> None:
29+
self._session = session
30+
31+
async def __aenter__(self) -> aiohttp.ClientSession:
32+
return self._session
33+
34+
async def __aexit__(self, *_: object) -> None:
35+
pass # session lifetime is managed by the caller
1936

2037

2138
class AsyncRequest(Generic[T]):
@@ -38,6 +55,8 @@ def __init__(
3855
self.headers = config.get("headers", None) or {"Content-Type": "application/json"}
3956
self.stream = stream
4057
self.files = files # Store files for multipart requests
58+
# Optional shared session injected by AsyncJigsawStack.
59+
self._shared_session: Optional[aiohttp.ClientSession] = config.get("session", None)
4160

4261
def __convert_params(
4362
self, params: Union[Dict[Any, Any], List[Dict[Any, Any]]]
@@ -269,13 +288,14 @@ async def make_request(
269288
headers=headers,
270289
)
271290

272-
def __get_session(self) -> aiohttp.ClientSession:
291+
def __get_session(self) -> Union["_SessionContext", aiohttp.ClientSession]:
273292
"""
274-
Create and return an async client session.
275-
276-
Returns:
277-
aiohttp.ClientSession: An async client session
293+
Return an async context manager that provides a ClientSession.
294+
If a shared session was injected via config, reuse it without closing.
295+
Otherwise open a fresh session for this request only.
278296
"""
297+
if self._shared_session is not None:
298+
return _SessionContext(self._shared_session)
279299
return aiohttp.ClientSession()
280300

281301
@staticmethod

jigsawstack/store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def delete(self, key: str) -> FileDeleteResponse:
8080
resp = Request(
8181
config=self.config,
8282
path=path,
83-
params=key,
83+
params={},
8484
verb="delete",
8585
).perform_with_content()
8686
return resp
@@ -140,7 +140,7 @@ async def delete(self, key: str) -> FileDeleteResponse:
140140
resp = await AsyncRequest(
141141
config=self.config,
142142
path=path,
143-
params=key,
143+
params={},
144144
verb="delete",
145145
).perform_with_content()
146146
return resp

tests/test_session_lifecycle.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
tests/test_session_lifecycle.py
3+
4+
Tests for claim 1: AsyncJigsawStack session management.
5+
6+
Verifies that:
7+
- Without __aenter__, each request creates its own temporary session.
8+
- With __aenter__, a single shared ClientSession is injected into every
9+
service config and reused across requests.
10+
- __aexit__ / aclose() removes the session from all configs and closes it.
11+
- Re-entering after aclose() works correctly (fresh session).
12+
No real network calls are made.
13+
"""
14+
15+
import asyncio
16+
from unittest.mock import AsyncMock, MagicMock, patch
17+
18+
import aiohttp
19+
import pytest
20+
21+
from jigsawstack import AsyncJigsawStack
22+
from jigsawstack.async_request import AsyncRequest, AsyncRequestConfig
23+
24+
25+
# ---------------------------------------------------------------------------
26+
# Unit: _SessionContext
27+
# ---------------------------------------------------------------------------
28+
29+
class TestSessionContext:
30+
def test_reuses_session_without_closing(self):
31+
from jigsawstack.async_request import _SessionContext
32+
33+
mock_session = MagicMock(spec=aiohttp.ClientSession)
34+
ctx = _SessionContext(mock_session)
35+
36+
async def run():
37+
async with ctx as s:
38+
assert s is mock_session
39+
mock_session.close.assert_not_called()
40+
41+
asyncio.run(run())
42+
43+
44+
# ---------------------------------------------------------------------------
45+
# Unit: AsyncRequest picks up shared session from config
46+
# ---------------------------------------------------------------------------
47+
48+
class TestAsyncRequestSessionInjection:
49+
def test_no_session_in_config_uses_own_session(self):
50+
config = AsyncRequestConfig(base_url="http://test", api_key="key", headers=None)
51+
r = AsyncRequest(config=config, path="/x", params={}, verb="get")
52+
assert r._shared_session is None
53+
54+
def test_session_in_config_is_stored(self):
55+
mock_session = MagicMock(spec=aiohttp.ClientSession)
56+
config = AsyncRequestConfig(base_url="http://test", api_key="key", headers=None)
57+
config["session"] = mock_session
58+
r = AsyncRequest(config=config, path="/x", params={}, verb="get")
59+
assert r._shared_session is mock_session
60+
61+
62+
# ---------------------------------------------------------------------------
63+
# Integration: AsyncJigsawStack as async context manager
64+
# ---------------------------------------------------------------------------
65+
66+
class TestAsyncJigsawStackContextManager:
67+
def test_no_session_before_enter(self):
68+
client = AsyncJigsawStack(api_key="test-key")
69+
assert client._session is None
70+
# configs should not have a session key yet
71+
for svc in client._async_services:
72+
assert svc.config.get("session") is None
73+
74+
def test_enter_injects_session_into_all_services(self):
75+
async def run():
76+
async with AsyncJigsawStack(api_key="test-key") as client:
77+
assert isinstance(client._session, aiohttp.ClientSession)
78+
for svc in client._async_services:
79+
assert svc.config.get("session") is client._session
80+
await client._session.close() # prevent ResourceWarning in test
81+
82+
asyncio.run(run())
83+
84+
def test_exit_clears_session_from_all_services(self):
85+
async def run():
86+
client = AsyncJigsawStack(api_key="test-key")
87+
await client.__aenter__()
88+
session = client._session
89+
await client.__aexit__(None, None, None)
90+
91+
assert client._session is None
92+
assert session.closed
93+
for svc in client._async_services:
94+
assert svc.config.get("session") is None
95+
96+
asyncio.run(run())
97+
98+
def test_aclose_is_idempotent(self):
99+
async def run():
100+
async with AsyncJigsawStack(api_key="test-key") as client:
101+
pass
102+
await client.aclose()
103+
104+
asyncio.run(run())
105+
106+
def test_reenter_after_aclose_creates_fresh_session(self):
107+
async def run():
108+
client = AsyncJigsawStack(api_key="test-key")
109+
await client.__aenter__()
110+
first_session = client._session
111+
await client.aclose()
112+
113+
await client.__aenter__()
114+
second_session = client._session
115+
assert second_session is not first_session
116+
assert isinstance(second_session, aiohttp.ClientSession)
117+
await client.aclose()
118+
119+
asyncio.run(run())

0 commit comments

Comments
 (0)