|
2 | 2 |
|
3 | 3 | import re |
4 | 4 | import time |
| 5 | +from collections.abc import Callable, Mapping |
5 | 6 | from contextlib import ContextDecorator |
6 | | -from typing import TYPE_CHECKING, Literal, Self |
| 7 | +from typing import Literal, Self |
7 | 8 | from urllib.parse import urljoin, urlparse |
8 | 9 |
|
9 | 10 | import requests |
10 | 11 | from beartype import BeartypeConf, beartype |
| 12 | +from requests import PreparedRequest |
11 | 13 | from responses import RequestsMock |
12 | 14 |
|
13 | 15 | from mock_vws.database import VuforiaDatabase |
|
24 | 26 | from .mock_web_query_api import MockVuforiaWebQueryAPI |
25 | 27 | from .mock_web_services_api import MockVuforiaWebServicesAPI |
26 | 28 |
|
27 | | -if TYPE_CHECKING: |
28 | | - from collections.abc import Callable, Iterable, Mapping |
29 | | - |
30 | | - from requests import PreparedRequest |
31 | | - |
32 | | - ResponseType = tuple[int, Mapping[str, str], str] |
33 | | - Callback = Callable[[PreparedRequest], ResponseType] |
| 29 | +_ResponseType = tuple[int, Mapping[str, str], str] |
| 30 | +_Callback = Callable[[PreparedRequest], _ResponseType] |
34 | 31 |
|
35 | 32 | _STRUCTURAL_SIMILARITY_MATCHER = StructuralSimilarityMatcher() |
36 | 33 | _BRISQUE_TRACKING_RATER = BrisqueTargetTrackingRater() |
@@ -135,91 +132,69 @@ def add_database(self, database: VuforiaDatabase) -> None: |
135 | 132 | """ |
136 | 133 | self._target_manager.add_database(database=database) |
137 | 134 |
|
| 135 | + @staticmethod |
| 136 | + def _wrap_callback( |
| 137 | + callback: _Callback, |
| 138 | + delay_seconds: float, |
| 139 | + ) -> _Callback: |
| 140 | + """Wrap a callback to add a response delay.""" |
| 141 | + |
| 142 | + def wrapped( |
| 143 | + request: "PreparedRequest", |
| 144 | + ) -> "_ResponseType": |
| 145 | + # req_kwargs is added dynamically by the responses |
| 146 | + # library onto PreparedRequest objects - it is not |
| 147 | + # in the requests type stubs. |
| 148 | + timeout = request.req_kwargs.get("timeout") # type: ignore[attr-defined] |
| 149 | + # requests allows timeout as a (connect, read) |
| 150 | + # tuple. The delay simulates server response |
| 151 | + # time, so compare against the read timeout. |
| 152 | + effective: float | None = None |
| 153 | + if isinstance(timeout, tuple): |
| 154 | + effective = timeout[1] |
| 155 | + elif isinstance(timeout, (int, float)): |
| 156 | + effective = timeout |
| 157 | + |
| 158 | + if effective is not None and delay_seconds > effective: |
| 159 | + time.sleep(effective) |
| 160 | + raise requests.exceptions.Timeout |
| 161 | + |
| 162 | + result = callback(request) |
| 163 | + time.sleep(delay_seconds) |
| 164 | + return result |
| 165 | + |
| 166 | + return wrapped |
| 167 | + |
138 | 168 | def __enter__(self) -> Self: |
139 | 169 | """Start an instance of a Vuforia mock. |
140 | 170 |
|
141 | 171 | Returns: |
142 | 172 | ``self``. |
143 | 173 | """ |
144 | | - compiled_url_patterns: Iterable[re.Pattern[str]] = set() |
145 | | - delay_seconds = self._response_delay_seconds |
146 | | - |
147 | | - def wrap_callback(callback: "Callback") -> "Callback": |
148 | | - """Wrap a callback to add a response delay.""" |
149 | | - |
150 | | - def wrapped( |
151 | | - request: "PreparedRequest", |
152 | | - ) -> "ResponseType": |
153 | | - # req_kwargs is added dynamically by the responses |
154 | | - # library onto PreparedRequest objects - it is not |
155 | | - # in the requests type stubs. |
156 | | - timeout = request.req_kwargs.get("timeout") # type: ignore[attr-defined] |
157 | | - # requests allows timeout as a (connect, read) |
158 | | - # tuple. The delay simulates server response |
159 | | - # time, so compare against the read timeout. |
160 | | - effective: float | None = None |
161 | | - if isinstance(timeout, tuple): |
162 | | - effective = timeout[1] |
163 | | - elif isinstance(timeout, (int, float)): |
164 | | - effective = timeout |
165 | | - |
166 | | - if ( |
167 | | - effective is not None |
168 | | - and delay_seconds > effective |
169 | | - ): |
170 | | - time.sleep(effective) |
171 | | - raise requests.exceptions.Timeout |
172 | | - |
173 | | - result = callback(request) |
174 | | - time.sleep(delay_seconds) |
175 | | - return result |
176 | | - |
177 | | - return wrapped |
178 | | - |
179 | 174 | mock = RequestsMock(assert_all_requests_are_fired=False) |
180 | | - for vws_route in self._mock_vws_api.routes: |
181 | | - url_pattern = urljoin( |
182 | | - base=self._base_vws_url, |
183 | | - url=f"{vws_route.path_pattern}$", |
184 | | - ) |
185 | | - compiled_url_pattern = re.compile(pattern=url_pattern) |
186 | | - compiled_url_patterns = { |
187 | | - *compiled_url_patterns, |
188 | | - compiled_url_pattern, |
189 | | - } |
190 | | - |
191 | | - for vws_http_method in vws_route.http_methods: |
192 | | - original_callback = getattr( |
193 | | - self._mock_vws_api, vws_route.route_name |
194 | | - ) |
195 | | - mock.add_callback( |
196 | | - method=vws_http_method, |
197 | | - url=compiled_url_pattern, |
198 | | - callback=wrap_callback(callback=original_callback), |
199 | | - content_type=None, |
200 | | - ) |
201 | 175 |
|
202 | | - for vwq_route in self._mock_vwq_api.routes: |
203 | | - url_pattern = urljoin( |
204 | | - base=self._base_vwq_url, |
205 | | - url=f"{vwq_route.path_pattern}$", |
206 | | - ) |
207 | | - compiled_url_pattern = re.compile(pattern=url_pattern) |
208 | | - compiled_url_patterns = { |
209 | | - *compiled_url_patterns, |
210 | | - compiled_url_pattern, |
211 | | - } |
212 | | - |
213 | | - for vwq_http_method in vwq_route.http_methods: |
214 | | - original_callback = getattr( |
215 | | - self._mock_vwq_api, vwq_route.route_name |
216 | | - ) |
217 | | - mock.add_callback( |
218 | | - method=vwq_http_method, |
219 | | - url=compiled_url_pattern, |
220 | | - callback=wrap_callback(callback=original_callback), |
221 | | - content_type=None, |
| 176 | + for api, base_url in ( |
| 177 | + (self._mock_vws_api, self._base_vws_url), |
| 178 | + (self._mock_vwq_api, self._base_vwq_url), |
| 179 | + ): |
| 180 | + for route in api.routes: |
| 181 | + url_pattern = urljoin( |
| 182 | + base=base_url, |
| 183 | + url=f"{route.path_pattern}$", |
222 | 184 | ) |
| 185 | + compiled_url_pattern = re.compile(pattern=url_pattern) |
| 186 | + |
| 187 | + for http_method in route.http_methods: |
| 188 | + original_callback = getattr(api, route.route_name) |
| 189 | + mock.add_callback( |
| 190 | + method=http_method, |
| 191 | + url=compiled_url_pattern, |
| 192 | + callback=self._wrap_callback( |
| 193 | + callback=original_callback, |
| 194 | + delay_seconds=self._response_delay_seconds, |
| 195 | + ), |
| 196 | + content_type=None, |
| 197 | + ) |
223 | 198 |
|
224 | 199 | if self._real_http: |
225 | 200 | all_requests_pattern = re.compile(pattern=".*") |
|
0 commit comments