11"""Decorators for using the mock."""
22
33import re
4- import threading
54import time
65from contextlib import ContextDecorator
7- from typing import TYPE_CHECKING , Any , Literal , Self
6+ from typing import TYPE_CHECKING , Any , Literal , Self , cast
87from urllib .parse import urljoin , urlparse
98
109import requests as requests_lib
2625from .mock_web_services_api import MockVuforiaWebServicesAPI
2726
2827if TYPE_CHECKING :
29- from collections .abc import Callable , Iterable , Mapping
30-
31- from requests import PreparedRequest
32- from requests .adapters import HTTPAdapter # noqa: F401
33-
34- ResponseType = tuple [int , Mapping [str , str ], str ]
35- Callback = Callable [[PreparedRequest ], ResponseType ] # noqa: F841
36-
37- # Thread-local storage to capture the request timeout
38- _timeout_storage = threading .local ()
28+ from collections .abc import Iterable
3929
4030_STRUCTURAL_SIMILARITY_MATCHER = StructuralSimilarityMatcher ()
4131_BRISQUE_TRACKING_RATER = BrisqueTargetTrackingRater ()
@@ -149,44 +139,32 @@ def __enter__(self) -> Self:
149139 compiled_url_patterns : Iterable [re .Pattern [str ]] = set ()
150140 delay_seconds = self ._response_delay_seconds
151141
152- def wrap_callback (callback : "Callback" ) -> "Callback" :
153- """Wrap a callback to add a response delay."""
142+ mock = RequestsMock (assert_all_requests_are_fired = False )
143+
144+ if delay_seconds > 0 :
145+ mock_any = cast (Any , mock )
146+ original_on_request = mock_any ._on_request # noqa: SLF001
154147
155- def wrapped (request : "PreparedRequest" ) -> "ResponseType" :
156- # Check if the delay would exceed the request timeout
157- timeout = getattr (_timeout_storage , "timeout" , None )
158- if timeout is not None and delay_seconds > 0 :
159- # timeout can be a float or a tuple (connect, read)
148+ def patched_on_request (
149+ * args : Any , # noqa: ANN401
150+ ** kwargs : Any , # noqa: ANN401
151+ ) -> Any : # noqa: ANN401
152+ timeout = kwargs .get ("timeout" )
153+ if timeout is not None :
160154 if isinstance (timeout , tuple ):
161- effective_timeout : float | None = timeout [1 ] # read timeout
155+ effective : float | None = timeout [1 ]
162156 else :
163- effective_timeout = timeout
157+ effective = timeout
164158 if (
165- effective_timeout is not None
166- and delay_seconds > effective_timeout
159+ isinstance ( effective , ( int , float ))
160+ and delay_seconds > effective
167161 ):
168162 raise requests_lib .exceptions .Timeout
169-
170- result = callback (request )
163+ result = original_on_request (* args , ** kwargs )
171164 time .sleep (delay_seconds )
172165 return result
173166
174- return wrapped
175-
176- mock = RequestsMock (assert_all_requests_are_fired = False )
177-
178- # Patch _on_request to capture the timeout parameter
179- original_on_request = mock ._on_request # noqa: SLF001
180-
181- def patched_on_request (
182- adapter : "HTTPAdapter" ,
183- request : "PreparedRequest" ,
184- ** kwargs : Any , # noqa: ANN401
185- ) -> Any : # noqa: ANN401
186- _timeout_storage .timeout = kwargs .get ("timeout" )
187- return original_on_request (adapter , request , ** kwargs ) # type: ignore[misc]
188-
189- mock ._on_request = patched_on_request # type: ignore[method-assign] # noqa: SLF001
167+ mock_any ._on_request = patched_on_request # noqa: SLF001
190168 for vws_route in self ._mock_vws_api .routes :
191169 url_pattern = urljoin (
192170 base = self ._base_vws_url ,
@@ -205,7 +183,7 @@ def patched_on_request(
205183 mock .add_callback (
206184 method = vws_http_method ,
207185 url = compiled_url_pattern ,
208- callback = wrap_callback ( callback = original_callback ) ,
186+ callback = original_callback ,
209187 content_type = None ,
210188 )
211189
@@ -227,7 +205,7 @@ def patched_on_request(
227205 mock .add_callback (
228206 method = vwq_http_method ,
229207 url = compiled_url_pattern ,
230- callback = wrap_callback ( callback = original_callback ) ,
208+ callback = original_callback ,
231209 content_type = None ,
232210 )
233211
0 commit comments