Skip to content

Commit 5cb7cfa

Browse files
committed
Realtime: use SDK types for all messages
1 parent 8fdbe09 commit 5cb7cfa

File tree

5 files changed

+571
-140
lines changed

5 files changed

+571
-140
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ requires-python = ">=3.9"
77
license = "MIT"
88
authors = [{ name = "OpenAI", email = "[email protected]" }]
99
dependencies = [
10-
"openai>=1.93.1, <2",
10+
"openai>=1.96.0, <2",
1111
"pydantic>=2.10, <3",
1212
"griffe>=1.5.6, <2",
1313
"typing-extensions>=4.12.2, <5",

src/agents/realtime/openai_realtime.py

Lines changed: 163 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,47 @@
1010

1111
import pydantic
1212
import websockets
13-
from openai.types.beta.realtime.conversation_item import ConversationItem
13+
from openai.types.beta.realtime.conversation_item import (
14+
ConversationItem,
15+
ConversationItem as OpenAIConversationItem,
16+
)
17+
from openai.types.beta.realtime.conversation_item_content import (
18+
ConversationItemContent as OpenAIConversationItemContent,
19+
)
20+
from openai.types.beta.realtime.conversation_item_create_event import (
21+
ConversationItemCreateEvent as OpenAIConversationItemCreateEvent,
22+
)
23+
from openai.types.beta.realtime.conversation_item_retrieve_event import (
24+
ConversationItemRetrieveEvent as OpenAIConversationItemRetrieveEvent,
25+
)
26+
from openai.types.beta.realtime.conversation_item_truncate_event import (
27+
ConversationItemTruncateEvent as OpenAIConversationItemTruncateEvent,
28+
)
29+
from openai.types.beta.realtime.input_audio_buffer_append_event import (
30+
InputAudioBufferAppendEvent as OpenAIInputAudioBufferAppendEvent,
31+
)
32+
from openai.types.beta.realtime.input_audio_buffer_commit_event import (
33+
InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent,
34+
)
35+
from openai.types.beta.realtime.realtime_client_event import (
36+
RealtimeClientEvent as OpenAIRealtimeClientEvent,
37+
)
1438
from openai.types.beta.realtime.realtime_server_event import (
1539
RealtimeServerEvent as OpenAIRealtimeServerEvent,
1640
)
1741
from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent
42+
from openai.types.beta.realtime.response_cancel_event import (
43+
ResponseCancelEvent as OpenAIResponseCancelEvent,
44+
)
45+
from openai.types.beta.realtime.response_create_event import (
46+
ResponseCreateEvent as OpenAIResponseCreateEvent,
47+
)
1848
from openai.types.beta.realtime.session_update_event import (
1949
Session as OpenAISessionObject,
2050
SessionTool as OpenAISessionTool,
51+
SessionTracing as OpenAISessionTracing,
52+
SessionTracingTracingConfiguration as OpenAISessionTracingConfiguration,
53+
SessionUpdateEvent as OpenAISessionUpdateEvent,
2154
)
2255
from pydantic import TypeAdapter
2356
from typing_extensions import assert_never
@@ -135,12 +168,11 @@ async def _send_tracing_config(
135168
) -> None:
136169
"""Update tracing configuration via session.update event."""
137170
if tracing_config is not None:
171+
converted_tracing_config = _ConversionHelper.convert_tracing_config(tracing_config)
138172
await self._send_raw_message(
139-
RealtimeModelSendRawMessage(
140-
message={
141-
"type": "session.update",
142-
"other_data": {"session": {"tracing": tracing_config}},
143-
}
173+
OpenAISessionUpdateEvent(
174+
session=OpenAISessionObject(tracing=converted_tracing_config),
175+
type="session.update",
144176
)
145177
)
146178

@@ -199,7 +231,11 @@ async def _listen_for_messages(self):
199231
async def send_event(self, event: RealtimeModelSendEvent) -> None:
200232
"""Send an event to the model."""
201233
if isinstance(event, RealtimeModelSendRawMessage):
202-
await self._send_raw_message(event)
234+
converted = _ConversionHelper.try_convert_raw_message(event)
235+
if converted is not None:
236+
await self._send_raw_message(converted)
237+
else:
238+
logger.error(f"Failed to convert raw message: {event}")
203239
elif isinstance(event, RealtimeModelSendUserInput):
204240
await self._send_user_input(event)
205241
elif isinstance(event, RealtimeModelSendAudio):
@@ -214,73 +250,28 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None:
214250
assert_never(event)
215251
raise ValueError(f"Unknown event type: {type(event)}")
216252

217-
async def _send_raw_message(self, event: RealtimeModelSendRawMessage) -> None:
253+
async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
218254
"""Send a raw message to the model."""
219255
assert self._websocket is not None, "Not connected"
220256

221-
converted_event = {
222-
"type": event.message["type"],
223-
}
224-
225-
converted_event.update(event.message.get("other_data", {}))
226-
227-
await self._websocket.send(json.dumps(converted_event))
257+
await self._websocket.send(event.model_dump_json(exclude_none=True, exclude_unset=True))
228258

229259
async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
230-
message = (
231-
event.user_input
232-
if isinstance(event.user_input, dict)
233-
else {
234-
"type": "message",
235-
"role": "user",
236-
"content": [{"type": "input_text", "text": event.user_input}],
237-
}
238-
)
239-
other_data = {
240-
"item": message,
241-
}
242-
243-
await self._send_raw_message(
244-
RealtimeModelSendRawMessage(
245-
message={"type": "conversation.item.create", "other_data": other_data}
246-
)
247-
)
248-
await self._send_raw_message(
249-
RealtimeModelSendRawMessage(message={"type": "response.create"})
250-
)
260+
converted = _ConversionHelper.convert_user_input_to_item_create(event)
261+
await self._send_raw_message(converted)
262+
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
251263

252264
async def _send_audio(self, event: RealtimeModelSendAudio) -> None:
253-
base64_audio = base64.b64encode(event.audio).decode("utf-8")
254-
await self._send_raw_message(
255-
RealtimeModelSendRawMessage(
256-
message={
257-
"type": "input_audio_buffer.append",
258-
"other_data": {
259-
"audio": base64_audio,
260-
},
261-
}
262-
)
263-
)
265+
converted = _ConversionHelper.convert_audio_to_input_audio_buffer_append(event)
266+
await self._send_raw_message(converted)
264267
if event.commit:
265268
await self._send_raw_message(
266-
RealtimeModelSendRawMessage(message={"type": "input_audio_buffer.commit"})
269+
OpenAIInputAudioBufferCommitEvent(type="input_audio_buffer.commit")
267270
)
268271

269272
async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
270-
await self._send_raw_message(
271-
RealtimeModelSendRawMessage(
272-
message={
273-
"type": "conversation.item.create",
274-
"other_data": {
275-
"item": {
276-
"type": "function_call_output",
277-
"output": event.output,
278-
"call_id": event.tool_call.id,
279-
},
280-
},
281-
}
282-
)
283-
)
273+
converted = _ConversionHelper.convert_tool_output(event)
274+
await self._send_raw_message(converted)
284275

285276
tool_item = RealtimeToolCallItem(
286277
item_id=event.tool_call.id or "",
@@ -294,9 +285,7 @@ async def _send_tool_output(self, event: RealtimeModelSendToolOutput) -> None:
294285
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_item))
295286

296287
if event.start_response:
297-
await self._send_raw_message(
298-
RealtimeModelSendRawMessage(message={"type": "response.create"})
299-
)
288+
await self._send_raw_message(OpenAIResponseCreateEvent(type="response.create"))
300289

301290
async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
302291
if not self._current_item_id or not self._audio_start_time:
@@ -307,18 +296,12 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
307296
elapsed_time_ms = (datetime.now() - self._audio_start_time).total_seconds() * 1000
308297
if elapsed_time_ms > 0 and elapsed_time_ms < self._audio_length_ms:
309298
await self._emit_event(RealtimeModelAudioInterruptedEvent())
310-
await self._send_raw_message(
311-
RealtimeModelSendRawMessage(
312-
message={
313-
"type": "conversation.item.truncate",
314-
"other_data": {
315-
"item_id": self._current_item_id,
316-
"content_index": self._current_audio_content_index,
317-
"audio_end_ms": elapsed_time_ms,
318-
},
319-
}
320-
)
299+
converted = _ConversionHelper.convert_interrupt(
300+
self._current_item_id,
301+
self._current_audio_content_index or 0,
302+
int(elapsed_time_ms),
321303
)
304+
await self._send_raw_message(converted)
322305

323306
self._current_item_id = None
324307
self._audio_start_time = None
@@ -365,7 +348,7 @@ async def _handle_output_item(self, item: ConversationItem) -> None:
365348
await self._emit_event(RealtimeModelItemUpdatedEvent(item=tool_call))
366349
await self._emit_event(
367350
RealtimeModelToolCallEvent(
368-
call_id=item.id or "",
351+
call_id=item.call_id or "",
369352
name=item.name or "",
370353
arguments=item.arguments or "",
371354
id=item.id or "",
@@ -404,9 +387,7 @@ async def close(self) -> None:
404387

405388
async def _cancel_response(self) -> None:
406389
if self._ongoing_response:
407-
await self._send_raw_message(
408-
RealtimeModelSendRawMessage(message={"type": "response.cancel"})
409-
)
390+
await self._send_raw_message(OpenAIResponseCancelEvent(type="response.cancel"))
410391
self._ongoing_response = False
411392

412393
async def _handle_ws_event(self, event: dict[str, Any]):
@@ -466,16 +447,13 @@ async def _handle_ws_event(self, event: dict[str, Any]):
466447
parsed.type == "conversation.item.input_audio_transcription.completed"
467448
or parsed.type == "conversation.item.truncated"
468449
):
469-
await self._send_raw_message(
470-
RealtimeModelSendRawMessage(
471-
message={
472-
"type": "conversation.item.retrieve",
473-
"other_data": {
474-
"item_id": self._current_item_id,
475-
},
476-
}
450+
if self._current_item_id:
451+
await self._send_raw_message(
452+
OpenAIConversationItemRetrieveEvent(
453+
type="conversation.item.retrieve",
454+
item_id=self._current_item_id,
455+
)
477456
)
478-
)
479457
if parsed.type == "conversation.item.input_audio_transcription.completed":
480458
await self._emit_event(
481459
RealtimeModelInputAudioTranscriptionCompletedEvent(
@@ -504,14 +482,7 @@ async def _handle_ws_event(self, event: dict[str, Any]):
504482
async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None:
505483
session_config = self._get_session_config(model_settings)
506484
await self._send_raw_message(
507-
RealtimeModelSendRawMessage(
508-
message={
509-
"type": "session.update",
510-
"other_data": {
511-
"session": session_config.model_dump(exclude_unset=True, exclude_none=True)
512-
},
513-
}
514-
)
485+
OpenAISessionUpdateEvent(session=session_config, type="session.update")
515486
)
516487

517488
def _get_session_config(
@@ -582,3 +553,98 @@ def conversation_item_to_realtime_message_item(
582553
"status": "in_progress",
583554
},
584555
)
556+
557+
@classmethod
558+
def try_convert_raw_message(
559+
cls, message: RealtimeModelSendRawMessage
560+
) -> OpenAIRealtimeClientEvent | None:
561+
try:
562+
data = {}
563+
data["type"] = message.message["type"]
564+
data.update(message.message.get("other_data", {}))
565+
return TypeAdapter(OpenAIRealtimeClientEvent).validate_python(data)
566+
except Exception:
567+
return None
568+
569+
@classmethod
570+
def convert_tracing_config(
571+
cls, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None
572+
) -> OpenAISessionTracing | None:
573+
if tracing_config is None:
574+
return None
575+
elif tracing_config == "auto":
576+
return "auto"
577+
return OpenAISessionTracingConfiguration(
578+
group_id=tracing_config.get("group_id"),
579+
metadata=tracing_config.get("metadata"),
580+
workflow_name=tracing_config.get("workflow_name"),
581+
)
582+
583+
@classmethod
584+
def convert_user_input_to_conversation_item(
585+
cls, event: RealtimeModelSendUserInput
586+
) -> OpenAIConversationItem:
587+
user_input = event.user_input
588+
589+
if isinstance(user_input, dict):
590+
return OpenAIConversationItem(
591+
type="message",
592+
role="user",
593+
content=[
594+
OpenAIConversationItemContent(
595+
type="input_text",
596+
text=item.get("text"),
597+
)
598+
for item in user_input.get("content", [])
599+
],
600+
)
601+
else:
602+
return OpenAIConversationItem(
603+
type="message",
604+
role="user",
605+
content=[OpenAIConversationItemContent(type="input_text", text=user_input)],
606+
)
607+
608+
@classmethod
609+
def convert_user_input_to_item_create(
610+
cls, event: RealtimeModelSendUserInput
611+
) -> OpenAIRealtimeClientEvent:
612+
return OpenAIConversationItemCreateEvent(
613+
type="conversation.item.create",
614+
item=cls.convert_user_input_to_conversation_item(event),
615+
)
616+
617+
@classmethod
618+
def convert_audio_to_input_audio_buffer_append(
619+
cls, event: RealtimeModelSendAudio
620+
) -> OpenAIRealtimeClientEvent:
621+
base64_audio = base64.b64encode(event.audio).decode("utf-8")
622+
return OpenAIInputAudioBufferAppendEvent(
623+
type="input_audio_buffer.append",
624+
audio=base64_audio,
625+
)
626+
627+
@classmethod
628+
def convert_tool_output(cls, event: RealtimeModelSendToolOutput) -> OpenAIRealtimeClientEvent:
629+
return OpenAIConversationItemCreateEvent(
630+
type="conversation.item.create",
631+
item=OpenAIConversationItem(
632+
type="function_call_output",
633+
output=event.output,
634+
call_id=event.tool_call.call_id,
635+
),
636+
)
637+
638+
@classmethod
639+
def convert_interrupt(
640+
cls,
641+
current_item_id: str,
642+
current_audio_content_index: int,
643+
elapsed_time_ms: int,
644+
) -> OpenAIRealtimeClientEvent:
645+
return OpenAIConversationItemTruncateEvent(
646+
type="conversation.item.truncate",
647+
item_id=current_item_id,
648+
content_index=current_audio_content_index,
649+
audio_end_ms=elapsed_time_ms,
650+
)

0 commit comments

Comments
 (0)