55import logging
66from contextlib import asynccontextmanager
77from dataclasses import dataclass , field
8+ from typing import Any
89
910from anyio import create_task_group
1011from google .protobuf import empty_pb2
1112from grpc import StatusCode
12- from grpc .aio import AioRpcError , Channel
13+ from grpc .aio import AioRpcError
1314from jumpstarter_protocol import jumpstarter_pb2 , jumpstarter_pb2_grpc , router_pb2_grpc
1415
1516from jumpstarter .common import Metadata
@@ -60,16 +61,14 @@ class AsyncDriverClient(
6061 Backing implementation of blocking driver client.
6162 """
6263
63- channel : Channel
64+ stub : Any
6465
6566 log_level : str = "INFO"
6667 logger : logging .Logger = field (init = False )
6768
6869 def __post_init__ (self ):
6970 if hasattr (super (), "__post_init__" ):
7071 super ().__post_init__ ()
71- jumpstarter_pb2_grpc .ExporterServiceStub .__init__ (self , self .channel )
72- router_pb2_grpc .RouterServiceStub .__init__ (self , self .channel )
7372 self .logger = logging .getLogger (self .__class__ .__name__ )
7473 self .logger .setLevel (self .log_level )
7574
@@ -89,7 +88,7 @@ async def call_async(self, method, *args):
8988 )
9089
9190 try :
92- response = await self .DriverCall (request )
91+ response = await self .stub . DriverCall (request )
9392 except AioRpcError as e :
9493 match e .code ():
9594 case StatusCode .UNIMPLEMENTED :
@@ -113,7 +112,7 @@ async def streamingcall_async(self, method, *args):
113112 )
114113
115114 try :
116- async for response in self .StreamingDriverCall (request ):
115+ async for response in self .stub . StreamingDriverCall (request ):
117116 yield decode_value (response .result )
118117 except AioRpcError as e :
119118 match e .code ():
@@ -128,7 +127,7 @@ async def streamingcall_async(self, method, *args):
128127
129128 @asynccontextmanager
130129 async def stream_async (self , method ):
131- context = self .Stream (
130+ context = self .stub . Stream (
132131 metadata = StreamRequestMetadata .model_construct (request = DriverStreamRequest (uuid = self .uuid , method = method ))
133132 .model_dump (mode = "json" , round_trip = True )
134133 .items (),
@@ -142,7 +141,7 @@ async def resource_async(
142141 self ,
143142 stream ,
144143 ):
145- context = self .Stream (
144+ context = self .stub . Stream (
146145 metadata = StreamRequestMetadata .model_construct (request = ResourceStreamRequest (uuid = self .uuid ))
147146 .model_dump (mode = "json" , round_trip = True )
148147 .items (),
@@ -160,7 +159,7 @@ def __log(self, level: int, msg: str):
160159 @asynccontextmanager
161160 async def log_stream_async (self ):
162161 async def log_stream ():
163- async for response in self .LogStream (empty_pb2 .Empty ()):
162+ async for response in self .stub . LogStream (empty_pb2 .Empty ()):
164163 self .__log (logging .getLevelName (response .severity ), response .message )
165164
166165 async with create_task_group () as tg :
0 commit comments