Skip to content

AIDEV-1390 | Added support for "return" handoffs #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ class NextStepHandoff:
new_agent: Agent[Any]


@dataclass
class NextStepHandoffReturnControl:
previous_agent: Agent[Any]


@dataclass
class NextStepFinalOutput:
output: Any
Expand All @@ -201,7 +206,9 @@ class SingleStepResult:
new_step_items: list[RunItem]
"""Items generated during this current step."""

next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
next_step: (
NextStepHandoff | NextStepFinalOutput | NextStepRunAgain | NextStepHandoffReturnControl
)
"""The next step to take."""

@property
Expand Down Expand Up @@ -238,6 +245,7 @@ async def execute_tools_and_side_effects(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_agents: list[Agent],
) -> SingleStepResult:
# Make a copy of the generated items
pre_step_items = list(pre_step_items)
Expand Down Expand Up @@ -286,6 +294,7 @@ async def execute_tools_and_side_effects(
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
previous_agents=previous_agents,
)

# Next, we'll check if the tool use should result in a final output
Expand Down Expand Up @@ -316,6 +325,7 @@ async def execute_tools_and_side_effects(
final_output=check_tool_use.final_output,
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)

# Now we can check if the model also produced a final output
Expand All @@ -340,6 +350,7 @@ async def execute_tools_and_side_effects(
final_output=final_output,
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)
elif (
not output_schema or output_schema.is_plain_text()
Expand All @@ -353,6 +364,7 @@ async def execute_tools_and_side_effects(
final_output=potential_final_output_text or "",
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)
else:
# If there's no final output, we can just run again
Expand Down Expand Up @@ -663,6 +675,7 @@ async def execute_handoffs(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
# If there is more than one handoff, add tool responses that reject those handoffs
multiple_handoffs = len(run_handoffs) > 1
Expand All @@ -684,6 +697,8 @@ async def execute_handoffs(
actual_handoff = run_handoffs[0]
with handoff_span(from_agent=agent.name) as span_handoff:
handoff = actual_handoff.handoff
if handoff.should_return_control:
previous_agents.append(agent)
new_agent: Agent[Any] = await handoff.on_invoke_handoff(
context_wrapper, actual_handoff.tool_call.arguments
)
Expand Down Expand Up @@ -825,16 +840,21 @@ async def execute_final_output(
final_output: Any,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
is_returning_control = len(previous_agents) > 0
# Run the on_end hooks
await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)

await cls.run_final_output_hooks(
agent, hooks, context_wrapper, final_output, is_returning_control
)
return SingleStepResult(
original_input=original_input,
model_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
next_step=NextStepFinalOutput(final_output),
next_step=NextStepHandoffReturnControl(previous_agents.pop())
if is_returning_control
else NextStepFinalOutput(final_output),
)

@classmethod
Expand All @@ -844,13 +864,19 @@ async def run_final_output_hooks(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
final_output: Any,
is_returning_control: bool,
):
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
)
# If the agent is not returning control, run the hooks
if not is_returning_control:
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
)
# If the agent is returning control, only run the current agent's hooks
elif agent.hooks:
await agent.hooks.on_end(context_wrapper, agent, final_output)

@classmethod
async def run_single_input_guardrail(
Expand Down
13 changes: 12 additions & 1 deletion src/agents/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ class Handoff(Generic[TContext]):
True, as it increases the likelihood of correct JSON input.
"""

should_return_control: bool = False
"""Whether the Agent that receives control during a handoff should return control to the
original (previous) Agent upon completion of its work. If False, after the Agent that received
the handoff completes its work, the interaction will end.
"""

def get_transfer_message(self, agent: Agent[Any]) -> str:
return json.dumps({"assistant": agent.name})

Expand All @@ -121,6 +127,7 @@ def handoff(
tool_name_override: str | None = None,
tool_description_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
should_return_control: bool = False,
) -> Handoff[TContext]: ...


Expand All @@ -133,6 +140,7 @@ def handoff(
tool_description_override: str | None = None,
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
should_return_control: bool = False,
) -> Handoff[TContext]: ...


Expand All @@ -144,6 +152,7 @@ def handoff(
tool_description_override: str | None = None,
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
should_return_control: bool = False,
) -> Handoff[TContext]: ...


Expand All @@ -154,6 +163,7 @@ def handoff(
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
input_type: type[THandoffInput] | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
should_return_control: bool = False,
) -> Handoff[TContext]:
"""Create a handoff from an agent.

Expand All @@ -168,7 +178,7 @@ def handoff(
input_filter: a function that filters the inputs that are passed to the next agent.
"""
assert (on_handoff and input_type) or not (on_handoff and input_type), (
"You must provide either both on_handoff and input_type, or neither"
"You must provide either both on_input and input_type, or neither"
)
type_adapter: TypeAdapter[Any] | None
if input_type is not None:
Expand Down Expand Up @@ -233,4 +243,5 @@ async def _invoke_handoff(
on_invoke_handoff=_invoke_handoff,
input_filter=input_filter,
agent_name=agent.name,
should_return_control=should_return_control,
)
37 changes: 33 additions & 4 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AgentToolUseTracker,
NextStepFinalOutput,
NextStepHandoff,
NextStepHandoffReturnControl,
NextStepRunAgain,
QueueCompleteSentinel,
RunImpl,
Expand Down Expand Up @@ -119,6 +120,7 @@ async def run(
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
previous_agents: list[Agent[TContext]] | None = None,
) -> RunResult:
"""Run a workflow starting at the given agent. The agent will run in a loop until a final
output is generated. The loop runs like so:
Expand Down Expand Up @@ -154,6 +156,8 @@ async def run(
hooks = RunHooks[Any]()
if run_config is None:
run_config = RunConfig()
if previous_agents is None:
previous_agents = []

tool_use_tracker = AgentToolUseTracker()

Expand Down Expand Up @@ -235,6 +239,7 @@ async def run(
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
previous_response_id=previous_response_id,
previous_agents=previous_agents,
),
)
else:
Expand All @@ -249,6 +254,7 @@ async def run(
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
previous_response_id=previous_response_id,
previous_agents=previous_agents,
)
should_run_agent_start_hooks = False

Expand All @@ -273,8 +279,13 @@ async def run(
output_guardrail_results=output_guardrail_results,
context_wrapper=context_wrapper,
)
elif isinstance(turn_result.next_step, NextStepHandoff):
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
elif isinstance(turn_result.next_step, NextStepHandoff) or isinstance(
turn_result.next_step, NextStepHandoffReturnControl
):
if isinstance(turn_result.next_step, NextStepHandoffReturnControl):
current_agent = turn_result.next_step.previous_agent
else:
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
current_span.finish(reset_current=True)
current_span = None
should_run_agent_start_hooks = True
Expand Down Expand Up @@ -367,6 +378,7 @@ def run_streamed(
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
previous_agents: list[Agent[TContext]] | None = None,
) -> RunResultStreaming:
"""Run a workflow starting at the given agent in streaming mode. The returned result object
contains a method you can use to stream semantic events as they are generated.
Expand Down Expand Up @@ -402,6 +414,8 @@ def run_streamed(
hooks = RunHooks[Any]()
if run_config is None:
run_config = RunConfig()
if previous_agents is None:
previous_agents = []

# If there's already a trace, we don't create a new one. In addition, we can't end the
# trace here, because the actual work is done in `stream_events` and this method ends
Expand Down Expand Up @@ -450,6 +464,7 @@ def run_streamed(
context_wrapper=context_wrapper,
run_config=run_config,
previous_response_id=previous_response_id,
previous_agents=previous_agents,
)
)
return streamed_result
Expand Down Expand Up @@ -508,6 +523,7 @@ async def _run_streamed_impl(
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_response_id: str | None,
previous_agents: list[Agent[TContext]],
):
if streamed_result.trace:
streamed_result.trace.start(mark_as_current=True)
Expand Down Expand Up @@ -581,6 +597,7 @@ async def _run_streamed_impl(
tool_use_tracker,
all_tools,
previous_response_id,
previous_agents,
)
should_run_agent_start_hooks = False

Expand All @@ -590,8 +607,14 @@ async def _run_streamed_impl(
streamed_result.input = turn_result.original_input
streamed_result.new_items = turn_result.generated_items

if isinstance(turn_result.next_step, NextStepHandoff):
current_agent = turn_result.next_step.new_agent
if isinstance(turn_result.next_step, NextStepHandoff) or isinstance(
turn_result.next_step, NextStepHandoffReturnControl
):
if isinstance(turn_result.next_step, NextStepHandoff):
current_agent = turn_result.next_step.new_agent
else:
current_agent = turn_result.next_step.previous_agent

current_span.finish(reset_current=True)
current_span = None
should_run_agent_start_hooks = True
Expand Down Expand Up @@ -666,6 +689,7 @@ async def _run_single_turn_streamed(
tool_use_tracker: AgentToolUseTracker,
all_tools: list[Tool],
previous_response_id: str | None,
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
if should_run_agent_start_hooks:
await asyncio.gather(
Expand Down Expand Up @@ -746,6 +770,7 @@ async def _run_single_turn_streamed(
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
previous_agents=previous_agents,
)

RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
Expand All @@ -765,6 +790,7 @@ async def _run_single_turn(
should_run_agent_start_hooks: bool,
tool_use_tracker: AgentToolUseTracker,
previous_response_id: str | None,
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
# Ensure we run the hooks before anything else
if should_run_agent_start_hooks:
Expand Down Expand Up @@ -809,6 +835,7 @@ async def _run_single_turn(
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
previous_agents=previous_agents,
)

@classmethod
Expand All @@ -826,6 +853,7 @@ async def _get_single_step_result_from_response(
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
processed_response = RunImpl.process_model_response(
agent=agent,
Expand All @@ -847,6 +875,7 @@ async def _get_single_step_result_from_response(
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
previous_agents=previous_agents,
)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions src/agents/tool_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")


@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""
Expand Down
Loading