-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtoolhive_client.py
More file actions
330 lines (278 loc) · 10.8 KB
/
toolhive_client.py
File metadata and controls
330 lines (278 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import asyncio
import atexit
import os
import subprocess
import time
import httpx
import mcp_client
# Global variable to hold the thv serve process
thv_process = None
# Default API configuration
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8080
DEFAULT_SCAN_PORT_START = 50000
DEFAULT_SCAN_PORT_END = 50100
DEFAULT_TIMEOUT = 10
# Global variables to store discovered connection info
_discovered_host = None
_discovered_port = None
def start_thv_serve():
"""Start the thv serve process"""
global thv_process
print("Starting thv serve...")
thv_process = subprocess.Popen(
["thv", "serve"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(f"thv serve started with PID: {thv_process.pid}")
# Give it a moment to start up
time.sleep(1)
def stop_thv_serve():
"""Stop the thv serve process"""
global thv_process
if thv_process:
print("Stopping thv serve...")
thv_process.terminate()
try:
thv_process.wait(timeout=5)
except subprocess.TimeoutExpired:
thv_process.kill()
print("thv serve stopped")
async def _is_toolhive_available(
host: str, port: int, timeout: float = DEFAULT_TIMEOUT
) -> tuple[int, int]:
"""
Check if ToolHive is available at the given host and port (async).
Uses the /api/v1beta/version endpoint to verify ToolHive is running.
Returns:
Tuple of (port, port) if available, raises ConnectionError otherwise.
The duplicate port value makes sorting by version possible in the future.
"""
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(f"http://{host}:{port}/api/v1beta/version")
response.raise_for_status()
data = response.json()
except (httpx.HTTPError, OSError) as e:
raise ConnectionError(f"ToolHive not available at {host}:{port}: {e}")
if not isinstance(data, dict) or "version" not in data:
raise ConnectionError(
f"Port {port} on host {host} did not respond with ToolHive format"
)
return (port, port)
async def _scan_for_toolhive_async(
host: str,
scan_port_start: int = DEFAULT_SCAN_PORT_START,
scan_port_end: int = DEFAULT_SCAN_PORT_END,
) -> int:
"""
Scan for ToolHive in the specified port range using concurrent async checks.
Returns the first port where ToolHive is found, or raises ConnectionError.
"""
print(
f"Scanning for ToolHive on {host} in port range {scan_port_start}-{scan_port_end}..."
)
num_ports = scan_port_end - scan_port_start + 1
total_timeout = min(DEFAULT_TIMEOUT * num_ports, 30.0)
try:
task_outcomes = await asyncio.wait_for(
asyncio.gather(
*[
_is_toolhive_available(host, port)
for port in range(scan_port_start, scan_port_end + 1)
],
return_exceptions=True,
),
timeout=total_timeout,
)
except TimeoutError:
raise ConnectionError(
f"ToolHive port scan timed out after {total_timeout}s "
f"on {host} in port range {scan_port_start}-{scan_port_end}"
)
found_ports = [
result for result in task_outcomes if not isinstance(result, BaseException)
]
if found_ports:
port = found_ports[0][0]
print(f"✓ ToolHive found at {host}:{port}")
return port
raise ConnectionError(
f"ToolHive not found on {host} in port range {scan_port_start}-{scan_port_end}. "
f"Is 'thv serve' running?"
)
async def discover_toolhive_async(
host: str | None = None,
port: int | None = None,
scan_port_start: int = DEFAULT_SCAN_PORT_START,
scan_port_end: int = DEFAULT_SCAN_PORT_END,
skip_port_discovery: bool = False,
) -> tuple[str, int]:
"""
Discover ToolHive connection parameters (async version with concurrent scanning).
This implements the same discovery algorithm as mcp-optimizer:
1. Use explicit host/port if provided and working
2. Fall back to scanning a port range concurrently
3. Support skipping discovery for known environments (K8s)
Args:
host: ToolHive host (defaults to env TOOLHIVE_HOST or "127.0.0.1")
port: ToolHive port (if None, will scan for it)
scan_port_start: Start of port scan range
scan_port_end: End of port scan range
skip_port_discovery: Skip port scanning (useful in K8s with known ports)
Returns:
tuple of (host, port)
"""
global _discovered_host, _discovered_port
# Use cached values if available
if _discovered_host and _discovered_port:
return _discovered_host, _discovered_port
# Get host from parameter, env, or default
host = host or os.environ.get("TOOLHIVE_HOST", DEFAULT_HOST)
# Handle port discovery
if skip_port_discovery:
port = port or DEFAULT_PORT
print(f"Using ToolHive at {host}:{port} (port discovery skipped)")
elif port is not None:
# Try the provided port first with retries
found = False
for attempt in range(3):
try:
await _is_toolhive_available(host, port)
print(f"✓ ToolHive found at {host}:{port}")
found = True
break
except ConnectionError:
if attempt < 2:
await asyncio.sleep(1)
if not found:
# Fall back to scanning
print(f"Port {port} not available, scanning for ToolHive...")
port = await _scan_for_toolhive_async(host, scan_port_start, scan_port_end)
else:
# Scan for ToolHive with host fallbacks
scan_hosts = [host]
# Add sensible fallbacks to handle containerized environments across OSes
if host != DEFAULT_HOST:
scan_hosts.append(DEFAULT_HOST)
if host != "host.docker.internal":
scan_hosts.append("host.docker.internal")
last_error = None
for candidate in scan_hosts:
try:
port = await _scan_for_toolhive_async(
candidate, scan_port_start, scan_port_end
)
host = candidate
break
except ConnectionError as e:
last_error = e
continue
else:
# None of the candidates worked
raise last_error or ConnectionError(
f"ToolHive not found via hosts {scan_hosts} in ports {scan_port_start}-{scan_port_end}"
)
# Cache the discovered values
_discovered_host = host
_discovered_port = port
return host, port
def discover_toolhive(
host: str | None = None,
port: int | None = None,
scan_port_start: int = DEFAULT_SCAN_PORT_START,
scan_port_end: int = DEFAULT_SCAN_PORT_END,
skip_port_discovery: bool = False,
) -> tuple[str, int]:
"""
Discover ToolHive connection parameters (sync wrapper).
Detects if there's a running event loop and handles appropriately.
Prefers the async version when called from async context.
"""
global _discovered_host, _discovered_port
# Use cached values if available (fast path, no async needed)
if _discovered_host and _discovered_port:
return _discovered_host, _discovered_port
try:
asyncio.get_running_loop()
# There's a running loop - caller should use discover_toolhive_async instead.
# Fall back to cached or default values.
host = host or os.environ.get("TOOLHIVE_HOST", DEFAULT_HOST)
port = port or DEFAULT_PORT
return host, port
except RuntimeError:
# No running loop - safe to use asyncio.run()
return asyncio.run(
discover_toolhive_async(
host, port, scan_port_start, scan_port_end, skip_port_discovery
)
)
def list_workloads(host: str = DEFAULT_HOST, port: int = DEFAULT_PORT) -> dict:
"""List all running workloads from the ToolHive API"""
base_url = f"http://{host}:{port}"
endpoint = "/api/v1beta/workloads"
try:
with httpx.Client(timeout=5.0) as client:
response = client.get(f"{base_url}{endpoint}")
response.raise_for_status()
return {"success": True, "endpoint": endpoint, "data": response.json()}
except Exception as e:
return {"success": False, "error": str(e)}
def initialize():
"""
Initialize the ToolHive client - starts thv serve and returns workload info.
This function now uses port discovery to automatically find ToolHive,
making it work in containerized environments and dynamic port scenarios.
"""
# Register cleanup handler
atexit.register(stop_thv_serve)
# Start thv serve
start_thv_serve()
# Discover ToolHive using port scanning
# This will automatically find ToolHive even if it's not on the default port
try:
host, port = discover_toolhive()
print(f"Connected to ToolHive at {host}:{port}\n")
except ConnectionError as e:
print(f"Warning: {e}")
print("Falling back to default connection parameters...\n")
host, port = DEFAULT_HOST, DEFAULT_PORT
# List current workloads using discovered connection
workloads = list_workloads(host=host, port=port)
print("\n=== Current Workloads ===")
if workloads.get("success"):
print(f"Endpoint: {workloads.get('endpoint')}")
print(f"Data: {workloads.get('data')}")
else:
print(f"Error: {workloads.get('error')}")
print("=" * 25 + "\n")
# List all tools from MCP servers using discovered connection
print("=== Available Tools ===")
try:
tools_list = asyncio.run(mcp_client.list_tools(host=host, port=port))
for server_tools in tools_list:
workload_name = server_tools.get("workload", "unknown")
status = server_tools.get("status", "unknown")
tools = server_tools.get("tools", [])
error = server_tools.get("error")
print(f"\nWorkload: {workload_name}")
print(f" Status: {status}")
if tools:
# tools may be a list of dicts ({"name": ..., "description": ...})
# or a list of strings (back-compat). Normalize to names for display.
try:
names: list[str] = [
(t.get("name", "") if isinstance(t, dict) else str(t))
for t in tools
]
except Exception:
# Fallback: stringify everything
names = [str(t) for t in tools]
print(f" Tools: {', '.join(names)}")
if error:
print(f" Error: {error}")
except Exception as e:
print(f"Error listing tools: {str(e)}")
print("=" * 25 + "\n")
return workloads