use crate::{ config::ProviderConfig, message::{Message, Role}, providers::{Model, Provider, ProviderError}, }; use async_trait::async_trait; use futures::{stream::BoxStream, StreamExt}; use serde::{Deserialize, Serialize}; pub struct OpenAICompatibleProvider { config: ProviderConfig, client: reqwest::Client, } impl OpenAICompatibleProvider { pub fn new(config: ProviderConfig) -> Self { Self { config, client: reqwest::Client::new(), } } fn api_url(&self, path: &str) -> String { let base = self.config.base_url.trim_end_matches('/'); format!("{}{}", base, path) } } #[async_trait] impl Provider for OpenAICompatibleProvider { fn name(&self) -> &str { &self.config.name } fn config(&self) -> &ProviderConfig { &self.config } async fn list_models(&self) -> Result, ProviderError> { let response = self .client .get(self.api_url("/v1/models")) .header("Authorization", format!("Bearer {}", self.config.api_key)) .send() .await?; if !response.status().is_success() { let text = response.text().await.unwrap_or_default(); return Err(ProviderError::Api { message: text }); } let data: ModelsResponse = response.json().await?; Ok(data .data .into_iter() .map(|m| Model { id: m.id.clone(), name: m.id, }) .collect()) } async fn chat_stream( &self, model: &str, messages: &[Message], ) -> Result>, ProviderError> { let body = ChatCompletionRequest { model: model.to_string(), messages: messages.iter().map(|m| ChatMessage::from(m.clone())).collect(), stream: true, }; let response = self .client .post(self.api_url("/v1/chat/completions")) .header("Authorization", format!("Bearer {}", self.config.api_key)) .header("Content-Type", "application/json") .json(&body) .send() .await?; if !response.status().is_success() { let text = response.text().await.unwrap_or_default(); return Err(ProviderError::Api { message: text }); } let stream = response.bytes_stream().map(move |result| { result.map_err(ProviderError::Http) }); let parsed = stream .flat_map(|result| { futures::stream::iter(match result { Ok(bytes) => parse_sse_chunk(&bytes), Err(e) => vec![Err(e)], }) }) .boxed(); Ok(parsed) } } fn parse_sse_chunk(bytes: &[u8]) -> Vec> { let text = String::from_utf8_lossy(bytes); let mut results = Vec::new(); for line in text.lines() { let line = line.trim(); if !line.starts_with("data: ") { continue; } let data = &line[6..]; if data == "[DONE]" { break; } match serde_json::from_str::(data) { Ok(chunk) => { if let Some(choice) = chunk.choices.first() { if let Some(content) = &choice.delta.content { if !content.is_empty() { results.push(Ok(content.clone())); } } } } Err(e) => { results.push(Err(ProviderError::StreamParse(e.to_string()))); } } } results } #[derive(Debug, Deserialize)] struct ModelsResponse { data: Vec, } #[derive(Debug, Deserialize)] struct ModelData { id: String, } #[derive(Debug, Serialize)] struct ChatCompletionRequest { model: String, messages: Vec, stream: bool, } #[derive(Debug, Serialize, Deserialize, Clone)] struct ChatMessage { role: String, content: String, } impl From for ChatMessage { fn from(msg: Message) -> Self { Self { role: match msg.role { Role::System => "system".to_string(), Role::User => "user".to_string(), Role::Assistant => "assistant".to_string(), }, content: msg.content, } } } #[derive(Debug, Deserialize)] struct ChatCompletionChunk { choices: Vec, } #[derive(Debug, Deserialize)] struct Choice { delta: Delta, } #[derive(Debug, Deserialize)] struct Delta { content: Option, }