1use crate::inference::{Api, Prompt, Provider};
8use anyhow::{Context, Result};
9use reqwest::blocking::Client;
10use serde::{Serialize, Deserialize, Serializer, Deserializer};
11use std::time::{Duration, Instant};
12use std::sync::Arc;
13
14#[derive(Debug, Serialize, Deserialize)]
16pub struct InferenceResponse {
17 #[serde(serialize_with = "serialize_arc_str")]
18 #[serde(deserialize_with = "deserialize_arc_str")]
19 text: Arc<str>,
20 input_tokens: u32,
21 output_tokens: u32,
22 total_tokens: u32,
23 response_time_ms: u128,
24}
25
26impl InferenceResponse {
27 pub fn text(&self) -> Arc<str> {
28 self.text.clone()
29 }
30 pub fn input_tokens(&self) -> u32 {
31 self.input_tokens
32 }
33 pub fn output_tokens(&self) -> u32 {
34 self.output_tokens
35 }
36 pub fn total_tokens(&self) -> u32 {
37 self.total_tokens
38 }
39 pub fn response_time_ms(&self) -> u128 {
40 self.response_time_ms
41 }
42
43 pub fn new_for_tests(
45 text: Arc<str>,
46 input_tokens: u32,
47 output_tokens: u32,
48 total_tokens: u32,
49 response_time_ms: u128,
50 ) -> Self {
51 Self {
52 text,
53 input_tokens,
54 output_tokens,
55 total_tokens,
56 response_time_ms,
57 }
58 }
59}
60
61#[derive(Serialize, Deserialize)]
63pub struct OpenAiMessage {
64 pub role: String,
65 #[serde(serialize_with = "serialize_arc_str")]
66 #[serde(deserialize_with = "deserialize_arc_str")]
67 pub content: Arc<str>,
68}
69
70#[derive(Serialize)]
71pub struct OpenAiRequest {
72 pub model: String,
73 pub messages: Vec<OpenAiMessage>,
74 pub temperature: f64,
75 pub max_tokens: u32,
76}
77
78#[derive(Deserialize)]
79pub struct OpenAiUsage {
80 pub prompt_tokens: u32,
81 pub completion_tokens: u32,
82 pub total_tokens: u32,
83}
84
85#[derive(Deserialize)]
86pub struct OpenAiChoice {
87 pub message: OpenAiMessage,
88}
89
90#[derive(Deserialize)]
91pub struct OpenAiResponse {
92 pub choices: Vec<OpenAiChoice>,
93 pub usage: OpenAiUsage,
94}
95
96#[derive(Debug, Serialize)]
98struct OllamaChatRequest {
99 pub model: String,
100 pub messages: Vec<OllamaMessage>,
101 pub stream: bool,
102 pub options: OllamaOptions,
103}
104
105#[derive(Debug, Serialize)]
106pub struct OllamaMessage {
107 pub role: String,
108 #[serde(serialize_with = "serialize_arc_str")]
109 pub content: Arc<str>,
110}
111
112#[derive(Debug, Serialize)]
113pub struct OllamaOptions {
114 pub temperature: f64,
115 pub num_predict: u32,
116}
117
118#[derive(Debug, Deserialize)]
120pub struct OllamaChatResponse {
121 pub message: OllamaMessageResponse,
122 #[serde(default)]
123 pub prompt_eval_count: u32,
124 #[serde(default)]
125 pub eval_count: u32,
126 #[serde(default)]
127 pub total_duration: u64,
128}
129
130#[derive(Debug, Deserialize)]
131pub struct OllamaMessageResponse {
132 #[serde(deserialize_with = "deserialize_arc_str")]
133 pub content: Arc<str>,
134}
135
136#[derive(Debug)]
138pub struct HttpResponse {
139 pub status: u16,
140 pub body: String,
141}
142
143pub trait InferenceHttpClient {
147 fn post_json(&self, url: &str, body: &serde_json::Value) -> Result<HttpResponse>;
148}
149
150pub struct ReqwestInferenceClient {
152 inner: Client,
153}
154
155impl ReqwestInferenceClient {
156 fn new(provider: &Provider) -> Result<Self> {
157 let inner = Client::builder()
158 .connect_timeout(Duration::from_millis(provider.connection_timeout_ms))
159 .timeout(Duration::from_millis(provider.request_timeout_ms))
160 .build()
161 .with_context(|| {
162 format!(
163 "Failed to create HTTP client with timeout configuration for {} provider",
164 provider.api
165 )
166 })?;
167 Ok(Self { inner })
168 }
169}
170
171impl InferenceHttpClient for ReqwestInferenceClient {
172 fn post_json(&self, url: &str, body: &serde_json::Value) -> Result<HttpResponse> {
173 let resp = self
174 .inner
175 .post(url)
176 .header("Content-Type", "application/json")
177 .json(body)
178 .send()
179 .with_context(|| format!("Failed to send HTTP request to {}", url))?;
180
181 let status = resp.status().as_u16();
182 let body = resp
183 .text()
184 .unwrap_or_else(|_| "Failed to read response body".to_string());
185
186 Ok(HttpResponse { status, body })
187 }
188}
189
190pub fn send_request(provider: &Provider, prompt: &Prompt) -> Result<InferenceResponse> {
192 let client = ReqwestInferenceClient::new(provider)?;
193 send_request_with_client(provider, prompt, &client)
194}
195
196pub fn send_request_with_client(
198 provider: &Provider,
199 prompt: &Prompt,
200 client: &impl InferenceHttpClient,
201) -> Result<InferenceResponse> {
202 let start_time = Instant::now();
203
204 match provider.api {
205 Api::Openai => {
206 let request_body = OpenAiRequest {
207 model: provider.model.clone(),
208 messages: vec![
209 OpenAiMessage {
210 role: "system".into(),
211 content: prompt.system().clone(),
212 },
213 OpenAiMessage {
214 role: "user".into(),
215 content: prompt.user().clone(),
216 },
217 ],
218 temperature: provider.temperature,
219 max_tokens: provider.max_tokens,
220 };
221
222 let url = format!("{}/chat/completions", provider.url);
223 let raw = client.post_json(&url, &serde_json::to_value(&request_body)?)?;
224
225 if !(200..300).contains(&raw.status) {
226 anyhow::bail!(
227 "HTTP error {} from {} provider (model: {}, url: {}): {}",
228 raw.status,
229 provider.api,
230 provider.model,
231 provider.url,
232 raw.body
233 );
234 }
235
236 let response: OpenAiResponse = serde_json::from_str(&raw.body).with_context(|| {
237 format!(
238 "Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.",
239 provider.api, provider.model, provider.url
240 )
241 })?;
242
243 let response_time = start_time.elapsed().as_millis();
244
245 Ok(InferenceResponse::new_for_tests(
246 response.choices[0].message.content.clone(),
247 response.usage.prompt_tokens,
248 response.usage.completion_tokens,
249 response.usage.total_tokens,
250 response_time,
251 ))
252 }
253
254 Api::Ollama => {
255 let request_body = OllamaChatRequest {
256 model: provider.model.clone(),
257 messages: vec![
258 OllamaMessage {
259 role: "system".into(),
260 content: prompt.system().clone(),
261 },
262 OllamaMessage {
263 role: "user".into(),
264 content: prompt.user().clone(),
265 },
266 ],
267 stream: false,
268 options: OllamaOptions {
269 temperature: provider.temperature,
270 num_predict: provider.max_tokens,
271 },
272 };
273
274 let url = format!("{}/api/chat", provider.url);
275 let raw = client.post_json(&url, &serde_json::to_value(&request_body)?)?;
276
277 if !(200..300).contains(&raw.status) {
278 anyhow::bail!(
279 "HTTP error {} from {} provider (model: {}, url: {}): {}",
280 raw.status,
281 provider.api,
282 provider.model,
283 provider.url,
284 raw.body
285 );
286 }
287
288 let response: OllamaChatResponse = serde_json::from_str(&raw.body).with_context(|| {
289 format!(
290 "Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.",
291 provider.api, provider.model, provider.url
292 )
293 })?;
294
295 let response_time = start_time.elapsed().as_millis();
296
297 Ok(InferenceResponse::new_for_tests(
298 response.message.content,
299 response.prompt_eval_count,
300 response.eval_count,
301 response.prompt_eval_count + response.eval_count,
302 response_time,
303 ))
304 }
305 }
306}
307
308fn serialize_arc_str<S>(arc_str: &Arc<str>, serializer: S) -> Result<S::Ok, S::Error>
309where
310 S: Serializer,
311{
312 serializer.serialize_str(arc_str)
313}
314
315fn deserialize_arc_str<'de, D>(deserializer: D) -> Result<Arc<str>, D::Error>
316where
317 D: Deserializer<'de>,
318{
319 let s = String::deserialize(deserializer)?;
320 Ok(Arc::from(s))
321}