Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 14 additions & 66 deletions src/workos/connect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from functools import partial
from typing import Optional, Protocol, Sequence

from workos.types.connect import ClientSecret, ConnectApplication
from workos.types.connect.connect_application import ApplicationType
from workos.types.connect.list_filters import (
ClientSecretListFilters,
ConnectApplicationListFilters,
)
from workos.types.connect.list_filters import ConnectApplicationListFilters
from workos.types.connect.redirect_uri_input import RedirectUriInput
from workos.types.list_resource import ListMetadata, ListPage, WorkOSListResource
from workos.typing.sync_or_async import SyncOrAsync
from workos.utils.http_client import AsyncHTTPClient, SyncHTTPClient
Expand All @@ -26,10 +23,6 @@
ConnectApplication, ConnectApplicationListFilters, ListMetadata
]

ClientSecretsListResource = WorkOSListResource[
ClientSecret, ClientSecretListFilters, ListMetadata
]


class ConnectModule(Protocol):
"""Offers methods through the WorkOS Connect service."""
Expand Down Expand Up @@ -76,7 +69,7 @@ def create_application(
is_first_party: bool,
description: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[RedirectUriInput]] = None,
uses_pkce: Optional[bool] = None,
organization_id: Optional[str] = None,
) -> SyncOrAsync[ConnectApplication]:
Expand Down Expand Up @@ -104,7 +97,7 @@ def update_application(
name: Optional[str] = None,
description: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[RedirectUriInput]] = None,
) -> SyncOrAsync[ConnectApplication]:
"""Update a connect application.

Expand Down Expand Up @@ -145,25 +138,14 @@ def create_client_secret(self, application_id: str) -> SyncOrAsync[ClientSecret]
def list_client_secrets(
self,
application_id: str,
*,
limit: int = DEFAULT_LIST_RESPONSE_LIMIT,
before: Optional[str] = None,
after: Optional[str] = None,
order: PaginationOrder = "desc",
) -> SyncOrAsync[ClientSecretsListResource]:
) -> SyncOrAsync[Sequence[ClientSecret]]:
"""List client secrets for a connect application.

Args:
application_id (str): Application ID or client ID.

Kwargs:
limit (int): Maximum number of records to return. (Optional)
before (str): Pagination cursor to receive records before a provided ID. (Optional)
after (str): Pagination cursor to receive records after a provided ID. (Optional)
order (Literal["asc","desc"]): Sort records in either ascending or descending order. (Optional)

Returns:
ClientSecretsListResource: Client secrets list response from WorkOS.
Sequence[ClientSecret]: Client secrets for the application.
"""
...

Expand Down Expand Up @@ -232,7 +214,7 @@ def create_application(
is_first_party: bool,
description: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[RedirectUriInput]] = None,
uses_pkce: Optional[bool] = None,
organization_id: Optional[str] = None,
) -> ConnectApplication:
Expand Down Expand Up @@ -262,7 +244,7 @@ def update_application(
name: Optional[str] = None,
description: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[RedirectUriInput]] = None,
) -> ConnectApplication:
json = {
"name": name,
Expand Down Expand Up @@ -297,30 +279,13 @@ def create_client_secret(self, application_id: str) -> ClientSecret:
def list_client_secrets(
self,
application_id: str,
*,
limit: int = DEFAULT_LIST_RESPONSE_LIMIT,
before: Optional[str] = None,
after: Optional[str] = None,
order: PaginationOrder = "desc",
) -> ClientSecretsListResource:
list_params: ClientSecretListFilters = {
"limit": limit,
"before": before,
"after": after,
"order": order,
}

) -> Sequence[ClientSecret]:
response = self._http_client.request(
f"{CONNECT_APPLICATIONS_PATH}/{application_id}/client_secrets",
method=REQUEST_METHOD_GET,
params=list_params,
)

return WorkOSListResource[ClientSecret, ClientSecretListFilters, ListMetadata](
list_method=partial(self.list_client_secrets, application_id),
list_args=list_params,
**ListPage[ClientSecret](**response).model_dump(),
)
return [ClientSecret.model_validate(secret) for secret in response]

def delete_client_secret(self, client_secret_id: str) -> None:
self._http_client.request(
Expand Down Expand Up @@ -382,7 +347,7 @@ async def create_application(
is_first_party: bool,
description: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[RedirectUriInput]] = None,
uses_pkce: Optional[bool] = None,
organization_id: Optional[str] = None,
) -> ConnectApplication:
Expand Down Expand Up @@ -412,7 +377,7 @@ async def update_application(
name: Optional[str] = None,
description: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[str]] = None,
redirect_uris: Optional[Sequence[RedirectUriInput]] = None,
) -> ConnectApplication:
json = {
"name": name,
Expand Down Expand Up @@ -447,30 +412,13 @@ async def create_client_secret(self, application_id: str) -> ClientSecret:
async def list_client_secrets(
self,
application_id: str,
*,
limit: int = DEFAULT_LIST_RESPONSE_LIMIT,
before: Optional[str] = None,
after: Optional[str] = None,
order: PaginationOrder = "desc",
) -> ClientSecretsListResource:
list_params: ClientSecretListFilters = {
"limit": limit,
"before": before,
"after": after,
"order": order,
}

) -> Sequence[ClientSecret]:
response = await self._http_client.request(
f"{CONNECT_APPLICATIONS_PATH}/{application_id}/client_secrets",
method=REQUEST_METHOD_GET,
params=list_params,
)

return WorkOSListResource[ClientSecret, ClientSecretListFilters, ListMetadata](
list_method=partial(self.list_client_secrets, application_id),
list_args=list_params,
**ListPage[ClientSecret](**response).model_dump(),
)
return [ClientSecret.model_validate(secret) for secret in response]

async def delete_client_secret(self, client_secret_id: str) -> None:
await self._http_client.request(
Expand Down
4 changes: 0 additions & 4 deletions src/workos/types/connect/list_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,3 @@

class ConnectApplicationListFilters(ListArgs, total=False):
organization_id: Optional[str]


class ClientSecretListFilters(ListArgs, total=False):
pass
9 changes: 9 additions & 0 deletions src/workos/types/connect/redirect_uri_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing_extensions import TypedDict


class _RedirectUriInputRequired(TypedDict):
uri: str


class RedirectUriInput(_RedirectUriInputRequired, total=False):
default: bool
48 changes: 6 additions & 42 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,7 @@ def mock_client_secret(self):

@pytest.fixture
def mock_client_secrets(self):
secret_list = [MockClientSecret(id=f"cs_{i}").dict() for i in range(10)]
return {
"data": secret_list,
"list_metadata": {"before": None, "after": None},
"object": "list",
}

@pytest.fixture
def mock_client_secrets_multiple_data_pages(self):
secrets_list = [MockClientSecret(id=f"cs_{i + 1}").dict() for i in range(40)]
return list_response_of(data=secrets_list)
return [MockClientSecret(id=f"cs_{i}").dict() for i in range(10)]

# --- Application Tests ---

Expand Down Expand Up @@ -149,7 +139,9 @@ def test_create_oauth_application(
name="Test Application",
application_type="oauth",
is_first_party=True,
redirect_uris=["https://example.com/callback"],
redirect_uris=[
{"uri": "https://example.com/callback", "default": True}
],
uses_pkce=True,
)
)
Expand All @@ -158,7 +150,7 @@ def test_create_oauth_application(
assert request_kwargs["method"] == "post"
assert request_kwargs["json"]["application_type"] == "oauth"
assert request_kwargs["json"]["redirect_uris"] == [
"https://example.com/callback"
{"uri": "https://example.com/callback", "default": True}
]

def test_update_application(
Expand Down Expand Up @@ -250,9 +242,7 @@ def test_list_client_secrets(
assert request_kwargs["url"].endswith(
"/connect/applications/app_01ABC/client_secrets"
)
assert (
list(map(lambda x: x.dict(), response.data)) == mock_client_secrets["data"]
)
assert [secret.dict() for secret in response] == mock_client_secrets

def test_delete_client_secret(self, capture_and_mock_http_client_request):
request_kwargs = capture_and_mock_http_client_request(
Expand All @@ -269,29 +259,3 @@ def test_delete_client_secret(self, capture_and_mock_http_client_request):
assert request_kwargs["url"].endswith("/connect/client_secrets/cs_01ABC")
assert request_kwargs["method"] == "delete"
assert response is None

def test_list_client_secrets_auto_pagination_for_single_page(
self,
mock_client_secrets,
test_auto_pagination: TestAutoPaginationFunction,
):
test_auto_pagination(
http_client=self.http_client,
list_function=self.connect.list_client_secrets,
expected_all_page_data=mock_client_secrets["data"],
list_function_params={"application_id": "app_01ABC"},
url_path_keys=["application_id"],
)

def test_list_client_secrets_auto_pagination_for_multiple_pages(
self,
mock_client_secrets_multiple_data_pages,
test_auto_pagination: TestAutoPaginationFunction,
):
test_auto_pagination(
http_client=self.http_client,
list_function=self.connect.list_client_secrets,
expected_all_page_data=mock_client_secrets_multiple_data_pages["data"],
list_function_params={"application_id": "app_01ABC"},
url_path_keys=["application_id"],
)