Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/lint/src/lint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def raise_err(code: int) -> None:

def main() -> None:
fix = ["--fix"] if "--fix" in sys.argv else []
raise_err(os.system(" ".join(["ruff", "check", "src"] + fix)))
raise_err(os.system("ruff format src"))
raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix)))
raise_err(os.system("ruff format src scripts tests"))
raise_err(os.system("mypy src"))
raise_err(os.system("pyright src"))
62 changes: 32 additions & 30 deletions scripts/parity/check_parity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Literal, TypedDict, TypeVar, Union
from typing import Any, Callable, Literal, TypedDict, TypeVar

import pyd
import tyd
Expand Down Expand Up @@ -85,35 +85,37 @@ def testAgenttoollanguageserverOpendocumentInput() -> None:
)


kind_type = Union[
Literal[1],
Literal[2],
Literal[3],
Literal[4],
Literal[5],
Literal[6],
Literal[7],
Literal[8],
Literal[9],
Literal[10],
Literal[11],
Literal[12],
Literal[13],
Literal[14],
Literal[15],
Literal[16],
Literal[17],
Literal[18],
Literal[19],
Literal[20],
Literal[21],
Literal[22],
Literal[23],
Literal[24],
Literal[25],
Literal[26],
None,
]
kind_type = (
Literal[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
]
| None
)


def testAgenttoollanguageserverGetcodesymbolInput() -> None:
Expand Down
4 changes: 2 additions & 2 deletions scripts/parity/gen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
import string
from typing import Callable, Optional, TypeVar
from typing import Callable, TypeVar

A = TypeVar("A")

Expand Down Expand Up @@ -37,7 +37,7 @@ def gen_choice(choices: list[A]) -> Callable[[], A]:
return lambda: random.choice(choices)


def gen_opt(gen_x: Callable[[], A]) -> Callable[[], Optional[A]]:
def gen_opt(gen_x: Callable[[], A]) -> Callable[[], A | None]:
return lambda: gen_x() if gen_bool() else None


Expand Down
18 changes: 9 additions & 9 deletions src/replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union
from typing import Any, AsyncGenerator, Generator, Generic, Literal

from opentelemetry import trace
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
Expand Down Expand Up @@ -100,9 +100,9 @@ async def send_upload(
self,
service_name: str,
procedure_name: str,
init: Optional[InitType],
init: InitType | None,
request: AsyncIterable[RequestType],
init_serializer: Optional[Callable[[InitType], Any]],
init_serializer: Callable[[InitType], Any] | None,
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
Expand All @@ -129,7 +129,7 @@ async def send_subscription(
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
) -> AsyncGenerator[ResponseType | RiverError, None]:
with _trace_procedure(
"subscription", service_name, procedure_name
) as span_handle:
Expand All @@ -151,13 +151,13 @@ async def send_stream(
self,
service_name: str,
procedure_name: str,
init: Optional[InitType],
init: InitType | None,
request: AsyncIterable[RequestType],
init_serializer: Optional[Callable[[InitType], Any]],
init_serializer: Callable[[InitType], Any] | None,
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
) -> AsyncGenerator[ResponseType | RiverError, None]:
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
session = await self._transport.get_or_create_session()
async for msg in session.send_stream(
Expand Down Expand Up @@ -185,8 +185,8 @@ class _SpanHandle:

def set_status(
self,
status: Union[Status, StatusCode],
description: Optional[str] = None,
status: Status | StatusCode,
description: str | None = None,
) -> None:
if self.did_set_status:
return
Expand Down
14 changes: 7 additions & 7 deletions src/replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections.abc import AsyncIterable
from datetime import timedelta
from typing import Any, AsyncGenerator, Callable, Optional, Union
from typing import Any, AsyncGenerator, Callable

import nanoid # type: ignore
from aiochannel import Channel
Expand Down Expand Up @@ -102,9 +102,9 @@ async def send_upload(
self,
service_name: str,
procedure_name: str,
init: Optional[InitType],
init: InitType | None,
request: AsyncIterable[RequestType],
init_serializer: Optional[Callable[[InitType], Any]],
init_serializer: Callable[[InitType], Any] | None,
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
Expand Down Expand Up @@ -194,7 +194,7 @@ async def send_subscription(
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
) -> AsyncGenerator[ResponseType | ErrorType, None]:
"""Sends a subscription request to the server.

Expects the input and output be messages that will be msgpacked.
Expand Down Expand Up @@ -241,14 +241,14 @@ async def send_stream(
self,
service_name: str,
procedure_name: str,
init: Optional[InitType],
init: InitType | None,
request: AsyncIterable[RequestType],
init_serializer: Optional[Callable[[InitType], Any]],
init_serializer: Callable[[InitType], Any] | None,
request_serializer: Callable[[RequestType], Any],
response_deserializer: Callable[[Any], ResponseType],
error_deserializer: Callable[[Any], ErrorType],
span: Span,
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
) -> AsyncGenerator[ResponseType | ErrorType, None]:
"""Sends a subscription request to the server.

Expects the input and output be messages that will be msgpacked.
Expand Down
16 changes: 8 additions & 8 deletions src/replit_river/client_transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Generic, Optional, Tuple
from typing import Generic

import websockets
from pydantic import ValidationError
Expand Down Expand Up @@ -98,7 +98,7 @@ async def get_or_create_session(self) -> ClientSession:
await existing_session.close()
return await self._create_new_session()

async def _get_existing_session(self) -> Optional[ClientSession]:
async def _get_existing_session(self) -> ClientSession | None:
async with self._session_lock:
if not self._sessions:
return None
Expand All @@ -117,8 +117,8 @@ async def _get_existing_session(self) -> Optional[ClientSession]:

async def _establish_new_connection(
self,
old_session: Optional[ClientSession] = None,
) -> Tuple[
old_session: ClientSession | None = None,
) -> tuple[
WebSocketCommonProtocol,
ControlMessageHandshakeRequest[HandshakeMetadataType],
ControlMessageHandshakeResponse,
Expand All @@ -129,7 +129,7 @@ async def _establish_new_connection(
client_id = self._client_id
logger.info("Attempting to establish new ws connection")

last_error: Optional[Exception] = None
last_error: Exception | None = None
for i in range(max_retry):
if i > 0:
logger.info(f"Retrying build handshake number {i} times")
Expand Down Expand Up @@ -221,7 +221,7 @@ async def _send_handshake_request(
transport_id: str,
to_id: str,
session_id: str,
handshake_metadata: Optional[HandshakeMetadataType],
handshake_metadata: HandshakeMetadataType | None,
websocket: WebSocketCommonProtocol,
expected_session_state: ExpectedSessionState,
) -> ControlMessageHandshakeRequest[HandshakeMetadataType]:
Expand Down Expand Up @@ -291,8 +291,8 @@ async def _establish_handshake(
session_id: str,
handshake_metadata: HandshakeMetadataType,
websocket: WebSocketCommonProtocol,
old_session: Optional[ClientSession],
) -> Tuple[
old_session: ClientSession | None,
) -> tuple[
ControlMessageHandshakeRequest[HandshakeMetadataType],
ControlMessageHandshakeResponse,
]:
Expand Down
Loading
Loading