@@ -112,6 +112,7 @@ def __init__(
112112 read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ],
113113 write_stream : MemoryObjectSendStream [SessionMessage ],
114114 read_timeout_seconds : timedelta | None = None ,
115+ progress_callback : ProgressFnT | None = None ,
115116 sampling_callback : SamplingFnT | None = None ,
116117 elicitation_callback : ElicitationFnT | None = None ,
117118 list_roots_callback : ListRootsFnT | None = None ,
@@ -127,6 +128,7 @@ def __init__(
127128 read_timeout_seconds = read_timeout_seconds ,
128129 )
129130 self ._client_info = client_info or DEFAULT_CLIENT_INFO
131+ self ._progress_callback = progress_callback
130132 self ._sampling_callback = sampling_callback or _default_sampling_callback
131133 self ._elicitation_callback = elicitation_callback or _default_elicitation_callback
132134 self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
@@ -302,7 +304,7 @@ async def call_tool(
302304 ),
303305 types .CallToolResult ,
304306 request_read_timeout_seconds = read_timeout_seconds ,
305- progress_callback = progress_callback ,
307+ progress_callback = progress_callback or self . _progress_callback ,
306308 )
307309
308310 if not result .isError :
0 commit comments