Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/authzed/api/materialize/v0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, target, credentials, options=None, compression=None):
self.init_stubs(channel)

def init_stubs(self, channel):
self._channel = channel
WatchPermissionsServiceStub.__init__(self, channel)
WatchPermissionSetsServiceStub.__init__(self, channel)

Expand All @@ -56,6 +57,21 @@ def create_channel(self, target, credentials, options=None, compression=None):

return channelfn(target, credentials, options, compression)

def close(self):
"""
Close the underlying gRPC channel.

For async channels (``grpc.aio.Channel``), this returns a coroutine that
must be awaited; the caller is expected to ``await client.close()``.
For sync channels (``grpc.Channel``), the channel is closed
synchronously and ``None`` is returned.

Closing the channel cancels in-flight RPCs and prevents new RPCs from
being issued through this client. Calling ``close`` more than once is
safe.
"""
return self._channel.close()


class AsyncClient(Client):
"""
Expand All @@ -66,6 +82,16 @@ def __init__(self, target, credentials, options=None, compression=None):
channel = grpc.aio.secure_channel(target, credentials, options, compression)
self.init_stubs(channel)

async def close(self, grace=None):
"""
Close the underlying async gRPC channel.

``grace`` is forwarded to ``grpc.aio.Channel.close``; when set, the
channel waits up to ``grace`` seconds for pending RPCs to finish
before cancelling them.
"""
await self._channel.close(grace)


class SyncClient(Client):
"""
Expand All @@ -76,6 +102,10 @@ def __init__(self, target, credentials, options=None, compression=None):
channel = grpc.secure_channel(target, credentials, options, compression)
self.init_stubs(channel)

def close(self):
"""Close the underlying sync gRPC channel."""
self._channel.close()


class TokenAuthorization(ClientInterceptor):
def __init__(self, token: str):
Expand Down
5 changes: 4 additions & 1 deletion src/authzed/api/materialize/v0/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Awaitable, Optional, Sequence, Tuple, Union

import grpc

Expand Down Expand Up @@ -42,6 +42,7 @@ class Client(WatchPermissionsServiceStub, WatchPermissionSetsServiceStub):
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
) -> None: ...
def close(self) -> Union[None, Awaitable[None]]: ...

class SyncClient(WatchPermissionsServiceStub, WatchPermissionSetsServiceStub):
def __init__(
Expand All @@ -51,6 +52,7 @@ class SyncClient(WatchPermissionsServiceStub, WatchPermissionSetsServiceStub):
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
) -> None: ...
def close(self) -> None: ...

class AsyncClient(WatchPermissionsServiceStub, WatchPermissionSetsServiceStub):
def __init__(
Expand All @@ -60,6 +62,7 @@ class AsyncClient(WatchPermissionsServiceStub, WatchPermissionSetsServiceStub):
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
) -> None: ...
async def close(self, grace: Optional[float] = None) -> None: ...

class InsecureClient(Client):
def __init__(
Expand Down
30 changes: 30 additions & 0 deletions src/authzed/api/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(self, target, credentials, options=None, compression=None):
self.init_stubs(channel)

def init_stubs(self, channel):
self._channel = channel
SchemaServiceStub.__init__(self, channel)
PermissionsServiceStub.__init__(self, channel)
ExperimentalServiceStub.__init__(self, channel)
Expand All @@ -110,6 +111,21 @@ def create_channel(self, target, credentials, options=None, compression=None):

return channelfn(target, credentials, options, compression)

def close(self):
"""
Close the underlying gRPC channel.

For async channels (``grpc.aio.Channel``), this returns a coroutine that
must be awaited; the caller is expected to ``await client.close()``.
For sync channels (``grpc.Channel``), the channel is closed
synchronously and ``None`` is returned.

Closing the channel cancels in-flight RPCs and prevents new RPCs from
being issued through this client. Calling ``close`` more than once is
safe.
"""
return self._channel.close()


class AsyncClient(Client):
"""
Expand All @@ -120,6 +136,16 @@ def __init__(self, target, credentials, options=None, compression=None):
channel = grpc.aio.secure_channel(target, credentials, options, compression)
self.init_stubs(channel)

async def close(self, grace=None):
"""
Close the underlying async gRPC channel.

``grace`` is forwarded to ``grpc.aio.Channel.close``; when set, the
channel waits up to ``grace`` seconds for pending RPCs to finish
before cancelling them.
"""
await self._channel.close(grace)


class SyncClient(Client):
"""
Expand All @@ -130,6 +156,10 @@ def __init__(self, target, credentials, options=None, compression=None):
channel = grpc.secure_channel(target, credentials, options, compression)
self.init_stubs(channel)

def close(self):
"""Close the underlying sync gRPC channel."""
self._channel.close()


class TokenAuthorization(ClientInterceptor):
def __init__(self, token: str):
Expand Down
5 changes: 4 additions & 1 deletion src/authzed/api/v1/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Awaitable, Optional, Sequence, Tuple, Union

import grpc

Expand Down Expand Up @@ -81,6 +81,7 @@ class Client(SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub,
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
) -> None: ...
def close(self) -> Union[None, Awaitable[None]]: ...

class SyncClient(
SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub, WatchServiceStub
Expand All @@ -92,6 +93,7 @@ class SyncClient(
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
) -> None: ...
def close(self) -> None: ...

class AsyncClient(
SchemaServiceAsyncStub,
Expand All @@ -106,6 +108,7 @@ class AsyncClient(
options: Optional[Sequence[Tuple[str, Any]]] = None,
compression: Optional[grpc.Compression] = None,
) -> None: ...
async def close(self, grace: Optional[float] = None) -> None: ...

class InsecureClient(Client):
def __init__(
Expand Down
59 changes: 59 additions & 0 deletions tests/close_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import asyncio

import grpc
import grpc.aio
import pytest

from authzed.api.v1 import AsyncClient, Client, InsecureClient, SyncClient
from grpcutil import insecure_bearer_token_credentials


def test_sync_client_close_closes_channel(token):
client = SyncClient("localhost:50051", insecure_bearer_token_credentials(token))
assert isinstance(client._channel, grpc.Channel)

# close should be idempotent and not raise
client.close()
client.close()


def test_insecure_client_close_closes_channel(token):
client = InsecureClient("localhost:50051", token)
# InsecureClient uses an intercepted channel, but it still exposes close()
client.close()
client.close()


async def test_async_client_close_closes_channel(token):
client = AsyncClient("localhost:50051", insecure_bearer_token_credentials(token))
assert isinstance(client._channel, grpc.aio.Channel)

# close should be awaitable and idempotent
await client.close()
await client.close()


async def test_async_client_close_accepts_grace(token):
client = AsyncClient("localhost:50051", insecure_bearer_token_credentials(token))
await client.close(grace=0)


def test_autodetect_client_close_when_sync(token):
# Outside of an event loop, Client builds a sync channel; close() should
# return None (not a coroutine).
with pytest.raises(RuntimeError):
asyncio.get_running_loop()
client = Client("localhost:50051", insecure_bearer_token_credentials(token))
assert isinstance(client._channel, grpc.Channel)
result = client.close()
assert result is None


async def test_autodetect_client_close_when_async(token):
# Inside an event loop, Client builds an async channel; close() returns a
# coroutine that must be awaited.
client = Client("localhost:50051", insecure_bearer_token_credentials(token))
assert isinstance(client._channel, grpc.aio.Channel)
awaitable = client.close()
assert asyncio.iscoroutine(awaitable)
await awaitable
Loading