Skip to content

Commit 7ad129f

Browse files
few fixes after the refactor (#9)
1 parent 38929f4 commit 7ad129f

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

src/openai/azure/_async_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,19 +418,19 @@ async def _poll(
418418
async def _request(self, cast_to: Type[ResponseT], options: FinalRequestOptions, **kwargs: Any) -> Any:
419419
if options.url == "/images/generations":
420420
options.url = "openai/images/generations:submit"
421-
response = await super().request(httpx.Response, **kwargs)
422-
operation_id = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {}
421+
response = await super()._request(cast_to=cast_to, options=options, **kwargs)
422+
model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {}
423+
operation_id = cast(str, model_extra['id'])
423424
return await self._poll(
424425
"get", f"openai/operations/images/{operation_id}",
425426
until=lambda response: response.json()["status"] in ["succeeded"],
426427
failed=lambda response: response.json()["status"] in ["failed"],
427428
)
428429
if isinstance(options.json_data, Mapping):
429-
model = cast(str, options.json_data["model"])
430+
model = cast(str, options.json_data["model"])
430431
if not options.url.startswith(f'openai/deployments/{model}'):
431432
if options.extra_json and options.extra_json.get("dataSources"):
432433
options.url = f'openai/deployments/{model}/extensions' + options.url
433-
else:
434+
else:
434435
options.url = f'openai/deployments/{model}' + options.url
435-
return await super().request(cast_to=cast_to, options=options, **kwargs)
436-
436+
return await super()._request(cast_to=cast_to, options=options, **kwargs)

src/openai/azure/_sync_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def auth_headers(self) -> Dict[str, str]:
384384
def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any:
385385
if options.url == "/images/generations":
386386
options.url = "openai/images/generations:submit"
387-
response = super().request(httpx.Response, **kwargs)
387+
response = super()._request(options=options, **kwargs)
388388
model_extra = cast(Mapping[str, Any], getattr(response, 'model_extra')) or {}
389389
operation_id = cast(str, model_extra['id'])
390390
return self._poll(
@@ -393,13 +393,13 @@ def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any:
393393
failed=lambda response: response.json()["status"] in ["failed"],
394394
)
395395
if isinstance(options.json_data, Mapping):
396-
model = cast(str, options.json_data["model"])
396+
model = cast(str, options.json_data["model"])
397397
if not options.url.startswith(f'openai/deployments/{model}'):
398398
if options.extra_json and options.extra_json.get("dataSources"):
399399
options.url = f'openai/deployments/{model}/extensions' + options.url
400-
else:
400+
else:
401401
options.url = f'openai/deployments/{model}' + options.url
402-
return super().request(options=options, **kwargs)
402+
return super()._request(options=options, **kwargs)
403403

404404
# Internal azure specific "helper" methods
405405
def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool:

0 commit comments

Comments
 (0)