diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index cb2752e4f..76c93a298 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -15,7 +15,6 @@ from .strict_schema import ensure_strict_json_schema from .tracing.spans import SpanError from .util import _error_tracing, _json, _transforms -from .util._types import MaybeAwaitable if TYPE_CHECKING: from .agent import Agent @@ -100,11 +99,6 @@ class Handoff(Generic[TContext]): True, as it increases the likelihood of correct JSON input. """ - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True - """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and - agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable - a handoff based on your context/state.""" - def get_transfer_message(self, agent: Agent[Any]) -> str: return json.dumps({"assistant": agent.name}) @@ -127,7 +121,6 @@ def handoff( tool_name_override: str | None = None, tool_description_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -140,7 +133,6 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -152,7 +144,6 @@ def handoff( tool_description_override: str | None = None, tool_name_override: str | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: ... @@ -163,7 +154,6 @@ def handoff( on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None, input_type: type[THandoffInput] | None = None, input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None, - is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True, ) -> Handoff[TContext]: """Create a handoff from an agent. @@ -176,9 +166,6 @@ def handoff( input_type: the type of the input to the handoff. If provided, the input will be validated against this type. Only relevant if you pass a function that takes an input. input_filter: a function that filters the inputs that are passed to the next agent. - is_enabled: Whether the handoff is enabled. Can be a bool or a callable that takes the run - context and agent and returns whether the handoff is enabled. Disabled handoffs are - hidden from the LLM at runtime. """ assert (on_handoff and input_type) or not (on_handoff and input_type), ( "You must provide either both on_handoff and input_type, or neither" @@ -246,5 +233,4 @@ async def _invoke_handoff( on_invoke_handoff=_invoke_handoff, input_filter=input_filter, agent_name=agent.name, - is_enabled=is_enabled, ) diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378ec..8a44a0e54 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,7 +2,6 @@ import asyncio import copy -import inspect from dataclasses import dataclass, field from typing import Any, Generic, cast @@ -362,8 +361,7 @@ async def run( # agent changes, or if the agent loop ends. if current_span is None: handoff_names = [ - h.agent_name - for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) + h.agent_name for h in AgentRunner._get_handoffs(current_agent) ] if output_schema := AgentRunner._get_output_schema(current_agent): output_type_name = output_schema.name() @@ -643,10 +641,7 @@ async def _start_streaming( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [ - h.agent_name - for h in await cls._get_handoffs(current_agent, context_wrapper) - ] + handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] if output_schema := cls._get_output_schema(current_agent): output_type_name = output_schema.name() else: @@ -803,7 +798,7 @@ async def _run_single_turn_streamed( agent.get_prompt(context_wrapper), ) - handoffs = await cls._get_handoffs(agent, context_wrapper) + handoffs = cls._get_handoffs(agent) model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) @@ -903,7 +898,7 @@ async def _run_single_turn( ) output_schema = cls._get_output_schema(agent) - handoffs = await cls._get_handoffs(agent, context_wrapper) + handoffs = cls._get_handoffs(agent) input = ItemHelpers.input_to_new_input_list(original_input) input.extend([generated_item.to_input_item() for generated_item in generated_items]) @@ -1096,28 +1091,14 @@ def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: return AgentOutputSchema(agent.output_type) @classmethod - async def _get_handoffs( - cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any] - ) -> list[Handoff]: + def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]: handoffs = [] for handoff_item in agent.handoffs: if isinstance(handoff_item, Handoff): handoffs.append(handoff_item) elif isinstance(handoff_item, Agent): handoffs.append(handoff(handoff_item)) - - async def _check_handoff_enabled(handoff_obj: Handoff) -> bool: - attr = handoff_obj.is_enabled - if isinstance(attr, bool): - return attr - res = attr(context_wrapper, agent) - if inspect.isawaitable(res): - return bool(await res) - return bool(res) - - results = await asyncio.gather(*(_check_handoff_enabled(h) for h in handoffs)) - enabled: list[Handoff] = [h for h, ok in zip(handoffs, results) if ok] - return enabled + return handoffs @classmethod async def _get_all_tools( diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index a985fd60d..f9423619d 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -43,7 +43,7 @@ async def test_handoff_with_agents(): handoffs=[agent_1, agent_2], ) - handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) + handoffs = AgentRunner._get_handoffs(agent_3) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -78,7 +78,7 @@ async def test_handoff_with_handoff_obj(): ], ) - handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) + handoffs = AgentRunner._get_handoffs(agent_3) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" @@ -112,7 +112,7 @@ async def test_handoff_with_handoff_obj_and_agent(): handoffs=[handoff(agent_1), agent_2], ) - handoffs = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(None)) + handoffs = AgentRunner._get_handoffs(agent_3) assert len(handoffs) == 2 assert handoffs[0].agent_name == "agent_1" diff --git a/tests/test_handoff_tool.py b/tests/test_handoff_tool.py index 0f7fc2166..a1b5b80ba 100644 --- a/tests/test_handoff_tool.py +++ b/tests/test_handoff_tool.py @@ -38,17 +38,16 @@ def get_len(data: HandoffInputData) -> int: return input_len + pre_handoff_len + new_items_len -@pytest.mark.asyncio -async def test_single_handoff_setup(): +def test_single_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2", handoffs=[agent_1]) assert not agent_1.handoffs assert agent_2.handoffs == [agent_1] - assert not (await AgentRunner._get_handoffs(agent_1, RunContextWrapper(agent_1))) + assert not AgentRunner._get_handoffs(agent_1) - handoff_objects = await AgentRunner._get_handoffs(agent_2, RunContextWrapper(agent_2)) + handoff_objects = AgentRunner._get_handoffs(agent_2) assert len(handoff_objects) == 1 obj = handoff_objects[0] assert obj.tool_name == Handoff.default_tool_name(agent_1) @@ -56,8 +55,7 @@ async def test_single_handoff_setup(): assert obj.agent_name == agent_1.name -@pytest.mark.asyncio -async def test_multiple_handoffs_setup(): +def test_multiple_handoffs_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent(name="test_3", handoffs=[agent_1, agent_2]) @@ -66,7 +64,7 @@ async def test_multiple_handoffs_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) + handoff_objects = AgentRunner._get_handoffs(agent_3) assert len(handoff_objects) == 2 assert handoff_objects[0].tool_name == Handoff.default_tool_name(agent_1) assert handoff_objects[1].tool_name == Handoff.default_tool_name(agent_2) @@ -78,8 +76,7 @@ async def test_multiple_handoffs_setup(): assert handoff_objects[1].agent_name == agent_2.name -@pytest.mark.asyncio -async def test_custom_handoff_setup(): +def test_custom_handoff_setup(): agent_1 = Agent(name="test_1") agent_2 = Agent(name="test_2") agent_3 = Agent( @@ -98,7 +95,7 @@ async def test_custom_handoff_setup(): assert not agent_1.handoffs assert not agent_2.handoffs - handoff_objects = await AgentRunner._get_handoffs(agent_3, RunContextWrapper(agent_3)) + handoff_objects = AgentRunner._get_handoffs(agent_3) assert len(handoff_objects) == 2 first_handoff = handoff_objects[0] @@ -287,86 +284,3 @@ def test_get_transfer_message_is_valid_json() -> None: obj = handoff(agent) transfer = obj.get_transfer_message(agent) assert json.loads(transfer) == {"assistant": agent.name} - - -def test_handoff_is_enabled_bool(): - """Test that handoff respects is_enabled boolean parameter.""" - agent = Agent(name="test") - - # Test enabled handoff (default) - handoff_enabled = handoff(agent) - assert handoff_enabled.is_enabled is True - - # Test explicitly enabled handoff - handoff_explicit_enabled = handoff(agent, is_enabled=True) - assert handoff_explicit_enabled.is_enabled is True - - # Test disabled handoff - handoff_disabled = handoff(agent, is_enabled=False) - assert handoff_disabled.is_enabled is False - - -@pytest.mark.asyncio -async def test_handoff_is_enabled_callable(): - """Test that handoff respects is_enabled callable parameter.""" - agent = Agent(name="test") - - # Test callable that returns True - def always_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: - return True - - handoff_callable_enabled = handoff(agent, is_enabled=always_enabled) - assert callable(handoff_callable_enabled.is_enabled) - result = handoff_callable_enabled.is_enabled(RunContextWrapper(agent), agent) - assert result is True - - # Test callable that returns False - def always_disabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: - return False - - handoff_callable_disabled = handoff(agent, is_enabled=always_disabled) - assert callable(handoff_callable_disabled.is_enabled) - result = handoff_callable_disabled.is_enabled(RunContextWrapper(agent), agent) - assert result is False - - # Test async callable - async def async_enabled(ctx: RunContextWrapper[Any], agent: Agent[Any]) -> bool: - return True - - handoff_async_enabled = handoff(agent, is_enabled=async_enabled) - assert callable(handoff_async_enabled.is_enabled) - result = await handoff_async_enabled.is_enabled(RunContextWrapper(agent), agent) # type: ignore - assert result is True - - -@pytest.mark.asyncio -async def test_handoff_is_enabled_filtering_integration(): - """Integration test that disabled handoffs are filtered out by the runner.""" - - # Set up agents - agent_1 = Agent(name="agent_1") - agent_2 = Agent(name="agent_2") - agent_3 = Agent(name="agent_3") - - # Create main agent with mixed enabled/disabled handoffs - main_agent = Agent( - name="main_agent", - handoffs=[ - handoff(agent_1, is_enabled=True), # enabled - handoff(agent_2, is_enabled=False), # disabled - handoff(agent_3, is_enabled=lambda ctx, agent: True), # enabled callable - ], - ) - - context_wrapper = RunContextWrapper(main_agent) - - # Get filtered handoffs using the runner's method - filtered_handoffs = await AgentRunner._get_handoffs(main_agent, context_wrapper) - - # Should only have 2 handoffs (agent_1 and agent_3), agent_2 should be filtered out - assert len(filtered_handoffs) == 2 - - # Check that the correct agents are present - agent_names = {h.agent_name for h in filtered_handoffs} - assert agent_names == {"agent_1", "agent_3"} - assert "agent_2" not in agent_names diff --git a/tests/test_run_step_execution.py b/tests/test_run_step_execution.py index 4cf9ae832..2454a4462 100644 --- a/tests/test_run_step_execution.py +++ b/tests/test_run_step_execution.py @@ -325,7 +325,7 @@ async def get_execute_result( run_config: RunConfig | None = None, ) -> SingleStepResult: output_schema = AgentRunner._get_output_schema(agent) - handoffs = await AgentRunner._get_handoffs(agent, context_wrapper or RunContextWrapper(None)) + handoffs = AgentRunner._get_handoffs(agent) processed_response = RunImpl.process_model_response( agent=agent, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 6a2904791..5a75ec837 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -186,7 +186,7 @@ async def test_handoffs_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), + handoffs=AgentRunner._get_handoffs(agent_3), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 1, "Should have a handoff here" @@ -216,7 +216,7 @@ async def test_missing_handoff_fails(): agent=agent_3, response=response, output_schema=None, - handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), + handoffs=AgentRunner._get_handoffs(agent_3), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) @@ -239,7 +239,7 @@ async def test_multiple_handoffs_doesnt_error(): agent=agent_3, response=response, output_schema=None, - handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), + handoffs=AgentRunner._get_handoffs(agent_3), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert len(result.handoffs) == 2, "Should have multiple handoffs here" @@ -471,7 +471,7 @@ async def test_tool_and_handoff_parsed_correctly(): agent=agent_3, response=response, output_schema=None, - handoffs=await AgentRunner._get_handoffs(agent_3, _dummy_ctx()), + handoffs=AgentRunner._get_handoffs(agent_3), all_tools=await agent_3.get_all_tools(_dummy_ctx()), ) assert result.functions and len(result.functions) == 1