Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/mock_vws/_requests_mock_server/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _wrap_callback(
callback: _MockCallback,
delay_seconds: float,
sleep_fn: Callable[[float], None],
base_path: str,
) -> _ResponsesCallback:
"""Wrap a callback to add a response delay."""

Expand Down Expand Up @@ -197,9 +198,13 @@ def wrapped(
else:
body_bytes = raw_body

path = request.path_url
if base_path and path.startswith(base_path):
path = path[len(base_path) :]

request_data = RequestData(
method=request.method or "",
path=request.path_url,
path=path,
headers=dict(request.headers),
body=body_bytes,
)
Expand All @@ -221,6 +226,7 @@ def __enter__(self) -> Self:
(self._mock_vws_api, self._base_vws_url),
(self._mock_vwq_api, self._base_vwq_url),
):
base_path = urlparse(url=base_url).path.rstrip("/")
for route in api.routes:
url_pattern = base_url.rstrip("/") + route.path_pattern + "$"
compiled_url_pattern = re.compile(pattern=url_pattern)
Expand All @@ -234,6 +240,7 @@ def __enter__(self) -> Self:
callback=original_callback,
delay_seconds=self._response_delay_seconds,
sleep_fn=self._sleep_fn,
base_path=base_path,
),
content_type=None,
)
Expand Down
21 changes: 18 additions & 3 deletions src/mock_vws/_respx_mock_server/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,26 @@
_BRISQUE_TRACKING_RATER = BrisqueTargetTrackingRater()


def _to_request_data(request: httpx.Request) -> RequestData:
def _to_request_data(
request: httpx.Request,
*,
base_path: str,
) -> RequestData:
"""Convert an httpx.Request to a RequestData.

Args:
request: The httpx request to convert.
base_path: The base path prefix to strip from the request path.

Returns:
A RequestData with method, path, headers, and body set.
"""
path = request.url.raw_path.decode(encoding="ascii")
if base_path and path.startswith(base_path):
path = path[len(base_path) :]
return RequestData(
method=request.method,
path=request.url.raw_path.decode(encoding="ascii"),
path=path,
headers={k.title(): v for k, v in request.headers.items()},
body=request.content,
)
Expand Down Expand Up @@ -155,12 +163,14 @@ def add_vumark_database(self, vumark_database: VuMarkDatabase) -> None:
def _make_callback(
self,
handler: Callable[[RequestData], _ResponseType],
base_path: str,
) -> Callable[[httpx.Request], httpx.Response]:
"""Create a respx-compatible callback from a handler.

Args:
handler: A handler that takes a RequestData and returns a
response tuple.
base_path: The base path prefix to strip from the request path.

Returns:
A callback that takes an httpx.Request and returns an
Expand All @@ -183,7 +193,10 @@ def callback(request: httpx.Request) -> httpx.Response:
Exception: A timeout error is raised when the response
delay exceeds the read timeout.
"""
request_data = _to_request_data(request=request)
request_data = _to_request_data(
request=request,
base_path=base_path,
)
timeout_info: dict[str, float | None] = request.extensions.get(
"timeout", {}
)
Expand Down Expand Up @@ -237,6 +250,7 @@ def __enter__(self) -> Self:
(self._mock_vws_api, self._base_vws_url),
(self._mock_vwq_api, self._base_vwq_url),
):
base_path = urlparse(url=base_url).path.rstrip("/")
for route in api.routes:
url_pattern = base_url.rstrip("/") + route.path_pattern + "$"
compiled_url_pattern = re.compile(pattern=url_pattern)
Expand All @@ -249,6 +263,7 @@ def __enter__(self) -> Self:
).mock(
side_effect=self._make_callback(
handler=original_callback,
base_path=base_path,
),
)

Expand Down
39 changes: 38 additions & 1 deletion tests/mock_vws/test_requests_mock_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import json
import socket
from http import HTTPStatus
from urllib.parse import urlparse

import pytest
Expand All @@ -13,7 +14,7 @@
from freezegun import freeze_time
from PIL import Image
from vws import VWS, CloudRecoService
from vws_auth_tools import rfc_1123_date
from vws_auth_tools import authorization_header, rfc_1123_date

from mock_vws import MissingSchemeError, MockVWS
from mock_vws.database import CloudDatabase, VuMarkDatabase
Expand Down Expand Up @@ -391,6 +392,42 @@ def test_custom_base_vwq_url_with_path_prefix() -> None:
timeout=30,
)

@staticmethod
def test_vws_operations_work_with_path_prefix() -> None:
"""VWS API operations work correctly with a base URL path
prefix.
"""
database = CloudDatabase()
base_vws_url = "https://vuforia.vws.example.com/prefix"

with MockVWS(base_vws_url=base_vws_url) as mock:
mock.add_cloud_database(cloud_database=database)

request_path = "/targets"
date = rfc_1123_date()
auth = authorization_header(
access_key=database.server_access_key,
secret_key=database.server_secret_key,
method="GET",
content=b"",
content_type="",
date=date,
request_path=request_path,
)
response = requests.get(
url=base_vws_url + request_path,
headers={
"Authorization": auth,
"Date": date,
},
timeout=30,
)

assert response.status_code == HTTPStatus.OK
response_json = response.json()
assert response_json["result_code"] == "Success"
assert response_json["results"] == []

@staticmethod
def test_no_scheme() -> None:
"""An error if raised if a URL is given with no scheme."""
Expand Down
36 changes: 36 additions & 0 deletions tests/mock_vws/test_respx_mock_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,42 @@ def test_custom_base_vwq_url_with_path_prefix() -> None:
timeout=30,
)

@staticmethod
def test_vws_operations_work_with_path_prefix() -> None:
"""VWS API operations work correctly with a base URL path
prefix.
"""
database = CloudDatabase()
base_vws_url = "https://vuforia.vws.example.com/prefix"

with MockVWSForHttpx(base_vws_url=base_vws_url) as mock:
mock.add_cloud_database(cloud_database=database)

request_path = "/targets"
date = rfc_1123_date()
auth = authorization_header(
access_key=database.server_access_key,
secret_key=database.server_secret_key,
method="GET",
content=b"",
content_type="",
date=date,
request_path=request_path,
)
response = httpx.get(
url=base_vws_url + request_path,
headers={
"Authorization": auth,
"Date": date,
},
timeout=30,
)

assert response.status_code == HTTPStatus.OK
response_json = response.json()
assert response_json["result_code"] == "Success"
assert response_json["results"] == []

@staticmethod
def test_no_scheme() -> None:
"""An error is raised if a URL is given with no scheme."""
Expand Down