aimx/inference/
request.rs

1//! Inference request management for AI model providers.
2//!
3//! Provides a minimal blocking interface for sending prompts via configured
4//! [`Provider`]s and returning a unified [`InferenceResponse`]. Supports
5//! OpenAI-compatible and Ollama chat APIs using the [`Prompt`] abstraction.
6
7use crate::inference::{Api, Prompt, Provider};
8use anyhow::{Context, Result};
9use reqwest::blocking::Client;
10use serde::{Serialize, Deserialize, Serializer, Deserializer};
11use std::time::{Duration, Instant};
12use std::sync::Arc;
13
14/// Unified response for a single provider request, including text and token usage.
15#[derive(Debug, Serialize, Deserialize)]
16pub struct InferenceResponse {
17    #[serde(serialize_with = "serialize_arc_str")]
18    #[serde(deserialize_with = "deserialize_arc_str")]
19    text: Arc<str>,
20    input_tokens: u32,
21    output_tokens: u32,
22    total_tokens: u32,
23    response_time_ms: u128,
24}
25
26impl InferenceResponse {
27    pub fn text(&self) -> Arc<str> {
28        self.text.clone()
29    }
30    pub fn input_tokens(&self) -> u32 {
31        self.input_tokens
32    }
33    pub fn output_tokens(&self) -> u32 {
34        self.output_tokens
35    }
36    pub fn total_tokens(&self) -> u32 {
37        self.total_tokens
38    }
39    pub fn response_time_ms(&self) -> u128 {
40        self.response_time_ms
41    }
42
43    /// Public helper primarily intended for tests and examples.
44    pub fn new_for_tests(
45        text: Arc<str>,
46        input_tokens: u32,
47        output_tokens: u32,
48        total_tokens: u32,
49        response_time_ms: u128,
50    ) -> Self {
51        Self {
52            text,
53            input_tokens,
54            output_tokens,
55            total_tokens,
56            response_time_ms,
57        }
58    }
59}
60
61// OpenAI API structures
62#[derive(Serialize, Deserialize)]
63pub struct OpenAiMessage {
64    pub role: String,
65    #[serde(serialize_with = "serialize_arc_str")]
66    #[serde(deserialize_with = "deserialize_arc_str")]
67    pub content: Arc<str>,
68}
69
70#[derive(Serialize)]
71pub struct OpenAiRequest {
72    pub model: String,
73    pub messages: Vec<OpenAiMessage>,
74    pub temperature: f64,
75    pub max_tokens: u32,
76}
77
78#[derive(Deserialize)]
79pub struct OpenAiUsage {
80    pub prompt_tokens: u32,
81    pub completion_tokens: u32,
82    pub total_tokens: u32,
83}
84
85#[derive(Deserialize)]
86pub struct OpenAiChoice {
87    pub message: OpenAiMessage,
88}
89
90#[derive(Deserialize)]
91pub struct OpenAiResponse {
92    pub choices: Vec<OpenAiChoice>,
93    pub usage: OpenAiUsage,
94}
95
96// Ollama API structures - Chat endpoint
97#[derive(Debug, Serialize)]
98struct OllamaChatRequest {
99    pub model: String,
100    pub messages: Vec<OllamaMessage>,
101    pub stream: bool,
102    pub options: OllamaOptions,
103}
104
105#[derive(Debug, Serialize)]
106pub struct OllamaMessage {
107    pub role: String,
108    #[serde(serialize_with = "serialize_arc_str")]
109    pub content: Arc<str>,
110}
111
112#[derive(Debug, Serialize)]
113pub struct OllamaOptions {
114    pub temperature: f64,
115    pub num_predict: u32,
116}
117
118/// Response structure from Ollama chat API
119#[derive(Debug, Deserialize)]
120pub struct OllamaChatResponse {
121    pub message: OllamaMessageResponse,
122    #[serde(default)]
123    pub prompt_eval_count: u32,
124    #[serde(default)]
125    pub eval_count: u32,
126    #[serde(default)]
127    pub total_duration: u64,
128}
129
130#[derive(Debug, Deserialize)]
131pub struct OllamaMessageResponse {
132    #[serde(deserialize_with = "deserialize_arc_str")]
133    pub content: Arc<str>,
134}
135
136/// Internal, provider-agnostic HTTP response used by the inference client abstraction.
137#[derive(Debug)]
138pub struct HttpResponse {
139    pub status: u16,
140    pub body: String,
141}
142
143/// Minimal HTTP client abstraction for inference requests.
144///
145/// This keeps `send_request` testable without depending directly on `reqwest` in tests.
146pub trait InferenceHttpClient {
147    fn post_json(&self, url: &str, body: &serde_json::Value) -> Result<HttpResponse>;
148}
149
150/// Default blocking HTTP client implementation backed by `reqwest`.
151pub struct ReqwestInferenceClient {
152    inner: Client,
153}
154
155impl ReqwestInferenceClient {
156    fn new(provider: &Provider) -> Result<Self> {
157        let inner = Client::builder()
158            .connect_timeout(Duration::from_millis(provider.connection_timeout_ms))
159            .timeout(Duration::from_millis(provider.request_timeout_ms))
160            .build()
161            .with_context(|| {
162                format!(
163                    "Failed to create HTTP client with timeout configuration for {} provider",
164                    provider.api
165                )
166            })?;
167        Ok(Self { inner })
168    }
169}
170
171impl InferenceHttpClient for ReqwestInferenceClient {
172    fn post_json(&self, url: &str, body: &serde_json::Value) -> Result<HttpResponse> {
173        let resp = self
174            .inner
175            .post(url)
176            .header("Content-Type", "application/json")
177            .json(body)
178            .send()
179            .with_context(|| format!("Failed to send HTTP request to {}", url))?;
180
181        let status = resp.status().as_u16();
182        let body = resp
183            .text()
184            .unwrap_or_else(|_| "Failed to read response body".to_string());
185
186        Ok(HttpResponse { status, body })
187    }
188}
189
190/// Public entry point: send a request using the default blocking HTTP client.
191pub fn send_request(provider: &Provider, prompt: &Prompt) -> Result<InferenceResponse> {
192    let client = ReqwestInferenceClient::new(provider)?;
193    send_request_with_client(provider, prompt, &client)
194}
195
196/// Internal helper: core logic parameterized over an [`InferenceHttpClient`].
197pub fn send_request_with_client(
198    provider: &Provider,
199    prompt: &Prompt,
200    client: &impl InferenceHttpClient,
201) -> Result<InferenceResponse> {
202    let start_time = Instant::now();
203
204    match provider.api {
205        Api::Openai => {
206            let request_body = OpenAiRequest {
207                model: provider.model.clone(),
208                messages: vec![
209                    OpenAiMessage {
210                        role: "system".into(),
211                        content: prompt.system().clone(),
212                    },
213                    OpenAiMessage {
214                        role: "user".into(),
215                        content: prompt.user().clone(),
216                    },
217                ],
218                temperature: provider.temperature,
219                max_tokens: provider.max_tokens,
220            };
221
222            let url = format!("{}/chat/completions", provider.url);
223            let raw = client.post_json(&url, &serde_json::to_value(&request_body)?)?;
224
225            if !(200..300).contains(&raw.status) {
226                anyhow::bail!(
227                    "HTTP error {} from {} provider (model: {}, url: {}): {}",
228                    raw.status,
229                    provider.api,
230                    provider.model,
231                    provider.url,
232                    raw.body
233                );
234            }
235
236            let response: OpenAiResponse = serde_json::from_str(&raw.body).with_context(|| {
237                format!(
238                    "Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.",
239                    provider.api, provider.model, provider.url
240                )
241            })?;
242
243            let response_time = start_time.elapsed().as_millis();
244
245            Ok(InferenceResponse::new_for_tests(
246                response.choices[0].message.content.clone(),
247                response.usage.prompt_tokens,
248                response.usage.completion_tokens,
249                response.usage.total_tokens,
250                response_time,
251            ))
252        }
253
254        Api::Ollama => {
255            let request_body = OllamaChatRequest {
256                model: provider.model.clone(),
257                messages: vec![
258                    OllamaMessage {
259                        role: "system".into(),
260                        content: prompt.system().clone(),
261                    },
262                    OllamaMessage {
263                        role: "user".into(),
264                        content: prompt.user().clone(),
265                    },
266                ],
267                stream: false,
268                options: OllamaOptions {
269                    temperature: provider.temperature,
270                    num_predict: provider.max_tokens,
271                },
272            };
273
274            let url = format!("{}/api/chat", provider.url);
275            let raw = client.post_json(&url, &serde_json::to_value(&request_body)?)?;
276
277            if !(200..300).contains(&raw.status) {
278                anyhow::bail!(
279                    "HTTP error {} from {} provider (model: {}, url: {}): {}",
280                    raw.status,
281                    provider.api,
282                    provider.model,
283                    provider.url,
284                    raw.body
285                );
286            }
287
288            let response: OllamaChatResponse = serde_json::from_str(&raw.body).with_context(|| {
289                format!(
290                    "Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.",
291                    provider.api, provider.model, provider.url
292                )
293            })?;
294
295            let response_time = start_time.elapsed().as_millis();
296
297            Ok(InferenceResponse::new_for_tests(
298                response.message.content,
299                response.prompt_eval_count,
300                response.eval_count,
301                response.prompt_eval_count + response.eval_count,
302                response_time,
303            ))
304        }
305    }
306}
307
308fn serialize_arc_str<S>(arc_str: &Arc<str>, serializer: S) -> Result<S::Ok, S::Error>
309where
310    S: Serializer,
311{
312    serializer.serialize_str(arc_str)
313}
314
315fn deserialize_arc_str<'de, D>(deserializer: D) -> Result<Arc<str>, D::Error>
316where
317    D: Deserializer<'de>,
318{
319    let s = String::deserialize(deserializer)?;
320    Ok(Arc::from(s))
321}