Skip to content

Changing base url - Usage with open source LLM - Invalid status code: 404 Not Found #173

@louis030195

Description

@louis030195

Hey I'm trying to use async-openai with axum and open source LLMs through perplexity.ai in my test.

Basically my endpoint would route the request to OpenAI API or an OpenAI API like API changing the URL of the API based on the given model like gpt-4 would go to openai and mistralai/whatever would go to MODEL_URL

Getting a 404. Not sure if I'm doing something wrong or this use case is implemented?

These are my env var

MODEL_API_KEY="pplx-..."
MODEL_URL="https://api.perplexity.ai/chat/completions"

My code:

use async_openai::Client;
use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest};
use axum::{
    extract::{Extension, Json, Path, State},
    http::StatusCode,
    response::IntoResponse,
    response::Json as JsonResponse,
};

use async_stream::try_stream;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures::Stream;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::error::Error;
use std::io::{stdout, Write};
use tokio::sync::broadcast::Receiver;
use tokio_stream::wrappers::BroadcastStream;
use url::Url;

fn extract_base_url(model_url: &str) -> Result<String, url::ParseError> {
    let url = Url::parse(model_url)?;
    let base_url = url.join("/")?;
    Ok(base_url.as_str().to_string())
}
// copied from https://github.com/tokio-rs/axum/discussions/1670

pub async fn stream_chat_handler(
    Json(request): Json<CreateChatCompletionRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, String)> {
    let model_name = &request.model;
    let model_url = std::env::var("MODEL_URL")
        .unwrap_or_else(|_| String::from("http://localhost:8000/v1/chat/completions"));
    let base_url = extract_base_url(&model_url).unwrap_or_else(|_| model_url);
    let (api_key, base_url) = if model_name.contains("/") {
        // Open Source model
        (std::env::var("MODEL_API_KEY").unwrap_or_default(), base_url)
    } else {
        // OpenAI model
        (
            std::env::var("OPENAI_API_KEY").unwrap_or_default(),
            String::from("https://api.openai.com"),
        )
    };
    let client = Client::with_config(
        OpenAIConfig::new()
            .with_api_key(&api_key)
            .with_api_base(&base_url),
    );

    let mut stream = client
        .chat()
        .create_stream(request)
        .await
        .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;

    let sse_stream = try_stream! {
        while let Some(result) = stream.next().await {
            match result {
                Ok(response) => {
                    for chat_choice in response.choices.iter() {
                        if let Some(ref content) = chat_choice.delta.content {
                            yield Event::default().data(content.clone());
                        }
                    }
                }
                Err(err) => {
                    println!("Error: {}", err);
                    tracing::error!("Error: {}", err);
                }
            }
        }
    };

    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::body::Body;
    use axum::http::{self, Request};
    use axum::response::Response;
    use axum::routing::post;
    use axum::Router;
    use dotenv::dotenv;
    use serde_json::json;
    use std::convert::Infallible;
    use tower::{Service, ServiceExt};
    use tower_http::trace::TraceLayer;

    fn app() -> Router {
        Router::new()
            .route("/chat/completions", post(stream_chat_handler))
            .layer(TraceLayer::new_for_http())
    }

    #[tokio::test]
    async fn test_stream_chat_handler() {
        dotenv().ok();
        let app = app();

        let chat_input = json!({
            "model": "mistralai/mixtral-8x7b-instruct",
            "messages": [
                {
                    "role": "system",
                    "content": "You are a helpful assistant."
                },
                {
                    "role": "user",
                    "content": "Hello!"
                }
            ]
        });

        let request = Request::builder()
            .method(http::Method::POST)
            .uri("/chat/completions")
            .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
            .body(Body::from(json!(chat_input).to_string()))
            .unwrap();

        let response = app.clone().oneshot(request).await.unwrap();

        assert_eq!(
            response.status(),
            StatusCode::OK,
            "response: {:?}",
            hyper::body::to_bytes(response.into_body()).await.unwrap()
        );

        let response = hyper::body::to_bytes(response.into_body()).await.unwrap();
        println!("response: {:?}", response);
    }
}

Invalid status code: 404 Not Found

Any help appreciated 🙏

Metadata

Metadata

Assignees

No one assigned

    Labels

    out of scopeRequests which are not related to OpenAI APIwontfixThis will not be worked on

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions