-
Notifications
You must be signed in to change notification settings - Fork 300
Closed
Labels
out of scopeRequests which are not related to OpenAI APIRequests which are not related to OpenAI APIwontfixThis will not be worked onThis will not be worked on
Description
Hey I'm trying to use async-openai
with axum
and open source LLMs through perplexity.ai
in my test.
Basically my endpoint would route the request to OpenAI API or an OpenAI API like API changing the URL of the API based on the given model like gpt-4
would go to openai and mistralai/whatever
would go to MODEL_URL
Getting a 404. Not sure if I'm doing something wrong or this use case is implemented?
These are my env var
MODEL_API_KEY="pplx-..."
MODEL_URL="https://api.perplexity.ai/chat/completions"
My code:
use async_openai::Client;
use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest};
use axum::{
extract::{Extension, Json, Path, State},
http::StatusCode,
response::IntoResponse,
response::Json as JsonResponse,
};
use async_stream::try_stream;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::Stream;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::error::Error;
use std::io::{stdout, Write};
use tokio::sync::broadcast::Receiver;
use tokio_stream::wrappers::BroadcastStream;
use url::Url;
fn extract_base_url(model_url: &str) -> Result<String, url::ParseError> {
let url = Url::parse(model_url)?;
let base_url = url.join("/")?;
Ok(base_url.as_str().to_string())
}
// copied from https://github.com/tokio-rs/axum/discussions/1670
pub async fn stream_chat_handler(
Json(request): Json<CreateChatCompletionRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, String)> {
let model_name = &request.model;
let model_url = std::env::var("MODEL_URL")
.unwrap_or_else(|_| String::from("http://localhost:8000/v1/chat/completions"));
let base_url = extract_base_url(&model_url).unwrap_or_else(|_| model_url);
let (api_key, base_url) = if model_name.contains("/") {
// Open Source model
(std::env::var("MODEL_API_KEY").unwrap_or_default(), base_url)
} else {
// OpenAI model
(
std::env::var("OPENAI_API_KEY").unwrap_or_default(),
String::from("https://api.openai.com"),
)
};
let client = Client::with_config(
OpenAIConfig::new()
.with_api_key(&api_key)
.with_api_base(&base_url),
);
let mut stream = client
.chat()
.create_stream(request)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let sse_stream = try_stream! {
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
for chat_choice in response.choices.iter() {
if let Some(ref content) = chat_choice.delta.content {
yield Event::default().data(content.clone());
}
}
}
Err(err) => {
println!("Error: {}", err);
tracing::error!("Error: {}", err);
}
}
}
};
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{self, Request};
use axum::response::Response;
use axum::routing::post;
use axum::Router;
use dotenv::dotenv;
use serde_json::json;
use std::convert::Infallible;
use tower::{Service, ServiceExt};
use tower_http::trace::TraceLayer;
fn app() -> Router {
Router::new()
.route("/chat/completions", post(stream_chat_handler))
.layer(TraceLayer::new_for_http())
}
#[tokio::test]
async fn test_stream_chat_handler() {
dotenv().ok();
let app = app();
let chat_input = json!({
"model": "mistralai/mixtral-8x7b-instruct",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
]
});
let request = Request::builder()
.method(http::Method::POST)
.uri("/chat/completions")
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
.body(Body::from(json!(chat_input).to_string()))
.unwrap();
let response = app.clone().oneshot(request).await.unwrap();
assert_eq!(
response.status(),
StatusCode::OK,
"response: {:?}",
hyper::body::to_bytes(response.into_body()).await.unwrap()
);
let response = hyper::body::to_bytes(response.into_body()).await.unwrap();
println!("response: {:?}", response);
}
}
Invalid status code: 404 Not Found
Any help appreciated 🙏
schneiderfelipe
Metadata
Metadata
Assignees
Labels
out of scopeRequests which are not related to OpenAI APIRequests which are not related to OpenAI APIwontfixThis will not be worked onThis will not be worked on