From 2ce434fd86dca1513f73ae1e65534998fdfd87eb Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Mon, 9 Oct 2023 14:50:04 -0700 Subject: [PATCH 1/2] add thin client for azure --- src/openai/azure/__init__.py | 7 +++++ src/openai/azure/_client.py | 47 +++++++++++++++++++++++++++++++++ src/openai/azure/_credential.py | 32 ++++++++++++++++++++++ 3 files changed, 86 insertions(+) create mode 100644 src/openai/azure/__init__.py create mode 100644 src/openai/azure/_client.py create mode 100644 src/openai/azure/_credential.py diff --git a/src/openai/azure/__init__.py b/src/openai/azure/__init__.py new file mode 100644 index 0000000000..a556e54549 --- /dev/null +++ b/src/openai/azure/__init__.py @@ -0,0 +1,7 @@ +from ._client import AzureOpenAIClient, AsyncAzureOpenAIClient + + +__all__ = [ + "AzureOpenAIClient", + "AsyncAzureOpenAIClient", +] \ No newline at end of file diff --git a/src/openai/azure/_client.py b/src/openai/azure/_client.py new file mode 100644 index 0000000000..193484f455 --- /dev/null +++ b/src/openai/azure/_client.py @@ -0,0 +1,47 @@ +import httpx +from typing import Any, Optional, Dict + +from openai import Client, AsyncClient +from ._credential import TokenAuth + + +class AzureOpenAIClient(Client): + + def __init__(self, *args: Any, base_url: str, credential: Optional["TokenCredential"] = None, api_version: str = '2023-09-01-preview', **kwargs: Any): + default_query = kwargs.get('default_query', {}) + default_query.setdefault('api-version', api_version) + kwargs['default_query'] = default_query + self.credential = credential + if credential: + kwargs['api_key'] = 'Placeholder: AAD' + super().__init__(*args, base_url=base_url, **kwargs) + + @property + def auth_headers(self) -> Dict[str, str]: + return {"api-key": self.api_key} + + @property + def custom_auth(self) -> Optional[httpx.Auth]: + if self.credential: + return TokenAuth(self.credential) + + +class AsyncAzureOpenAIClient(AsyncClient): + + def __init__(self, *args: Any, credential: Optional["TokenCredential"] = None, api_version: str = '2023-09-01-preview', **kwargs: Any): + default_query = kwargs.get('default_query', {}) + default_query.setdefault('api-version', api_version) + kwargs['default_query'] = default_query + self.credential = credential + if credential: + kwargs['api_key'] = 'Placeholder: AAD' + super().__init__(*args, **kwargs) + + @property + def auth_headers(self) -> Dict[str, str]: + return {"api-key": self.api_key} + + @property + def custom_auth(self) -> httpx.Auth | None: + if self.credential: + return TokenAuth(self.credential) \ No newline at end of file diff --git a/src/openai/azure/_credential.py b/src/openai/azure/_credential.py new file mode 100644 index 0000000000..fbf3647c17 --- /dev/null +++ b/src/openai/azure/_credential.py @@ -0,0 +1,32 @@ +import time +import asyncio +import httpx +from typing import Any, Generator, AsyncGenerator + + +class TokenAuth(httpx.Auth): + def __init__(self, credential: "TokenCredential") -> None: + self._credential = credential + self._async_lock = asyncio.Lock() + self.cached_token = None + + def sync_get_token(self) -> str: + if not self.cached_token or self.cached_token.expires_on - time.time() < 300: + return self._credential.get_token("https://cognitiveservices.azure.com/.default").token + return self.cached_token.token + + def sync_auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, Any, Any]: + token = self.sync_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + async def async_get_token(self) -> str: + async with self._async_lock: + if not self.cached_token or self.cached_token.expires_on - time.time() < 300: + return (await self._credential.get_token("https://cognitiveservices.azure.com/.default")).token + return self.cached_token.token + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, Any]: + token = await self.async_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request From 8cb6c2a8d9a49fe65e8a089c821a453097c7fdcb Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Mon, 9 Oct 2023 14:51:16 -0700 Subject: [PATCH 2/2] format --- src/openai/azure/__init__.py | 2 +- src/openai/azure/_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/openai/azure/__init__.py b/src/openai/azure/__init__.py index a556e54549..296a7da84d 100644 --- a/src/openai/azure/__init__.py +++ b/src/openai/azure/__init__.py @@ -4,4 +4,4 @@ __all__ = [ "AzureOpenAIClient", "AsyncAzureOpenAIClient", -] \ No newline at end of file +] diff --git a/src/openai/azure/_client.py b/src/openai/azure/_client.py index 193484f455..7f49ec81b2 100644 --- a/src/openai/azure/_client.py +++ b/src/openai/azure/_client.py @@ -44,4 +44,4 @@ def auth_headers(self) -> Dict[str, str]: @property def custom_auth(self) -> httpx.Auth | None: if self.credential: - return TokenAuth(self.credential) \ No newline at end of file + return TokenAuth(self.credential)