aimx/inference/provider.rs
1//! Inference provider configuration for external AI APIs used by AIMX.
2//! Defines serializable types for selecting APIs, endpoints, models,
3//! and runtime limits so hosts can configure inference behavior.
4
5use crate::inference::Capability;
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9/// AI inference APIs supported by AIMX configuration.
10#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
11pub enum Api {
12 /// Ollama API for local model inference.
13 Ollama,
14 /// OpenAI-compatible chat completions API.
15 Openai,
16}
17
18impl fmt::Display for Api {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 match self {
21 Api::Ollama => write!(f, "ollama"),
22 Api::Openai => write!(f, "openai"),
23 }
24 }
25}
26
27impl Api {
28 /// Construct from identifier: "ollama" => `Ollama`, other => `Openai`.
29 pub fn new(api: &str) -> Self {
30 match api {
31 "ollama" => Api::Ollama,
32 _ => Api::Openai,
33 }
34 }
35}
36
37/// Configuration for an AI inference provider used by AIMX.
38#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
39pub struct Provider {
40 /// API implementation.
41 pub api: Api,
42 /// Base URL for the API.
43 pub url: String,
44 /// API key or token (empty if not required).
45 pub key: String,
46 /// Model identifier for the provider.
47 pub model: String,
48 /// Abstract capability class of the model.
49 pub capability: Capability,
50 /// Sampling temperature.
51 pub temperature: f64,
52 /// Maximum tokens to generate.
53 pub max_tokens: u32,
54 /// Context window size in tokens.
55 pub context_length: u32,
56 /// Connection timeout in milliseconds.
57 #[serde(default = "default_connection_timeout")]
58 pub connection_timeout_ms: u64,
59 /// Request timeout in milliseconds.
60 #[serde(default = "default_request_timeout")]
61 pub request_timeout_ms: u64,
62}
63
64/// Default connection timeout (milliseconds).
65fn default_connection_timeout() -> u64 {
66 30000
67}
68
69/// Default request timeout (milliseconds).
70fn default_request_timeout() -> u64 {
71 120000
72}