Skip to content

Commit a0fdf96

Browse files
committed
security: comprehensive code review fixes - SSRF protection, SSE parsing, mask API key, add retry logic, fix kwargs, add chat_stream tests, connection pool config, fix exception swallowing, extract shared headers, fix magic defaults
1 parent 2f2cc32 commit a0fdf96

8 files changed

Lines changed: 279 additions & 78 deletions

File tree

src/hawk/agent.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,14 @@ class AgentConfig:
2020
Attributes:
2121
name: Agent identifier.
2222
model: Model to use for completions.
23-
system_prompt: Optional system prompt prepended to conversations.
2423
tools: List of tools available to the agent.
2524
max_rounds: Maximum tool-call rounds per chat turn.
26-
temperature: Sampling temperature.
27-
top_p: Nucleus sampling parameter.
2825
"""
2926

3027
name: str = "hawk-agent"
3128
model: str | None = None
32-
system_prompt: str | None = None
3329
tools: list[Tool] = field(default_factory=list)
3430
max_rounds: int = 10
35-
temperature: float | None = None
36-
top_p: float | None = None
3731

3832

3933
class Agent:

src/hawk/client.py

Lines changed: 86 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
from typing import Any
6+
from urllib.parse import urlparse
67

78
import httpx
89

@@ -23,6 +24,29 @@
2324

2425
DEFAULT_BASE_URL = "http://127.0.0.1:4590"
2526
DEFAULT_TIMEOUT = 30.0
27+
DEFAULT_POOL_CONNECTIONS = 10
28+
DEFAULT_POOL_MAXSIZE = 100
29+
30+
31+
def _build_headers(api_key: str | None) -> dict[str, str]:
32+
"""Build standard HTTP headers for Hawk API requests."""
33+
headers: dict[str, str] = {
34+
"Accept": "application/json",
35+
"User-Agent": f"hawk-sdk-python/{__version__}",
36+
}
37+
if api_key:
38+
headers["Authorization"] = f"Bearer {api_key}"
39+
return headers
40+
41+
42+
def _validate_base_url(url: str) -> str:
43+
"""Validate that base_url uses http/https and has a hostname."""
44+
parsed = urlparse(url)
45+
if parsed.scheme not in ("http", "https"):
46+
raise ValueError(f"base_url must use http or https scheme, got: {parsed.scheme}")
47+
if not parsed.hostname:
48+
raise ValueError("base_url must have a hostname")
49+
return url
2650

2751

2852
class HawkClient:
@@ -40,25 +64,29 @@ def __init__(
4064
api_key: str | None = None,
4165
retry_config: RetryConfig | None = None,
4266
timeout: float = DEFAULT_TIMEOUT,
67+
pool_connections: int = DEFAULT_POOL_CONNECTIONS,
68+
pool_maxsize: int = DEFAULT_POOL_MAXSIZE,
4369
) -> None:
70+
_validate_base_url(base_url)
4471
self._base_url = base_url.rstrip("/")
4572
self._api_key = api_key
4673
self._retry_config = retry_config or DEFAULT_RETRY_CONFIG
4774
self._timeout = timeout
75+
self._pool_connections = pool_connections
76+
self._pool_maxsize = pool_maxsize
4877
self._client = httpx.Client(
4978
base_url=self._base_url,
5079
timeout=self._timeout,
51-
headers=self._build_headers(),
80+
headers=_build_headers(api_key),
81+
limits=httpx.Limits(
82+
max_keepalive_connections=pool_connections,
83+
max_connections=pool_maxsize,
84+
),
5285
)
5386

54-
def _build_headers(self) -> dict[str, str]:
55-
headers: dict[str, str] = {
56-
"Accept": "application/json",
57-
"User-Agent": f"hawk-sdk-python/{__version__}",
58-
}
59-
if self._api_key:
60-
headers["Authorization"] = f"Bearer {self._api_key}"
61-
return headers
87+
def __repr__(self) -> str:
88+
masked = "***" + self._api_key[-4:] if self._api_key else "None"
89+
return f"HawkClient(base_url='{self._base_url}', api_key='{masked}')"
6290

6391
def __enter__(self) -> HawkClient:
6492
return self
@@ -96,7 +124,6 @@ def chat(
96124
autonomy: str | None = None,
97125
cwd: str | None = None,
98126
agent: str | None = None,
99-
**kwargs: Any,
100127
) -> ChatResponse:
101128
"""Send a prompt and return the complete response."""
102129
request = ChatRequest(
@@ -129,7 +156,6 @@ def chat_stream(
129156
autonomy: str | None = None,
130157
cwd: str | None = None,
131158
agent: str | None = None,
132-
**kwargs: Any,
133159
) -> StreamReader:
134160
"""Send a prompt and stream the response via SSE."""
135161
request = ChatRequest(
@@ -142,21 +168,24 @@ def chat_stream(
142168
agent=agent,
143169
)
144170

145-
response = self._client.send(
146-
self._client.build_request(
147-
"POST",
148-
"/v1/chat",
149-
json=request.model_dump(exclude_none=True, by_alias=True),
150-
headers={"Accept": "text/event-stream"},
151-
),
152-
stream=True,
153-
)
171+
def _do() -> StreamReader:
172+
response = self._client.send(
173+
self._client.build_request(
174+
"POST",
175+
"/v1/chat",
176+
json=request.model_dump(exclude_none=True, by_alias=True),
177+
headers={"Accept": "text/event-stream"},
178+
),
179+
stream=True,
180+
)
154181

155-
if response.status_code >= 400:
156-
response.read()
157-
raise parse_error(response)
182+
if response.status_code >= 400:
183+
response.read()
184+
raise parse_error(response)
158185

159-
return StreamReader(response)
186+
return StreamReader(response)
187+
188+
return with_retry_sync(_do, self._retry_config)
160189

161190
def create_session(
162191
self,
@@ -189,9 +218,7 @@ def list_sessions(self, limit: int = 20, offset: int = 0) -> PaginatedResponse[S
189218
"""List sessions with pagination."""
190219

191220
def _do() -> PaginatedResponse[SessionSummary]:
192-
params: dict[str, Any] = {}
193-
if limit != 20:
194-
params["limit"] = limit
221+
params: dict[str, Any] = {"limit": limit}
195222
if offset > 0:
196223
params["offset"] = offset
197224
resp = self._request("GET", "/v1/sessions", params=params)
@@ -213,9 +240,7 @@ def list_messages(
213240
"""List messages for a session with pagination."""
214241

215242
def _do() -> PaginatedResponse[Message]:
216-
params: dict[str, Any] = {}
217-
if limit != 50:
218-
params["limit"] = limit
243+
params: dict[str, Any] = {"limit": limit}
219244
if offset > 0:
220245
params["offset"] = offset
221246
resp = self._request("GET", f"/v1/sessions/{session_id}/messages", params=params)
@@ -248,25 +273,29 @@ def __init__(
248273
api_key: str | None = None,
249274
retry_config: RetryConfig | None = None,
250275
timeout: float = DEFAULT_TIMEOUT,
276+
pool_connections: int = DEFAULT_POOL_CONNECTIONS,
277+
pool_maxsize: int = DEFAULT_POOL_MAXSIZE,
251278
) -> None:
279+
_validate_base_url(base_url)
252280
self._base_url = base_url.rstrip("/")
253281
self._api_key = api_key
254282
self._retry_config = retry_config or DEFAULT_RETRY_CONFIG
255283
self._timeout = timeout
284+
self._pool_connections = pool_connections
285+
self._pool_maxsize = pool_maxsize
256286
self._client = httpx.AsyncClient(
257287
base_url=self._base_url,
258288
timeout=self._timeout,
259-
headers=self._build_headers(),
289+
headers=_build_headers(api_key),
290+
limits=httpx.Limits(
291+
max_keepalive_connections=pool_connections,
292+
max_connections=pool_maxsize,
293+
),
260294
)
261295

262-
def _build_headers(self) -> dict[str, str]:
263-
headers: dict[str, str] = {
264-
"Accept": "application/json",
265-
"User-Agent": f"hawk-sdk-python/{__version__}",
266-
}
267-
if self._api_key:
268-
headers["Authorization"] = f"Bearer {self._api_key}"
269-
return headers
296+
def __repr__(self) -> str:
297+
masked = "***" + self._api_key[-4:] if self._api_key else "None"
298+
return f"AsyncHawkClient(base_url='{self._base_url}', api_key='{masked}')"
270299

271300
async def __aenter__(self) -> AsyncHawkClient:
272301
return self
@@ -304,7 +333,6 @@ async def chat(
304333
autonomy: str | None = None,
305334
cwd: str | None = None,
306335
agent: str | None = None,
307-
**kwargs: Any,
308336
) -> ChatResponse:
309337
"""Send a prompt and return the complete response."""
310338
request = ChatRequest(
@@ -337,7 +365,6 @@ async def chat_stream(
337365
autonomy: str | None = None,
338366
cwd: str | None = None,
339367
agent: str | None = None,
340-
**kwargs: Any,
341368
) -> AsyncStreamReader:
342369
"""Send a prompt and stream the response via SSE."""
343370
request = ChatRequest(
@@ -350,21 +377,24 @@ async def chat_stream(
350377
agent=agent,
351378
)
352379

353-
response = await self._client.send(
354-
self._client.build_request(
355-
"POST",
356-
"/v1/chat",
357-
json=request.model_dump(exclude_none=True, by_alias=True),
358-
headers={"Accept": "text/event-stream"},
359-
),
360-
stream=True,
361-
)
380+
async def _do() -> AsyncStreamReader:
381+
response = await self._client.send(
382+
self._client.build_request(
383+
"POST",
384+
"/v1/chat",
385+
json=request.model_dump(exclude_none=True, by_alias=True),
386+
headers={"Accept": "text/event-stream"},
387+
),
388+
stream=True,
389+
)
362390

363-
if response.status_code >= 400:
364-
await response.aread()
365-
raise parse_error(response)
391+
if response.status_code >= 400:
392+
await response.aread()
393+
raise parse_error(response)
394+
395+
return AsyncStreamReader(response)
366396

367-
return AsyncStreamReader(response)
397+
return await with_retry(_do, self._retry_config)
368398

369399
async def create_session(
370400
self,
@@ -399,9 +429,7 @@ async def list_sessions(
399429
"""List sessions with pagination."""
400430

401431
async def _do() -> PaginatedResponse[SessionSummary]:
402-
params: dict[str, Any] = {}
403-
if limit != 20:
404-
params["limit"] = limit
432+
params: dict[str, Any] = {"limit": limit}
405433
if offset > 0:
406434
params["offset"] = offset
407435
resp = await self._request("GET", "/v1/sessions", params=params)
@@ -423,9 +451,7 @@ async def list_messages(
423451
"""List messages for a session with pagination."""
424452

425453
async def _do() -> PaginatedResponse[Message]:
426-
params: dict[str, Any] = {}
427-
if limit != 50:
428-
params["limit"] = limit
454+
params: dict[str, Any] = {"limit": limit}
429455
if offset > 0:
430456
params["offset"] = offset
431457
resp = await self._request("GET", f"/v1/sessions/{session_id}/messages", params=params)

src/hawk/discovery.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,9 @@ async def _fetch_card(self, base_url: str) -> AgentCard | None:
158158
data = resp.json()
159159
if isinstance(data, dict):
160160
return AgentCard.from_dict(cast("dict[str, Any]", data))
161-
except Exception:
162-
pass
161+
except Exception as exc:
162+
import logging
163+
logging.getLogger(__name__).debug("Failed to fetch agent card from %s: %s", base_url, exc)
163164
return None
164165

165166

src/hawk/memory_tools.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def record_memory(
4747
if hasattr(self._client, "remember"):
4848
self._client.remember(content, session_id=self._session_id)
4949
return f"Recorded to persistent memory: '{content[:100]}...'"
50-
except Exception:
51-
pass
50+
except Exception as exc:
51+
import logging
52+
logging.getLogger(__name__).debug("Failed to persist memory via yaad: %s", exc)
5253

5354
return f"Recorded to session memory: '{content[:100]}...'"
5455

@@ -64,8 +65,9 @@ def retrieve_memories(self, query: str, limit: int = 5) -> str:
6465
return f"Recalled {len(recalled)} memories:\n" + "\n".join(
6566
f"- {m}" for m in recalled
6667
)
67-
except Exception:
68-
pass
68+
except Exception as exc:
69+
import logging
70+
logging.getLogger(__name__).debug("Failed to recall memories via yaad: %s", exc)
6971

7072
# Fallback to local fuzzy match
7173
query_lower = query.lower()

src/hawk/retry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ async def with_retry(
8989

9090
await asyncio.sleep(wait)
9191

92-
assert last_error is not None
92+
if last_error is None:
93+
raise HawkAPIError("Retry logic failed unexpectedly: no error recorded")
9394
raise last_error
9495

9596

@@ -138,5 +139,6 @@ def with_retry_sync(
138139

139140
time.sleep(wait)
140141

141-
assert last_error is not None
142+
if last_error is None:
143+
raise HawkAPIError("Retry logic failed unexpectedly: no error recorded")
142144
raise last_error

src/hawk/streaming.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,15 @@ def events(self) -> Iterator[StreamEvent]:
5050
if line.startswith("event: "):
5151
current_event = line[7:]
5252
elif line.startswith("data: "):
53-
current_data = line[6:]
53+
if current_data is not None:
54+
current_data += "\n" + line[6:]
55+
else:
56+
current_data = line[6:]
5457
elif line == "data:":
55-
current_data = ""
58+
if current_data is not None:
59+
current_data += "\n"
60+
else:
61+
current_data = ""
5662

5763
def collect_text(self) -> str:
5864
"""Consume the entire stream and return concatenated text content."""
@@ -126,9 +132,15 @@ async def events(self) -> AsyncIterator[StreamEvent]:
126132
if line.startswith("event: "):
127133
current_event = line[7:]
128134
elif line.startswith("data: "):
129-
current_data = line[6:]
135+
if current_data is not None:
136+
current_data += "\n" + line[6:]
137+
else:
138+
current_data = line[6:]
130139
elif line == "data:":
131-
current_data = ""
140+
if current_data is not None:
141+
current_data += "\n"
142+
else:
143+
current_data = ""
132144

133145
async def collect_text(self) -> str:
134146
"""Consume the entire stream and return concatenated text content."""

0 commit comments

Comments
 (0)