diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 8643248b1..3e5903af8 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,6 +1,7 @@ -from typing import Any, Generic +from typing import Any, Generic, Optional from .agent import Agent +from .items import ModelResponse, TResponseInputItem from .run_context import RunContextWrapper, TContext from .tool import Tool @@ -10,6 +11,25 @@ class RunHooks(Generic[TContext]): override the methods you need. """ + # Two new hook methods added to the RunHooks class to handle LLM start and end events. + # These methods allow you to perform actions just before and after the LLM call for an agent. + # This is useful for logging, monitoring, or modifying the context before and after the LLM call + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + """Called just before invoking the LLM for this agent.""" + pass + + async def on_llm_end( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], response: ModelResponse + ) -> None: + """Called immediately after the LLM call returns for this agent.""" + pass + async def on_agent_start( self, context: RunContextWrapper[TContext], agent: Agent[TContext] ) -> None: @@ -103,3 +123,22 @@ async def on_tool_end( ) -> None: """Called after a tool is invoked.""" pass + + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + """Called immediately before the agent issues an LLM call.""" + pass + + async def on_llm_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + """Called immediately after the agent receives the LLM response.""" + pass diff --git a/src/agents/run.py b/src/agents/run.py index e5f9378ec..841813bd5 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1067,6 +1067,9 @@ async def _get_new_response( model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + # If the agent has hooks, we need to call them before and after the LLM call + if agent.hooks: + await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input) new_response = await model.get_response( system_instructions=system_prompt, @@ -1081,6 +1084,9 @@ async def _get_new_response( previous_response_id=previous_response_id, prompt=prompt_config, ) + # If the agent has hooks, we need to call them after the LLM call + if agent.hooks: + await agent.hooks.on_llm_end(context_wrapper, agent, new_response) context_wrapper.usage.add(new_response.usage) diff --git a/tests/test_agent_llm_hooks.py b/tests/test_agent_llm_hooks.py new file mode 100644 index 000000000..ded7e18de --- /dev/null +++ b/tests/test_agent_llm_hooks.py @@ -0,0 +1,85 @@ +from collections import defaultdict +from typing import Any, Optional + +import pytest + +from agents.agent import Agent +from agents.items import ModelResponse, TResponseInputItem +from agents.lifecycle import AgentHooks +from agents.run import Runner +from agents.run_context import RunContextWrapper, TContext +from agents.tool import Tool + +from .fake_model import FakeModel +from .test_responses import ( + get_function_tool, + get_text_message, +) + + +class AgentHooksForTests(AgentHooks): + def __init__(self): + self.events: dict[str, int] = defaultdict(int) + + def reset(self): + self.events.clear() + + async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None: + self.events["on_start"] += 1 + + async def on_end( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any + ) -> None: + self.events["on_end"] += 1 + + async def on_handoff( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext] + ) -> None: + self.events["on_handoff"] += 1 + + async def on_tool_start( + self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool + ) -> None: + self.events["on_tool_start"] += 1 + + async def on_tool_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + tool: Tool, + result: str, + ) -> None: + self.events["on_tool_end"] += 1 + + # NEW: LLM hooks + async def on_llm_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + system_prompt: Optional[str], + input_items: list[TResponseInputItem], + ) -> None: + self.events["on_llm_start"] += 1 + + async def on_llm_end( + self, + ccontext: RunContextWrapper[TContext], + agent: Agent[TContext], + response: ModelResponse, + ) -> None: + self.events["on_llm_end"] += 1 + + +# Example test using the above hooks: +@pytest.mark.asyncio +async def test_non_streamed_agent_hooks_with_llm(): + hooks = AgentHooksForTests() + model = FakeModel() + agent = Agent( + name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks + ) + # Simulate a single LLM call producing an output: + model.set_next_output([get_text_message("hello")]) + await Runner.run(agent, input="hello") + # Expect one on_start, one on_llm_start, one on_llm_end, and one on_end + assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}