Skip to content

Commit 479c171

Browse files
authored
Realtime: handoffs (#1139)
--- [//]: # (BEGIN SAPLING FOOTER) * #1141 * __->__ #1139
1 parent 6293d66 commit 479c171

14 files changed

+393
-48
lines changed

src/agents/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class Agent(AgentBase, Generic[TContext]):
158158
usable with OpenAI models, using the Responses API.
159159
"""
160160

161-
handoffs: list[Agent[Any] | Handoff[TContext]] = field(default_factory=list)
161+
handoffs: list[Agent[Any] | Handoff[TContext, Any]] = field(default_factory=list)
162162
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
163163
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
164164
modularity.

src/agents/guardrail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def decorator(
244244
return InputGuardrail(
245245
guardrail_function=f,
246246
# If not set, guardrail name uses the function’s name by default.
247-
name=name if name else f.__name__
247+
name=name if name else f.__name__,
248248
)
249249

250250
if func is not None:

src/agents/handoffs.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
from .util._types import MaybeAwaitable
1919

2020
if TYPE_CHECKING:
21-
from .agent import Agent
21+
from .agent import Agent, AgentBase
2222

2323

2424
# The handoff input type is the type of data passed when the agent is called via a handoff.
2525
THandoffInput = TypeVar("THandoffInput", default=Any)
2626

27+
# The agent type that the handoff returns
28+
TAgent = TypeVar("TAgent", bound="AgentBase[Any]", default="Agent[Any]")
29+
2730
OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
2831
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]
2932

@@ -52,7 +55,7 @@ class HandoffInputData:
5255

5356

5457
@dataclass
55-
class Handoff(Generic[TContext]):
58+
class Handoff(Generic[TContext, TAgent]):
5659
"""A handoff is when an agent delegates a task to another agent.
5760
For example, in a customer support scenario you might have a "triage agent" that determines
5861
which agent should handle the user's request, and sub-agents that specialize in different
@@ -69,7 +72,7 @@ class Handoff(Generic[TContext]):
6972
"""The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
7073
"""
7174

72-
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[Agent[TContext]]]
75+
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[TAgent]]
7376
"""The function that invokes the handoff. The parameters passed are:
7477
1. The handoff run context
7578
2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
@@ -100,20 +103,22 @@ class Handoff(Generic[TContext]):
100103
True, as it increases the likelihood of correct JSON input.
101104
"""
102105

103-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
106+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
107+
True
108+
)
104109
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
105110
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
106111
a handoff based on your context/state."""
107112

108-
def get_transfer_message(self, agent: Agent[Any]) -> str:
113+
def get_transfer_message(self, agent: AgentBase[Any]) -> str:
109114
return json.dumps({"assistant": agent.name})
110115

111116
@classmethod
112-
def default_tool_name(cls, agent: Agent[Any]) -> str:
117+
def default_tool_name(cls, agent: AgentBase[Any]) -> str:
113118
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
114119

115120
@classmethod
116-
def default_tool_description(cls, agent: Agent[Any]) -> str:
121+
def default_tool_description(cls, agent: AgentBase[Any]) -> str:
117122
return (
118123
f"Handoff to the {agent.name} agent to handle the request. "
119124
f"{agent.handoff_description or ''}"
@@ -128,7 +133,7 @@ def handoff(
128133
tool_description_override: str | None = None,
129134
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
130135
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
131-
) -> Handoff[TContext]: ...
136+
) -> Handoff[TContext, Agent[TContext]]: ...
132137

133138

134139
@overload
@@ -141,7 +146,7 @@ def handoff(
141146
tool_name_override: str | None = None,
142147
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
143148
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
144-
) -> Handoff[TContext]: ...
149+
) -> Handoff[TContext, Agent[TContext]]: ...
145150

146151

147152
@overload
@@ -153,7 +158,7 @@ def handoff(
153158
tool_name_override: str | None = None,
154159
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
155160
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
156-
) -> Handoff[TContext]: ...
161+
) -> Handoff[TContext, Agent[TContext]]: ...
157162

158163

159164
def handoff(
@@ -163,8 +168,9 @@ def handoff(
163168
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
164169
input_type: type[THandoffInput] | None = None,
165170
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
166-
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
167-
) -> Handoff[TContext]:
171+
is_enabled: bool
172+
| Callable[[RunContextWrapper[Any], Agent[TContext]], MaybeAwaitable[bool]] = True,
173+
) -> Handoff[TContext, Agent[TContext]]:
168174
"""Create a handoff from an agent.
169175
170176
Args:
@@ -202,7 +208,7 @@ def handoff(
202208

203209
async def _invoke_handoff(
204210
ctx: RunContextWrapper[Any], input_json: str | None = None
205-
) -> Agent[Any]:
211+
) -> Agent[TContext]:
206212
if input_type is not None and type_adapter is not None:
207213
if input_json is None:
208214
_error_tracing.attach_error_to_current_span(
@@ -239,12 +245,24 @@ async def _invoke_handoff(
239245
# If there is a need, we can make this configurable in the future
240246
input_json_schema = ensure_strict_json_schema(input_json_schema)
241247

248+
async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
249+
from .agent import Agent
250+
251+
assert callable(is_enabled), "is_enabled must be non-null here"
252+
assert isinstance(agent_base, Agent), "Can't handoff to a non-Agent"
253+
result = is_enabled(ctx, agent_base)
254+
255+
if inspect.isawaitable(result):
256+
return await result
257+
258+
return result
259+
242260
return Handoff(
243261
tool_name=tool_name,
244262
tool_description=tool_description,
245263
input_json_schema=input_json_schema,
246264
on_invoke_handoff=_invoke_handoff,
247265
input_filter=input_filter,
248266
agent_name=agent.name,
249-
is_enabled=is_enabled,
267+
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
250268
)

src/agents/models/chatcmpl_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam:
484484
)
485485

486486
@classmethod
487-
def convert_handoff_tool(cls, handoff: Handoff[Any]) -> ChatCompletionToolParam:
487+
def convert_handoff_tool(cls, handoff: Handoff[Any, Any]) -> ChatCompletionToolParam:
488488
return {
489489
"type": "function",
490490
"function": {

src/agents/models/openai_responses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def get_response_format(
370370
def convert_tools(
371371
cls,
372372
tools: list[Tool],
373-
handoffs: list[Handoff[Any]],
373+
handoffs: list[Handoff[Any, Any]],
374374
) -> ConvertedTools:
375375
converted_tools: list[ToolParam] = []
376376
includes: list[ResponseIncludable] = []

src/agents/realtime/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RealtimeToolEnd,
3131
RealtimeToolStart,
3232
)
33+
from .handoffs import realtime_handoff
3334
from .items import (
3435
AssistantMessageItem,
3536
AssistantText,
@@ -92,6 +93,8 @@
9293
"RealtimeAgentHooks",
9394
"RealtimeRunHooks",
9495
"RealtimeRunner",
96+
# Handoffs
97+
"realtime_handoff",
9598
# Config
9699
"RealtimeAudioFormat",
97100
"RealtimeClientMessage",

src/agents/realtime/agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import dataclasses
44
import inspect
55
from collections.abc import Awaitable
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from typing import Any, Callable, Generic, cast
88

99
from ..agent import AgentBase
10+
from ..handoffs import Handoff
1011
from ..lifecycle import AgentHooksBase, RunHooksBase
1112
from ..logger import logger
1213
from ..run_context import RunContextWrapper, TContext
@@ -53,6 +54,14 @@ class RealtimeAgent(AgentBase, Generic[TContext]):
5354
return a string.
5455
"""
5556

57+
handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field(
58+
default_factory=list
59+
)
60+
"""Handoffs are sub-agents that the agent can delegate to. You can provide a list of handoffs,
61+
and the agent can choose to delegate to them if relevant. Allows for separation of concerns and
62+
modularity.
63+
"""
64+
5665
hooks: RealtimeAgentHooks | None = None
5766
"""A class that receives callbacks on various lifecycle events for this agent.
5867
"""

src/agents/realtime/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

1111
from ..guardrail import OutputGuardrail
12+
from ..handoffs import Handoff
1213
from ..model_settings import ToolChoice
1314
from ..tool import Tool
1415

@@ -71,6 +72,7 @@ class RealtimeSessionModelSettings(TypedDict):
7172

7273
tool_choice: NotRequired[ToolChoice]
7374
tools: NotRequired[list[Tool]]
75+
handoffs: NotRequired[list[Handoff]]
7476

7577
tracing: NotRequired[RealtimeModelTracingConfig | None]
7678

src/agents/realtime/handoffs.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from typing import TYPE_CHECKING, Any, Callable, cast, overload
5+
6+
from pydantic import TypeAdapter
7+
from typing_extensions import TypeVar
8+
9+
from ..exceptions import ModelBehaviorError, UserError
10+
from ..handoffs import Handoff
11+
from ..run_context import RunContextWrapper, TContext
12+
from ..strict_schema import ensure_strict_json_schema
13+
from ..tracing.spans import SpanError
14+
from ..util import _error_tracing, _json
15+
from ..util._types import MaybeAwaitable
16+
17+
if TYPE_CHECKING:
18+
from ..agent import AgentBase
19+
from . import RealtimeAgent
20+
21+
22+
# The handoff input type is the type of data passed when the agent is called via a handoff.
23+
THandoffInput = TypeVar("THandoffInput", default=Any)
24+
25+
OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
26+
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]
27+
28+
29+
@overload
30+
def realtime_handoff(
31+
agent: RealtimeAgent[TContext],
32+
*,
33+
tool_name_override: str | None = None,
34+
tool_description_override: str | None = None,
35+
is_enabled: bool
36+
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
37+
) -> Handoff[TContext, RealtimeAgent[TContext]]: ...
38+
39+
40+
@overload
41+
def realtime_handoff(
42+
agent: RealtimeAgent[TContext],
43+
*,
44+
on_handoff: OnHandoffWithInput[THandoffInput],
45+
input_type: type[THandoffInput],
46+
tool_description_override: str | None = None,
47+
tool_name_override: str | None = None,
48+
is_enabled: bool
49+
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
50+
) -> Handoff[TContext, RealtimeAgent[TContext]]: ...
51+
52+
53+
@overload
54+
def realtime_handoff(
55+
agent: RealtimeAgent[TContext],
56+
*,
57+
on_handoff: OnHandoffWithoutInput,
58+
tool_description_override: str | None = None,
59+
tool_name_override: str | None = None,
60+
is_enabled: bool
61+
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
62+
) -> Handoff[TContext, RealtimeAgent[TContext]]: ...
63+
64+
65+
def realtime_handoff(
66+
agent: RealtimeAgent[TContext],
67+
tool_name_override: str | None = None,
68+
tool_description_override: str | None = None,
69+
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
70+
input_type: type[THandoffInput] | None = None,
71+
is_enabled: bool
72+
| Callable[[RunContextWrapper[Any], RealtimeAgent[Any]], MaybeAwaitable[bool]] = True,
73+
) -> Handoff[TContext, RealtimeAgent[TContext]]:
74+
"""Create a handoff from a RealtimeAgent.
75+
76+
Args:
77+
agent: The RealtimeAgent to handoff to, or a function that returns a RealtimeAgent.
78+
tool_name_override: Optional override for the name of the tool that represents the handoff.
79+
tool_description_override: Optional override for the description of the tool that
80+
represents the handoff.
81+
on_handoff: A function that runs when the handoff is invoked.
82+
input_type: the type of the input to the handoff. If provided, the input will be validated
83+
against this type. Only relevant if you pass a function that takes an input.
84+
is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run
85+
context and agent and returns whether the handoff is enabled. Disabled handoffs are
86+
hidden from the LLM at runtime.
87+
88+
Note: input_filter is not supported for RealtimeAgent handoffs.
89+
"""
90+
assert (on_handoff and input_type) or not (on_handoff and input_type), (
91+
"You must provide either both on_handoff and input_type, or neither"
92+
)
93+
type_adapter: TypeAdapter[Any] | None
94+
if input_type is not None:
95+
assert callable(on_handoff), "on_handoff must be callable"
96+
sig = inspect.signature(on_handoff)
97+
if len(sig.parameters) != 2:
98+
raise UserError("on_handoff must take two arguments: context and input")
99+
100+
type_adapter = TypeAdapter(input_type)
101+
input_json_schema = type_adapter.json_schema()
102+
else:
103+
type_adapter = None
104+
input_json_schema = {}
105+
if on_handoff is not None:
106+
sig = inspect.signature(on_handoff)
107+
if len(sig.parameters) != 1:
108+
raise UserError("on_handoff must take one argument: context")
109+
110+
async def _invoke_handoff(
111+
ctx: RunContextWrapper[Any], input_json: str | None = None
112+
) -> RealtimeAgent[TContext]:
113+
if input_type is not None and type_adapter is not None:
114+
if input_json is None:
115+
_error_tracing.attach_error_to_current_span(
116+
SpanError(
117+
message="Handoff function expected non-null input, but got None",
118+
data={"details": "input_json is None"},
119+
)
120+
)
121+
raise ModelBehaviorError("Handoff function expected non-null input, but got None")
122+
123+
validated_input = _json.validate_json(
124+
json_str=input_json,
125+
type_adapter=type_adapter,
126+
partial=False,
127+
)
128+
input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff)
129+
if inspect.iscoroutinefunction(input_func):
130+
await input_func(ctx, validated_input)
131+
else:
132+
input_func(ctx, validated_input)
133+
elif on_handoff is not None:
134+
no_input_func = cast(OnHandoffWithoutInput, on_handoff)
135+
if inspect.iscoroutinefunction(no_input_func):
136+
await no_input_func(ctx)
137+
else:
138+
no_input_func(ctx)
139+
140+
return agent
141+
142+
tool_name = tool_name_override or Handoff.default_tool_name(agent)
143+
tool_description = tool_description_override or Handoff.default_tool_description(agent)
144+
145+
# Always ensure the input JSON schema is in strict mode
146+
# If there is a need, we can make this configurable in the future
147+
input_json_schema = ensure_strict_json_schema(input_json_schema)
148+
149+
async def _is_enabled(ctx: RunContextWrapper[Any], agent_base: AgentBase[Any]) -> bool:
150+
assert callable(is_enabled), "is_enabled must be non-null here"
151+
assert isinstance(agent_base, RealtimeAgent), "Can't handoff to a non-RealtimeAgent"
152+
result = is_enabled(ctx, agent_base)
153+
if inspect.isawaitable(result):
154+
return await result
155+
return result
156+
157+
return Handoff(
158+
tool_name=tool_name,
159+
tool_description=tool_description,
160+
input_json_schema=input_json_schema,
161+
on_invoke_handoff=_invoke_handoff,
162+
input_filter=None, # Not supported for RealtimeAgent handoffs
163+
agent_name=agent.name,
164+
is_enabled=_is_enabled if callable(is_enabled) else is_enabled,
165+
)

0 commit comments

Comments
 (0)