|
1 | | -use anyhow::Result; |
2 | | -use reqwest::Client; |
3 | | -use serde_json::Value; |
| 1 | +#[path = "endpoint/inference.rs"] |
| 2 | +mod inference; |
| 3 | +#[path = "endpoint/models.rs"] |
| 4 | +mod models; |
4 | 5 |
|
5 | | -use crate::core::offline; |
6 | | - |
7 | | -pub(super) async fn test_model_inference( |
8 | | - client: &Client, |
9 | | - base_url: &str, |
10 | | - model_name: &str, |
11 | | - endpoint_type: &str, |
12 | | -) -> Result<String> { |
13 | | - let system_msg = "You are a code reviewer. Respond with a single JSON object."; |
14 | | - let user_msg = "Review this code change:\n+fn add(a: i32, b: i32) -> i32 { a + b }\nRespond with: {\"ok\": true}"; |
15 | | - |
16 | | - let messages = serde_json::json!([ |
17 | | - {"role": "system", "content": system_msg}, |
18 | | - {"role": "user", "content": user_msg} |
19 | | - ]); |
20 | | - |
21 | | - if endpoint_type == "ollama" { |
22 | | - let url = format!("{}/api/chat", base_url); |
23 | | - let body = serde_json::json!({ |
24 | | - "model": model_name, |
25 | | - "messages": messages, |
26 | | - "stream": false, |
27 | | - "options": {"num_predict": 50} |
28 | | - }); |
29 | | - |
30 | | - let resp = client |
31 | | - .post(&url) |
32 | | - .json(&body) |
33 | | - .send() |
34 | | - .await |
35 | | - .map_err(|e| anyhow::anyhow!("Request failed: {}", e))?; |
36 | | - |
37 | | - if !resp.status().is_success() { |
38 | | - let status = resp.status(); |
39 | | - let body = resp.text().await.unwrap_or_default(); |
40 | | - anyhow::bail!("HTTP {} - {}", status, body); |
41 | | - } |
42 | | - |
43 | | - let text = resp.text().await?; |
44 | | - parse_ollama_response_content(&text) |
45 | | - } else { |
46 | | - let url = format!("{}/v1/chat/completions", base_url); |
47 | | - let body = serde_json::json!({ |
48 | | - "model": model_name, |
49 | | - "messages": messages, |
50 | | - "max_tokens": 50, |
51 | | - "temperature": 0.1 |
52 | | - }); |
53 | | - |
54 | | - let resp = client |
55 | | - .post(&url) |
56 | | - .json(&body) |
57 | | - .send() |
58 | | - .await |
59 | | - .map_err(|e| anyhow::anyhow!("Request failed: {}", e))?; |
60 | | - |
61 | | - if !resp.status().is_success() { |
62 | | - let status = resp.status(); |
63 | | - let body = resp.text().await.unwrap_or_default(); |
64 | | - anyhow::bail!("HTTP {} - {}", status, body); |
65 | | - } |
66 | | - |
67 | | - let text = resp.text().await?; |
68 | | - parse_openai_response_content(&text) |
69 | | - } |
70 | | -} |
71 | | - |
72 | | -pub(super) fn estimate_tokens(text: &str) -> usize { |
73 | | - (text.len() / 4).max(1) |
74 | | -} |
75 | | - |
76 | | -pub(super) fn parse_openai_models(body: &str, models: &mut Vec<offline::LocalModel>) { |
77 | | - if let Ok(value) = serde_json::from_str::<Value>(body) { |
78 | | - if let Some(data) = value.get("data").and_then(|d| d.as_array()) { |
79 | | - for model in data { |
80 | | - if let Some(id) = model.get("id").and_then(|i| i.as_str()) { |
81 | | - models.push(offline::LocalModel { |
82 | | - name: id.to_string(), |
83 | | - size_mb: 0, |
84 | | - quantization: None, |
85 | | - modified_at: None, |
86 | | - family: None, |
87 | | - parameter_size: None, |
88 | | - }); |
89 | | - } |
90 | | - } |
91 | | - } |
92 | | - } |
93 | | -} |
94 | | - |
95 | | -fn parse_ollama_response_content(text: &str) -> Result<String> { |
96 | | - let value: Value = serde_json::from_str(text)?; |
97 | | - Ok(value |
98 | | - .get("message") |
99 | | - .and_then(|message| message.get("content")) |
100 | | - .and_then(|content| content.as_str()) |
101 | | - .unwrap_or("") |
102 | | - .to_string()) |
103 | | -} |
104 | | - |
105 | | -fn parse_openai_response_content(text: &str) -> Result<String> { |
106 | | - let value: Value = serde_json::from_str(text)?; |
107 | | - Ok(value |
108 | | - .get("choices") |
109 | | - .and_then(|choices| choices.as_array()) |
110 | | - .and_then(|choices| choices.first()) |
111 | | - .and_then(|choice| choice.get("message")) |
112 | | - .and_then(|message| message.get("content")) |
113 | | - .and_then(|content| content.as_str()) |
114 | | - .unwrap_or("") |
115 | | - .to_string()) |
116 | | -} |
117 | | - |
118 | | -#[cfg(test)] |
119 | | -mod tests { |
120 | | - use super::*; |
121 | | - |
122 | | - #[test] |
123 | | - fn test_estimate_tokens() { |
124 | | - assert_eq!(estimate_tokens(""), 1); |
125 | | - assert_eq!(estimate_tokens("abcd"), 1); |
126 | | - assert_eq!(estimate_tokens("abcdefgh"), 2); |
127 | | - assert_eq!(estimate_tokens("a]"), 1); |
128 | | - } |
129 | | - |
130 | | - #[test] |
131 | | - fn test_estimate_tokens_longer_text() { |
132 | | - let text = "This is a longer response with several words in it for testing."; |
133 | | - let tokens = estimate_tokens(text); |
134 | | - assert!(tokens > 10); |
135 | | - assert!(tokens < 30); |
136 | | - } |
137 | | - |
138 | | - #[test] |
139 | | - fn test_parse_openai_models_valid() { |
140 | | - let body = r#"{"data":[{"id":"gpt-3.5-turbo"},{"id":"codellama-7b"}]}"#; |
141 | | - let mut models = Vec::new(); |
142 | | - parse_openai_models(body, &mut models); |
143 | | - assert_eq!(models.len(), 2); |
144 | | - assert_eq!(models[0].name, "gpt-3.5-turbo"); |
145 | | - assert_eq!(models[1].name, "codellama-7b"); |
146 | | - } |
147 | | - |
148 | | - #[test] |
149 | | - fn test_parse_openai_models_empty() { |
150 | | - let body = r#"{"data":[]}"#; |
151 | | - let mut models = Vec::new(); |
152 | | - parse_openai_models(body, &mut models); |
153 | | - assert!(models.is_empty()); |
154 | | - } |
155 | | - |
156 | | - #[test] |
157 | | - fn test_parse_openai_models_invalid_json() { |
158 | | - let body = "not json"; |
159 | | - let mut models = Vec::new(); |
160 | | - parse_openai_models(body, &mut models); |
161 | | - assert!(models.is_empty()); |
162 | | - } |
163 | | - |
164 | | - #[test] |
165 | | - fn test_parse_openai_models_missing_data() { |
166 | | - let body = r#"{"models":[]}"#; |
167 | | - let mut models = Vec::new(); |
168 | | - parse_openai_models(body, &mut models); |
169 | | - assert!(models.is_empty()); |
170 | | - } |
171 | | - |
172 | | - #[test] |
173 | | - fn test_parse_openai_models_missing_id() { |
174 | | - let body = r#"{"data":[{"name":"model-1"}]}"#; |
175 | | - let mut models = Vec::new(); |
176 | | - parse_openai_models(body, &mut models); |
177 | | - assert!(models.is_empty()); |
178 | | - } |
179 | | - |
180 | | - #[test] |
181 | | - fn test_test_model_inference_ollama_parse() { |
182 | | - let json = r#"{"message":{"role":"assistant","content":"{\"ok\": true}"}}"#; |
183 | | - let content = parse_ollama_response_content(json).unwrap(); |
184 | | - assert_eq!(content, "{\"ok\": true}"); |
185 | | - } |
186 | | - |
187 | | - #[test] |
188 | | - fn test_test_model_inference_openai_parse() { |
189 | | - let json = r#"{"choices":[{"message":{"content":"{\"ok\": true}"}}]}"#; |
190 | | - let content = parse_openai_response_content(json).unwrap(); |
191 | | - assert_eq!(content, "{\"ok\": true}"); |
192 | | - } |
193 | | - |
194 | | - #[test] |
195 | | - fn test_test_model_inference_empty_choices() { |
196 | | - let json = r#"{"choices":[]}"#; |
197 | | - let content = parse_openai_response_content(json).unwrap(); |
198 | | - assert_eq!(content, ""); |
199 | | - } |
200 | | -} |
| 6 | +pub(super) use inference::{estimate_tokens, test_model_inference}; |
| 7 | +pub(super) use models::parse_openai_models; |
0 commit comments