Skip to content

Commit 202b314

Browse files
authored
Add Second Generic Value to _TracingSignal for Custom Context Variables (aio-libs#11268)
1 parent 331c989 commit 202b314

3 files changed

Lines changed: 63 additions & 38 deletions

File tree

CHANGES/11268.feature.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Updated ``_TracingSignal`` to utilize a secondary generic variable for type hinting custom context variables
2+
-- by :user:`Vizonex`.

aiohttp/tracing.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
if TYPE_CHECKING:
1212
from .client import ClientSession
1313

14-
_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
15-
_TracingSignal = Signal[ClientSession, SimpleNamespace, _ParamT_contra]
16-
1714

1815
__all__ = (
1916
"TraceConfig",
@@ -36,6 +33,8 @@
3633
)
3734

3835
_T = TypeVar("_T", covariant=True)
36+
_ParamT_contra = TypeVar("_ParamT_contra", contravariant=True)
37+
_TracingSignal = Signal["ClientSession", _T, _ParamT_contra]
3938

4039

4140
class _Factory(Protocol[_T]):
@@ -52,46 +51,52 @@ def __init__(self, trace_config_ctx_factory: _Factory[_T]) -> None: ...
5251
def __init__(
5352
self, trace_config_ctx_factory: _Factory[Any] = SimpleNamespace
5453
) -> None:
55-
self._on_request_start: _TracingSignal[TraceRequestStartParams] = Signal(self)
56-
self._on_request_chunk_sent: _TracingSignal[TraceRequestChunkSentParams] = (
54+
self._on_request_start: _TracingSignal[_T, TraceRequestStartParams] = Signal(
55+
self
56+
)
57+
self._on_request_chunk_sent: _TracingSignal[_T, TraceRequestChunkSentParams] = (
5758
Signal(self)
5859
)
5960
self._on_response_chunk_received: _TracingSignal[
60-
TraceResponseChunkReceivedParams
61+
_T, TraceResponseChunkReceivedParams
6162
] = Signal(self)
62-
self._on_request_end: _TracingSignal[TraceRequestEndParams] = Signal(self)
63-
self._on_request_exception: _TracingSignal[TraceRequestExceptionParams] = (
63+
self._on_request_end: _TracingSignal[_T, TraceRequestEndParams] = Signal(self)
64+
self._on_request_exception: _TracingSignal[_T, TraceRequestExceptionParams] = (
6465
Signal(self)
6566
)
66-
self._on_request_redirect: _TracingSignal[TraceRequestRedirectParams] = Signal(
67-
self
67+
self._on_request_redirect: _TracingSignal[_T, TraceRequestRedirectParams] = (
68+
Signal(self)
6869
)
6970
self._on_connection_queued_start: _TracingSignal[
70-
TraceConnectionQueuedStartParams
71+
_T, TraceConnectionQueuedStartParams
7172
] = Signal(self)
7273
self._on_connection_queued_end: _TracingSignal[
73-
TraceConnectionQueuedEndParams
74+
_T, TraceConnectionQueuedEndParams
7475
] = Signal(self)
7576
self._on_connection_create_start: _TracingSignal[
76-
TraceConnectionCreateStartParams
77+
_T, TraceConnectionCreateStartParams
7778
] = Signal(self)
7879
self._on_connection_create_end: _TracingSignal[
79-
TraceConnectionCreateEndParams
80+
_T, TraceConnectionCreateEndParams
8081
] = Signal(self)
8182
self._on_connection_reuseconn: _TracingSignal[
82-
TraceConnectionReuseconnParams
83+
_T, TraceConnectionReuseconnParams
8384
] = Signal(self)
8485
self._on_dns_resolvehost_start: _TracingSignal[
85-
TraceDnsResolveHostStartParams
86+
_T, TraceDnsResolveHostStartParams
8687
] = Signal(self)
87-
self._on_dns_resolvehost_end: _TracingSignal[TraceDnsResolveHostEndParams] = (
88-
Signal(self)
88+
self._on_dns_resolvehost_end: _TracingSignal[
89+
_T, TraceDnsResolveHostEndParams
90+
] = Signal(self)
91+
self._on_dns_cache_hit: _TracingSignal[_T, TraceDnsCacheHitParams] = Signal(
92+
self
8993
)
90-
self._on_dns_cache_hit: _TracingSignal[TraceDnsCacheHitParams] = Signal(self)
91-
self._on_dns_cache_miss: _TracingSignal[TraceDnsCacheMissParams] = Signal(self)
92-
self._on_request_headers_sent: _TracingSignal[TraceRequestHeadersSentParams] = (
93-
Signal(self)
94+
self._on_dns_cache_miss: _TracingSignal[_T, TraceDnsCacheMissParams] = Signal(
95+
self
9496
)
97+
self._on_request_headers_sent: _TracingSignal[
98+
_T, TraceRequestHeadersSentParams
99+
] = Signal(self)
95100

96101
self._trace_config_ctx_factory: _Factory[_T] = trace_config_ctx_factory
97102

@@ -118,89 +123,91 @@ def freeze(self) -> None:
118123
self._on_request_headers_sent.freeze()
119124

120125
@property
121-
def on_request_start(self) -> "_TracingSignal[TraceRequestStartParams]":
126+
def on_request_start(self) -> "_TracingSignal[_T, TraceRequestStartParams]":
122127
return self._on_request_start
123128

124129
@property
125-
def on_request_chunk_sent(self) -> "_TracingSignal[TraceRequestChunkSentParams]":
130+
def on_request_chunk_sent(
131+
self,
132+
) -> "_TracingSignal[_T, TraceRequestChunkSentParams]":
126133
return self._on_request_chunk_sent
127134

128135
@property
129136
def on_response_chunk_received(
130137
self,
131-
) -> "_TracingSignal[TraceResponseChunkReceivedParams]":
138+
) -> "_TracingSignal[_T, TraceResponseChunkReceivedParams]":
132139
return self._on_response_chunk_received
133140

134141
@property
135-
def on_request_end(self) -> "_TracingSignal[TraceRequestEndParams]":
142+
def on_request_end(self) -> "_TracingSignal[_T, TraceRequestEndParams]":
136143
return self._on_request_end
137144

138145
@property
139146
def on_request_exception(
140147
self,
141-
) -> "_TracingSignal[TraceRequestExceptionParams]":
148+
) -> "_TracingSignal[_T, TraceRequestExceptionParams]":
142149
return self._on_request_exception
143150

144151
@property
145152
def on_request_redirect(
146153
self,
147-
) -> "_TracingSignal[TraceRequestRedirectParams]":
154+
) -> "_TracingSignal[_T, TraceRequestRedirectParams]":
148155
return self._on_request_redirect
149156

150157
@property
151158
def on_connection_queued_start(
152159
self,
153-
) -> "_TracingSignal[TraceConnectionQueuedStartParams]":
160+
) -> "_TracingSignal[_T, TraceConnectionQueuedStartParams]":
154161
return self._on_connection_queued_start
155162

156163
@property
157164
def on_connection_queued_end(
158165
self,
159-
) -> "_TracingSignal[TraceConnectionQueuedEndParams]":
166+
) -> "_TracingSignal[_T, TraceConnectionQueuedEndParams]":
160167
return self._on_connection_queued_end
161168

162169
@property
163170
def on_connection_create_start(
164171
self,
165-
) -> "_TracingSignal[TraceConnectionCreateStartParams]":
172+
) -> "_TracingSignal[_T, TraceConnectionCreateStartParams]":
166173
return self._on_connection_create_start
167174

168175
@property
169176
def on_connection_create_end(
170177
self,
171-
) -> "_TracingSignal[TraceConnectionCreateEndParams]":
178+
) -> "_TracingSignal[_T, TraceConnectionCreateEndParams]":
172179
return self._on_connection_create_end
173180

174181
@property
175182
def on_connection_reuseconn(
176183
self,
177-
) -> "_TracingSignal[TraceConnectionReuseconnParams]":
184+
) -> "_TracingSignal[_T, TraceConnectionReuseconnParams]":
178185
return self._on_connection_reuseconn
179186

180187
@property
181188
def on_dns_resolvehost_start(
182189
self,
183-
) -> "_TracingSignal[TraceDnsResolveHostStartParams]":
190+
) -> "_TracingSignal[_T, TraceDnsResolveHostStartParams]":
184191
return self._on_dns_resolvehost_start
185192

186193
@property
187194
def on_dns_resolvehost_end(
188195
self,
189-
) -> "_TracingSignal[TraceDnsResolveHostEndParams]":
196+
) -> "_TracingSignal[_T, TraceDnsResolveHostEndParams]":
190197
return self._on_dns_resolvehost_end
191198

192199
@property
193-
def on_dns_cache_hit(self) -> "_TracingSignal[TraceDnsCacheHitParams]":
200+
def on_dns_cache_hit(self) -> "_TracingSignal[_T, TraceDnsCacheHitParams]":
194201
return self._on_dns_cache_hit
195202

196203
@property
197-
def on_dns_cache_miss(self) -> "_TracingSignal[TraceDnsCacheMissParams]":
204+
def on_dns_cache_miss(self) -> "_TracingSignal[_T, TraceDnsCacheMissParams]":
198205
return self._on_dns_cache_miss
199206

200207
@property
201208
def on_request_headers_sent(
202209
self,
203-
) -> "_TracingSignal[TraceRequestHeadersSentParams]":
210+
) -> "_TracingSignal[_T, TraceRequestHeadersSentParams]":
204211
return self._on_request_headers_sent
205212

206213

tests/test_tracing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import sys
12
from types import SimpleNamespace
23
from typing import Any, Tuple
34
from unittest import mock
45
from unittest.mock import Mock
56

67
import pytest
8+
from aiosignal import Signal
79

10+
from aiohttp import ClientSession
811
from aiohttp.tracing import (
912
Trace,
1013
TraceConfig,
@@ -25,15 +28,28 @@
2528
TraceResponseChunkReceivedParams,
2629
)
2730

31+
if sys.version_info >= (3, 11):
32+
from typing import assert_type
33+
2834

2935
class TestTraceConfig:
3036
def test_trace_config_ctx_default(self) -> None:
3137
trace_config = TraceConfig()
3238
assert isinstance(trace_config.trace_config_ctx(), SimpleNamespace)
39+
if sys.version_info >= (3, 11):
40+
assert_type(
41+
trace_config.on_request_chunk_sent,
42+
Signal[ClientSession, SimpleNamespace, TraceRequestChunkSentParams],
43+
)
3344

3445
def test_trace_config_ctx_factory(self) -> None:
3546
trace_config = TraceConfig(trace_config_ctx_factory=dict)
3647
assert isinstance(trace_config.trace_config_ctx(), dict)
48+
if sys.version_info >= (3, 11):
49+
assert_type(
50+
trace_config.on_request_start,
51+
Signal[ClientSession, dict[str, Any], TraceRequestStartParams],
52+
)
3753

3854
def test_trace_config_ctx_request_ctx(self) -> None:
3955
trace_request_ctx = Mock()

0 commit comments

Comments
 (0)