Skip to content

Commit adac26f

Browse files
fix(structured_output): do not modify conversation_history when prompt is passed (#628)
1 parent 29b2127 commit adac26f

File tree

3 files changed

+67
-15
lines changed

3 files changed

+67
-15
lines changed

src/strands/agent/agent.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -403,16 +403,16 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
403403
def structured_output(self, output_model: Type[T], prompt: Optional[Union[str, list[ContentBlock]]] = None) -> T:
404404
"""This method allows you to get structured output from the agent.
405405
406-
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
407-
If you don't pass in a prompt, it will use only the conversation history to respond.
406+
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
407+
If you don't pass in a prompt, it will use only the existing conversation history to respond.
408408
409409
For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
410410
instruct the model to output the structured data.
411411
412412
Args:
413413
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
414414
that the agent will use when responding.
415-
prompt: The prompt to use for the agent.
415+
prompt: The prompt to use for the agent (will not be added to conversation history).
416416
417417
Raises:
418418
ValueError: If no conversation history or prompt is provided.
@@ -430,16 +430,16 @@ async def structured_output_async(
430430
) -> T:
431431
"""This method allows you to get structured output from the agent.
432432
433-
If you pass in a prompt, it will be added to the conversation history and the agent will respond to it.
434-
If you don't pass in a prompt, it will use only the conversation history to respond.
433+
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
434+
If you don't pass in a prompt, it will use only the existing conversation history to respond.
435435
436436
For smaller models, you may want to use the optional prompt to add additional instructions to explicitly
437437
instruct the model to output the structured data.
438438
439439
Args:
440440
output_model: The output model (a JSON schema written as a Pydantic BaseModel)
441441
that the agent will use when responding.
442-
prompt: The prompt to use for the agent.
442+
prompt: The prompt to use for the agent (will not be added to conversation history).
443443
444444
Raises:
445445
ValueError: If no conversation history or prompt is provided.
@@ -450,12 +450,14 @@ async def structured_output_async(
450450
if not self.messages and not prompt:
451451
raise ValueError("No conversation history or prompt provided")
452452

453-
# add the prompt as the last message
453+
# Create temporary messages array if prompt is provided
454454
if prompt:
455455
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
456-
self._append_message({"role": "user", "content": content})
456+
temp_messages = self.messages + [{"role": "user", "content": content}]
457+
else:
458+
temp_messages = self.messages
457459

458-
events = self.model.structured_output(output_model, self.messages, system_prompt=self.system_prompt)
460+
events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt)
459461
async for event in events:
460462
if "callback" in event:
461463
self.callback_handler(**cast(dict, event["callback"]))

tests/strands/agent/test_agent.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,10 +984,17 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator):
984984

985985
prompt = "Jane Doe is 30 years old and her email is [email protected]"
986986

987+
# Store initial message count
988+
initial_message_count = len(agent.messages)
989+
987990
tru_result = agent.structured_output(type(user), prompt)
988991
exp_result = user
989992
assert tru_result == exp_result
990993

994+
# Verify conversation history is not polluted
995+
assert len(agent.messages) == initial_message_count
996+
997+
# Verify the model was called with temporary messages array
991998
agent.model.structured_output.assert_called_once_with(
992999
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
9931000
)
@@ -1008,10 +1015,17 @@ def test_agent_structured_output_multi_modal_input(agent, system_prompt, user, a
10081015
},
10091016
]
10101017

1018+
# Store initial message count
1019+
initial_message_count = len(agent.messages)
1020+
10111021
tru_result = agent.structured_output(type(user), prompt)
10121022
exp_result = user
10131023
assert tru_result == exp_result
10141024

1025+
# Verify conversation history is not polluted
1026+
assert len(agent.messages) == initial_message_count
1027+
1028+
# Verify the model was called with temporary messages array
10151029
agent.model.structured_output.assert_called_once_with(
10161030
type(user), [{"role": "user", "content": prompt}], system_prompt=system_prompt
10171031
)
@@ -1023,21 +1037,59 @@ async def test_agent_structured_output_in_async_context(agent, user, agenerator)
10231037

10241038
prompt = "Jane Doe is 30 years old and her email is [email protected]"
10251039

1040+
# Store initial message count
1041+
initial_message_count = len(agent.messages)
1042+
10261043
tru_result = await agent.structured_output_async(type(user), prompt)
10271044
exp_result = user
10281045
assert tru_result == exp_result
10291046

1047+
# Verify conversation history is not polluted
1048+
assert len(agent.messages) == initial_message_count
1049+
1050+
1051+
def test_agent_structured_output_without_prompt(agent, system_prompt, user, agenerator):
1052+
"""Test that structured_output works with existing conversation history and no new prompt."""
1053+
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
1054+
1055+
# Add some existing messages to the agent
1056+
existing_messages = [
1057+
{"role": "user", "content": [{"text": "Jane Doe is 30 years old"}]},
1058+
{"role": "assistant", "content": [{"text": "I understand."}]},
1059+
]
1060+
agent.messages.extend(existing_messages)
1061+
1062+
initial_message_count = len(agent.messages)
1063+
1064+
tru_result = agent.structured_output(type(user)) # No prompt provided
1065+
exp_result = user
1066+
assert tru_result == exp_result
1067+
1068+
# Verify conversation history is unchanged
1069+
assert len(agent.messages) == initial_message_count
1070+
assert agent.messages == existing_messages
1071+
1072+
# Verify the model was called with existing messages only
1073+
agent.model.structured_output.assert_called_once_with(type(user), existing_messages, system_prompt=system_prompt)
1074+
10301075

10311076
@pytest.mark.asyncio
10321077
async def test_agent_structured_output_async(agent, system_prompt, user, agenerator):
10331078
agent.model.structured_output = unittest.mock.Mock(return_value=agenerator([{"output": user}]))
10341079

10351080
prompt = "Jane Doe is 30 years old and her email is [email protected]"
10361081

1082+
# Store initial message count
1083+
initial_message_count = len(agent.messages)
1084+
10371085
tru_result = agent.structured_output(type(user), prompt)
10381086
exp_result = user
10391087
assert tru_result == exp_result
10401088

1089+
# Verify conversation history is not polluted
1090+
assert len(agent.messages) == initial_message_count
1091+
1092+
# Verify the model was called with temporary messages array
10411093
agent.model.structured_output.assert_called_once_with(
10421094
type(user), [{"role": "user", "content": [{"text": prompt}]}], system_prompt=system_prompt
10431095
)

tests/strands/agent/test_agent_hooks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,12 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
267267

268268
length, events = hook_provider.get_events()
269269

270-
assert length == 3
270+
assert length == 2
271271

272272
assert next(events) == BeforeInvocationEvent(agent=agent)
273-
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
274273
assert next(events) == AfterInvocationEvent(agent=agent)
275274

276-
assert len(agent.messages) == 1
275+
assert len(agent.messages) == 0 # no new messages added
277276

278277

279278
@pytest.mark.asyncio
@@ -285,10 +284,9 @@ async def test_agent_structured_async_output_hooks(agent, hook_provider, user, a
285284

286285
length, events = hook_provider.get_events()
287286

288-
assert length == 3
287+
assert length == 2
289288

290289
assert next(events) == BeforeInvocationEvent(agent=agent)
291-
assert next(events) == MessageAddedEvent(agent=agent, message=agent.messages[0])
292290
assert next(events) == AfterInvocationEvent(agent=agent)
293291

294-
assert len(agent.messages) == 1
292+
assert len(agent.messages) == 0 # no new messages added

0 commit comments

Comments
 (0)