|
16 | 16 | logger = structlog.get_logger("codegate") |
17 | 17 |
|
18 | 18 |
|
19 | | -# call_directly is a function to call the model directly bypassing codegate |
20 | | -def call_directly(url: str, headers: dict, data: dict) -> Optional[requests.Response]: |
21 | | - try: |
22 | | - headers["Content-Type"] = "application/json" |
23 | | - stream = data.get("stream", False) |
24 | | - response = requests.post(url, headers=headers, json=data, stream=stream) |
25 | | - response.raise_for_status() |
26 | | - return response |
27 | | - except Exception as e: |
28 | | - logger.error(f"Error making direct request to {url}: {str(e)}") |
29 | | - return None |
30 | | - |
31 | | - |
32 | 19 | class CodegateTestRunner: |
33 | 20 | def __init__(self): |
34 | 21 | self.requester_factory = RequesterFactory() |
35 | 22 | self.failed_tests = [] # Track failed tests |
36 | 23 |
|
37 | | - def call_codegate( |
| 24 | + def call_provider( |
38 | 25 | self, url: str, headers: dict, data: dict, provider: str, method: str = "POST" |
39 | 26 | ) -> Optional[requests.Response]: |
40 | 27 | logger.debug(f"Creating requester for provider: {provider}") |
@@ -146,21 +133,23 @@ def replacement(match): |
146 | 133 | async def run_test(self, test: dict, test_headers: dict) -> bool: |
147 | 134 | test_name = test["name"] |
148 | 135 | data = json.loads(test["data"]) |
| 136 | + codegate_url = test["url"] |
| 137 | + direct_provider_url = test.get(CodeGateEnrichment.KEY)["provider_url"] |
149 | 138 | streaming = data.get("stream", False) |
150 | 139 | provider = test["provider"] |
151 | 140 | logger.info(f"Starting test: {test_name}") |
152 | 141 |
|
153 | 142 | # Call Codegate |
154 | | - response = self.call_codegate(test["url"], test_headers, data, provider) |
| 143 | + response = self.call_provider(codegate_url, test_headers, data, provider) |
155 | 144 | if not response: |
156 | 145 | logger.error(f"Test {test_name} failed: No response received") |
157 | 146 | return False |
158 | 147 |
|
159 | 148 | # Call model directly if specified |
160 | 149 | direct_response = None |
161 | 150 | if test.get(CodeGateEnrichment.KEY) is not None: |
162 | | - direct_response = call_directly( |
163 | | - test.get(CodeGateEnrichment.KEY)["provider_url"], test_headers, data |
| 151 | + direct_response = self.call_provider( |
| 152 | + direct_provider_url, test_headers, data, "not-codegate" |
164 | 153 | ) |
165 | 154 | if not direct_response: |
166 | 155 | logger.error(f"Test {test_name} failed: No direct response received") |
@@ -412,6 +401,7 @@ async def main(): |
412 | 401 | # Exit with status code 1 if any tests failed |
413 | 402 | if not all_tests_passed: |
414 | 403 | sys.exit(1) |
| 404 | + logger.info("All tests passed") |
415 | 405 |
|
416 | 406 |
|
417 | 407 | if __name__ == "__main__": |
|
0 commit comments