Skip to content

Commit 2ab6964

Browse files
committed
chore: add tests
1 parent 337f530 commit 2ab6964

5 files changed

Lines changed: 1278 additions & 3 deletions

File tree

replane/_sync.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,12 @@ def _process_stream(self, response: http.client.HTTPResponse) -> None:
425425
parser = SSEParser()
426426
last_event_time = time.monotonic()
427427

428-
# Set socket timeout for inactivity detection
429-
# Access internal socket for timeout - implementation detail
428+
# Use a short socket timeout (1s) to allow checking _stop_event frequently.
429+
# We track elapsed time separately for the real inactivity timeout.
430+
socket_timeout = 1.0
430431
sock = response.fp.raw._sock if hasattr(response.fp, "raw") else None # type: ignore[attr-defined]
431432
if sock:
432-
sock.settimeout(self._inactivity_timeout)
433+
sock.settimeout(socket_timeout)
433434

434435
buffer_size = 4096
435436

@@ -456,10 +457,12 @@ def _process_stream(self, response: http.client.HTTPResponse) -> None:
456457
self._handle_event(event)
457458

458459
except socket.timeout:
460+
# Check if we've exceeded the inactivity timeout
459461
elapsed = time.monotonic() - last_event_time
460462
if elapsed > self._inactivity_timeout:
461463
logger.debug("SSE inactivity timeout, reconnecting...")
462464
break
465+
# Otherwise, just loop and check _stop_event again
463466

464467
except Exception as e:
465468
if self._stop_event.is_set():

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Pytest fixtures for Replane SDK tests."""
2+
3+
from __future__ import annotations
4+
5+
import pytest
6+
7+
from .mock_server import MockSSEServer
8+
9+
10+
@pytest.fixture
11+
def mock_server():
12+
"""Provide a mock SSE server for testing.
13+
14+
The server is started before the test and stopped after.
15+
Each test gets a fresh server with reset state.
16+
17+
Example:
18+
def test_something(mock_server):
19+
mock_server.send_init([{"name": "feature", "value": True}])
20+
client = SyncReplaneClient(base_url=mock_server.url, sdk_key="test")
21+
client.connect()
22+
assert client.get("feature") is True
23+
client.close()
24+
"""
25+
server = MockSSEServer(port=0) # Pick available port
26+
server.start()
27+
yield server
28+
server.stop()
29+
30+
31+
@pytest.fixture
32+
def server_url(mock_server: MockSSEServer) -> str:
33+
"""Provide just the URL of the mock server."""
34+
return mock_server.url

tests/mock_server.py

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
"""Mock SSE server for integration testing.
2+
3+
This module provides a controllable HTTP server that simulates a Replane server
4+
for testing the SDK clients without requiring a real server.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import json
10+
import queue
11+
import socketserver
12+
import threading
13+
import time
14+
from http.server import BaseHTTPRequestHandler, HTTPServer
15+
from typing import Any
16+
17+
18+
class MockSSEHandler(BaseHTTPRequestHandler):
19+
"""HTTP request handler for mock SSE server."""
20+
21+
# Disable logging to stderr during tests
22+
def log_message(self, format: str, *args: Any) -> None:
23+
pass
24+
25+
def do_POST(self) -> None:
26+
"""Handle POST requests to the SSE endpoint."""
27+
server: MockSSEServer = self.server # type: ignore
28+
29+
# Check path
30+
if not self.path.endswith("/api/sdk/v1/replication/stream"):
31+
self.send_error(404, "Not Found")
32+
return
33+
34+
# Check authentication
35+
auth_header = self.headers.get("Authorization", "")
36+
if server.required_sdk_key:
37+
expected = f"Bearer {server.required_sdk_key}"
38+
if auth_header != expected:
39+
self.send_response(401)
40+
self.send_header("Content-Type", "application/json")
41+
self.end_headers()
42+
self.wfile.write(b'{"error": "Unauthorized"}')
43+
return
44+
45+
# Check for forced status code
46+
if server.next_status_code != 200:
47+
status = server.next_status_code
48+
server.next_status_code = 200 # Reset for next request
49+
self.send_response(status)
50+
self.send_header("Content-Type", "application/json")
51+
self.end_headers()
52+
self.wfile.write(json.dumps({"error": f"Status {status}"}).encode())
53+
return
54+
55+
# Apply delay if configured
56+
if server.response_delay > 0:
57+
delay = server.response_delay
58+
server.response_delay = 0 # Reset for next request
59+
time.sleep(delay)
60+
61+
# Send SSE response headers
62+
self.send_response(200)
63+
self.send_header("Content-Type", "text/event-stream")
64+
self.send_header("Cache-Control", "no-cache")
65+
self.send_header("Connection", "keep-alive")
66+
self.end_headers()
67+
68+
# Track this connection
69+
server.active_connections += 1
70+
server.connection_event.set()
71+
72+
try:
73+
# Stream events from the queue
74+
while not server.should_stop:
75+
# Check for disconnect signal
76+
if server.should_disconnect:
77+
server.should_disconnect = False
78+
break
79+
80+
try:
81+
event = server.events_queue.get(timeout=0.1)
82+
except queue.Empty:
83+
continue
84+
85+
# Format and send SSE event
86+
event_type = event.get("type", "message")
87+
data = event.get("data", {})
88+
89+
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
90+
try:
91+
self.wfile.write(sse_message.encode())
92+
self.wfile.flush()
93+
except (BrokenPipeError, ConnectionResetError):
94+
break
95+
96+
finally:
97+
server.active_connections -= 1
98+
99+
100+
class MockSSEServer(HTTPServer):
101+
"""Controllable HTTP server for testing SSE clients.
102+
103+
Example:
104+
>>> server = MockSSEServer()
105+
>>> server.start()
106+
>>> # Queue events before client connects
107+
>>> server.send_init([{"name": "feature", "value": True}])
108+
>>> # ... run client tests ...
109+
>>> server.stop()
110+
"""
111+
112+
def __init__(self, port: int = 0):
113+
"""Initialize the mock server.
114+
115+
Args:
116+
port: Port to listen on. Use 0 to pick an available port.
117+
"""
118+
# Allow address reuse to avoid "Address already in use" errors
119+
socketserver.TCPServer.allow_reuse_address = True
120+
121+
super().__init__(("127.0.0.1", port), MockSSEHandler)
122+
123+
self.events_queue: queue.Queue[dict[str, Any]] = queue.Queue()
124+
self.next_status_code = 200
125+
self.response_delay = 0.0
126+
self.should_disconnect = False
127+
self.should_stop = False
128+
self.required_sdk_key: str | None = None
129+
130+
# Connection tracking
131+
self.active_connections = 0
132+
self.connection_event = threading.Event()
133+
134+
# Server thread
135+
self._thread: threading.Thread | None = None
136+
137+
@property
138+
def port(self) -> int:
139+
"""Get the actual port the server is listening on."""
140+
return self.server_address[1]
141+
142+
@property
143+
def url(self) -> str:
144+
"""Get the base URL for connecting to this server."""
145+
return f"http://127.0.0.1:{self.port}"
146+
147+
def start(self) -> None:
148+
"""Start the server in a background thread."""
149+
self.should_stop = False
150+
self._thread = threading.Thread(
151+
target=self.serve_forever,
152+
daemon=True,
153+
name="mock-sse-server",
154+
)
155+
self._thread.start()
156+
157+
def stop(self) -> None:
158+
"""Stop the server and wait for it to finish."""
159+
self.should_stop = True
160+
self.shutdown()
161+
if self._thread:
162+
self._thread.join(timeout=2.0)
163+
self._thread = None
164+
165+
def reset(self) -> None:
166+
"""Reset server state for a new test."""
167+
# Clear event queue
168+
while not self.events_queue.empty():
169+
try:
170+
self.events_queue.get_nowait()
171+
except queue.Empty:
172+
break
173+
174+
self.next_status_code = 200
175+
self.response_delay = 0.0
176+
self.should_disconnect = False
177+
self.required_sdk_key = None
178+
self.connection_event.clear()
179+
180+
def wait_for_connection(self, timeout: float = 5.0) -> bool:
181+
"""Wait for a client to connect.
182+
183+
Args:
184+
timeout: Maximum time to wait in seconds.
185+
186+
Returns:
187+
True if a connection was made, False if timeout.
188+
"""
189+
return self.connection_event.wait(timeout=timeout)
190+
191+
def send_event(self, event_type: str, data: dict[str, Any]) -> None:
192+
"""Queue an SSE event to be sent to connected clients.
193+
194+
Args:
195+
event_type: The SSE event type (e.g., "init", "config_change").
196+
data: The event data to send as JSON.
197+
"""
198+
self.events_queue.put({"type": event_type, "data": data})
199+
200+
def send_init(self, configs: list[dict[str, Any]]) -> None:
201+
"""Send an init event with the given configs.
202+
203+
Args:
204+
configs: List of config objects with name, value, and optional overrides.
205+
"""
206+
self.send_event("init", {"type": "init", "configs": configs})
207+
208+
def send_config_change(self, config: dict[str, Any]) -> None:
209+
"""Send a config change event.
210+
211+
Args:
212+
config: Config object with name, value, and optional overrides.
213+
"""
214+
self.send_event("config_change", {"type": "config_change", "config": config})
215+
216+
def set_status_code(self, code: int) -> None:
217+
"""Set the HTTP status code for the next request.
218+
219+
Args:
220+
code: HTTP status code (e.g., 401, 500).
221+
"""
222+
self.next_status_code = code
223+
224+
def set_delay(self, seconds: float) -> None:
225+
"""Set a delay before responding to the next request.
226+
227+
Args:
228+
seconds: Delay in seconds.
229+
"""
230+
self.response_delay = seconds
231+
232+
def disconnect(self) -> None:
233+
"""Force disconnect the current SSE stream."""
234+
self.should_disconnect = True
235+
236+
def set_auth_required(self, sdk_key: str) -> None:
237+
"""Require a specific SDK key for authentication.
238+
239+
Args:
240+
sdk_key: The SDK key to require.
241+
"""
242+
self.required_sdk_key = sdk_key
243+
244+
245+
def create_config(
246+
name: str,
247+
value: Any,
248+
overrides: list[dict[str, Any]] | None = None,
249+
) -> dict[str, Any]:
250+
"""Helper to create a config object for testing.
251+
252+
Args:
253+
name: Config name.
254+
value: Config value.
255+
overrides: Optional list of override rules.
256+
257+
Returns:
258+
A config dict ready to send via SSE.
259+
"""
260+
config: dict[str, Any] = {"name": name, "value": value}
261+
if overrides:
262+
config["overrides"] = overrides
263+
return config
264+
265+
266+
def create_override(
267+
name: str,
268+
value: Any,
269+
conditions: list[dict[str, Any]],
270+
) -> dict[str, Any]:
271+
"""Helper to create an override rule for testing.
272+
273+
Args:
274+
name: Override name.
275+
value: Value when override matches.
276+
conditions: List of condition objects.
277+
278+
Returns:
279+
An override dict.
280+
"""
281+
return {"name": name, "value": value, "conditions": conditions}
282+
283+
284+
def create_condition(
285+
operator: str,
286+
property: str,
287+
value: Any,
288+
) -> dict[str, Any]:
289+
"""Helper to create a condition for testing.
290+
291+
Args:
292+
operator: Condition operator (e.g., "equals", "in").
293+
property: Context property to check.
294+
value: Expected value.
295+
296+
Returns:
297+
A condition dict.
298+
"""
299+
return {"operator": operator, "property": property, "value": value}

0 commit comments

Comments
 (0)