Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 54 additions & 32 deletions jigsawstack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from typing import Dict, Union
from typing import Dict, List, Optional, Union

import aiohttp

from ._config import ClientConfig
from .audio import AsyncAudio, Audio
from .classification import AsyncClassification, Classification
from .embedding import AsyncEmbedding, Embedding
Expand Down Expand Up @@ -155,51 +158,70 @@ def __init__(
self.base_url = base_url
self.headers = headers or {"Content-Type": "application/json"}

self.web = AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers)
# _async_services holds every async service instance so that
# __aenter__ / aclose() can inject / remove the shared session.
self._async_services: List[ClientConfig] = []
self._session: Optional[aiohttp.ClientSession] = None

self.validate = AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers)
def _reg(svc: ClientConfig) -> ClientConfig:
self._async_services.append(svc)
return svc

self.audio = AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers)
self.web = _reg(AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers))

self.vision = AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers)
self.validate = _reg(AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers))

self.store = AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers)
self.audio = _reg(AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers))

self.summary = AsyncSummary(
api_key=api_key, base_url=base_url + "/v1", headers=headers
).summarize
self.vision = _reg(AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers))

self.prediction = AsyncPrediction(api_key=api_key, base_url=base_url + "/v1").predict
self.store = _reg(AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers))

self.text_to_sql = AsyncSQL(
api_key=api_key, base_url=base_url + "/v1", headers=headers
).text_to_sql
_summary = _reg(AsyncSummary(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.summary = _summary.summarize

self.sentiment = AsyncSentiment(
api_key=api_key, base_url=base_url + "/v1", headers=headers
).analyze
_prediction = _reg(AsyncPrediction(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.prediction = _prediction.predict

self.translate = AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers)
_sql = _reg(AsyncSQL(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.text_to_sql = _sql.text_to_sql

self.embedding = AsyncEmbedding(
api_key=api_key, base_url=base_url + "/v1", headers=headers
).execute
_sentiment = _reg(AsyncSentiment(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.sentiment = _sentiment.analyze

self.embedding_v2 = AsyncEmbeddingV2(
api_key=api_key, base_url=base_url + "/v2", headers=headers
).execute
self.translate = _reg(AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers))

self.image_generation = AsyncImageGeneration(
api_key=api_key, base_url=base_url + "/v1", headers=headers
).image_generation
_embedding = _reg(AsyncEmbedding(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.embedding = _embedding.execute

self.classification = AsyncClassification(
api_key=api_key, base_url=base_url + "/v1", headers=headers
).classify
_embedding_v2 = _reg(AsyncEmbeddingV2(api_key=api_key, base_url=base_url + "/v2", headers=headers))
self.embedding_v2 = _embedding_v2.execute

self.prompt_engine = AsyncPromptEngine(
api_key=api_key, base_url=base_url + "/v1", headers=headers
)
_image_gen = _reg(AsyncImageGeneration(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.image_generation = _image_gen.image_generation

_classification = _reg(AsyncClassification(api_key=api_key, base_url=base_url + "/v1", headers=headers))
self.classification = _classification.classify

self.prompt_engine = _reg(AsyncPromptEngine(api_key=api_key, base_url=base_url + "/v1", headers=headers))

async def __aenter__(self) -> "AsyncJigsawStack":
"""Open a shared aiohttp.ClientSession reused across all requests."""
self._session = aiohttp.ClientSession()
for svc in self._async_services:
svc.config["session"] = self._session
return self

async def aclose(self) -> None:
"""Close the shared session and clear it from all service configs."""
if self._session is not None:
for svc in self._async_services:
svc.config.pop("session", None)
await self._session.close()
self._session = None

async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
await self.aclose()


# Create a global instance of the Web class
Expand Down
34 changes: 27 additions & 7 deletions jigsawstack/async_request.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Generic, List, TypedDict, Union, cast
from typing import Any, AsyncGenerator, Dict, Generic, List, Optional, TypedDict, Union, cast

import aiohttp
from typing_extensions import Literal, TypeVar
from typing_extensions import Literal, NotRequired, TypeVar

from .exceptions import NoContentError, raise_for_code_and_type

Expand All @@ -16,6 +16,23 @@ class AsyncRequestConfig(TypedDict):
base_url: str
api_key: str
headers: Union[Dict[str, str], None]
session: NotRequired[aiohttp.ClientSession]


class _SessionContext:
"""Async context manager that wraps an existing ClientSession without closing it.
Used when a shared session is injected from the client (AsyncJigsawStack).
"""
__slots__ = ("_session",)

def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session

async def __aenter__(self) -> aiohttp.ClientSession:
return self._session

async def __aexit__(self, *_: object) -> None:
pass # session lifetime is managed by the caller


class AsyncRequest(Generic[T]):
Expand All @@ -38,6 +55,8 @@ def __init__(
self.headers = config.get("headers", None) or {"Content-Type": "application/json"}
self.stream = stream
self.files = files # Store files for multipart requests
# Optional shared session injected by AsyncJigsawStack.
self._shared_session: Optional[aiohttp.ClientSession] = config.get("session", None)

def __convert_params(
self, params: Union[Dict[Any, Any], List[Dict[Any, Any]]]
Expand Down Expand Up @@ -269,13 +288,14 @@ async def make_request(
headers=headers,
)

def __get_session(self) -> aiohttp.ClientSession:
def __get_session(self) -> Union["_SessionContext", aiohttp.ClientSession]:
"""
Create and return an async client session.

Returns:
aiohttp.ClientSession: An async client session
Return an async context manager that provides a ClientSession.
If a shared session was injected via config, reuse it without closing.
Otherwise open a fresh session for this request only.
"""
if self._shared_session is not None:
return _SessionContext(self._shared_session)
return aiohttp.ClientSession()

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions jigsawstack/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def delete(self, key: str) -> FileDeleteResponse:
resp = Request(
config=self.config,
path=path,
params=key,
params={},
verb="delete",
).perform_with_content()
return resp
Expand Down Expand Up @@ -140,7 +140,7 @@ async def delete(self, key: str) -> FileDeleteResponse:
resp = await AsyncRequest(
config=self.config,
path=path,
params=key,
params={},
verb="delete",
).perform_with_content()
return resp
119 changes: 119 additions & 0 deletions tests/test_session_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
tests/test_session_lifecycle.py

Tests for claim 1: AsyncJigsawStack session management.

Verifies that:
- Without __aenter__, each request creates its own temporary session.
- With __aenter__, a single shared ClientSession is injected into every
service config and reused across requests.
- __aexit__ / aclose() removes the session from all configs and closes it.
- Re-entering after aclose() works correctly (fresh session).
No real network calls are made.
"""

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import aiohttp
import pytest

from jigsawstack import AsyncJigsawStack
from jigsawstack.async_request import AsyncRequest, AsyncRequestConfig


# ---------------------------------------------------------------------------
# Unit: _SessionContext
# ---------------------------------------------------------------------------

class TestSessionContext:
def test_reuses_session_without_closing(self):
from jigsawstack.async_request import _SessionContext

mock_session = MagicMock(spec=aiohttp.ClientSession)
ctx = _SessionContext(mock_session)

async def run():
async with ctx as s:
assert s is mock_session
mock_session.close.assert_not_called()

asyncio.run(run())


# ---------------------------------------------------------------------------
# Unit: AsyncRequest picks up shared session from config
# ---------------------------------------------------------------------------

class TestAsyncRequestSessionInjection:
def test_no_session_in_config_uses_own_session(self):
config = AsyncRequestConfig(base_url="http://test", api_key="key", headers=None)
r = AsyncRequest(config=config, path="/x", params={}, verb="get")
assert r._shared_session is None

def test_session_in_config_is_stored(self):
mock_session = MagicMock(spec=aiohttp.ClientSession)
config = AsyncRequestConfig(base_url="http://test", api_key="key", headers=None)
config["session"] = mock_session
r = AsyncRequest(config=config, path="/x", params={}, verb="get")
assert r._shared_session is mock_session


# ---------------------------------------------------------------------------
# Integration: AsyncJigsawStack as async context manager
# ---------------------------------------------------------------------------

class TestAsyncJigsawStackContextManager:
def test_no_session_before_enter(self):
client = AsyncJigsawStack(api_key="test-key")
assert client._session is None
# configs should not have a session key yet
for svc in client._async_services:
assert svc.config.get("session") is None

def test_enter_injects_session_into_all_services(self):
async def run():
async with AsyncJigsawStack(api_key="test-key") as client:
assert isinstance(client._session, aiohttp.ClientSession)
for svc in client._async_services:
assert svc.config.get("session") is client._session
await client._session.close() # prevent ResourceWarning in test

asyncio.run(run())

def test_exit_clears_session_from_all_services(self):
async def run():
client = AsyncJigsawStack(api_key="test-key")
await client.__aenter__()
session = client._session
await client.__aexit__(None, None, None)

assert client._session is None
assert session.closed
for svc in client._async_services:
assert svc.config.get("session") is None

asyncio.run(run())

def test_aclose_is_idempotent(self):
async def run():
async with AsyncJigsawStack(api_key="test-key") as client:
pass
await client.aclose()

asyncio.run(run())

def test_reenter_after_aclose_creates_fresh_session(self):
async def run():
client = AsyncJigsawStack(api_key="test-key")
await client.__aenter__()
first_session = client._session
await client.aclose()

await client.__aenter__()
second_session = client._session
assert second_session is not first_session
assert isinstance(second_session, aiohttp.ClientSession)
await client.aclose()

asyncio.run(run())
Loading