Skip to content

Commit 3ad53b6

Browse files
committed
fix: Properly handle prompt=None & avoid agent hanging
bedrock.py now catches all exceptions in _stream so it no longer hangs when invalid content is passed. In addition, since we don't allow agent(None), go ahead and validate that none is not passed throughout our agent calls.
1 parent adac26f commit 3ad53b6

File tree

5 files changed

+52
-11
lines changed

5 files changed

+52
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ test-integ = [
234234
"hatch test tests_integ {args}"
235235
]
236236
prepare = [
237-
"hatch fmt --linter",
238237
"hatch fmt --formatter",
238+
"hatch fmt --linter",
239239
"hatch run test-lint",
240240
"hatch test --all"
241241
]

src/strands/agent/agent.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,9 @@ def __call__(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> Age
367367
- message: The final message from the model
368368
- metrics: Performance metrics from the event loop
369369
- state: The final state of the event loop
370+
371+
Raises:
372+
ValueError: If prompt is None.
370373
"""
371374

372375
def execute() -> AgentResult:
@@ -393,6 +396,9 @@ async def invoke_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
393396
- message: The final message from the model
394397
- metrics: Performance metrics from the event loop
395398
- state: The final state of the event loop
399+
400+
Raises:
401+
ValueError: If prompt is None.
396402
"""
397403
events = self.stream_async(prompt, **kwargs)
398404
async for event in events:
@@ -452,8 +458,7 @@ async def structured_output_async(
452458

453459
# Create temporary messages array if prompt is provided
454460
if prompt:
455-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
456-
temp_messages = self.messages + [{"role": "user", "content": content}]
461+
temp_messages = self.messages + self._standardize_prompt(prompt)
457462
else:
458463
temp_messages = self.messages
459464

@@ -489,6 +494,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
489494
- And other event data provided by the callback handler
490495
491496
Raises:
497+
ValueError: If prompt is None.
492498
Exception: Any exceptions from the agent invocation will be propagated to the caller.
493499
494500
Example:
@@ -500,8 +506,7 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
500506
"""
501507
callback_handler = kwargs.get("callback_handler", self.callback_handler)
502508

503-
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
504-
message: Message = {"role": "user", "content": content}
509+
message = self._standardize_prompt(prompt)
505510

506511
self.trace_span = self._start_agent_trace_span(message)
507512
with trace_api.use_span(self.trace_span):
@@ -563,6 +568,15 @@ async def _run_loop(
563568
self.conversation_manager.apply_management(self)
564569
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
565570

571+
def _standardize_prompt(self, prompt: Union[str, list[ContentBlock]]) -> Message:
572+
"""Convert the prompt into a Message, validating it along the way."""
573+
if prompt is None:
574+
raise ValueError("User prompt must not be None")
575+
576+
content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt
577+
message: Message = {"role": "user", "content": content}
578+
return message
579+
566580
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
567581
"""Execute the event loop cycle with retry logic for context window limits.
568582

src/strands/models/bedrock.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,14 @@ def _stream(
418418
ContextWindowOverflowException: If the input exceeds the model's context window.
419419
ModelThrottledException: If the model service is throttling requests.
420420
"""
421-
logger.debug("formatting request")
422-
request = self.format_request(messages, tool_specs, system_prompt)
423-
logger.debug("request=<%s>", request)
421+
try:
422+
logger.debug("formatting request")
423+
request = self.format_request(messages, tool_specs, system_prompt)
424+
logger.debug("request=<%s>", request)
424425

425-
logger.debug("invoking model")
426-
streaming = self.config.get("streaming", True)
426+
logger.debug("invoking model")
427+
streaming = self.config.get("streaming", True)
427428

428-
try:
429429
logger.debug("got response from model")
430430
if streaming:
431431
response = self.client.converse_stream(**request)

tests/strands/agent/test_agent.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,24 @@ async def test_agent__call__in_async_context(mock_model, agent, agenerator):
750750
assert tru_message == exp_message
751751

752752

753+
@pytest.mark.asyncio
754+
async def test_agent_invocations_prompt_validation(agent, alist):
755+
with pytest.raises(ValueError):
756+
await agent.invoke_async(prompt=None)
757+
758+
with pytest.raises(ValueError):
759+
await agent(prompt=None)
760+
761+
with pytest.raises(ValueError):
762+
await alist(agent.stream_async(prompt=None))
763+
764+
with pytest.raises(ValueError):
765+
await agent.structured_output(type(user), prompt=None)
766+
767+
with pytest.raises(ValueError):
768+
await agent.structured_output_async(type(user), prompt=None)
769+
770+
753771
@pytest.mark.asyncio
754772
async def test_agent_invoke_async(mock_model, agent, agenerator):
755773
mock_model.mock_stream.return_value = agenerator(

tests/strands/models/test_bedrock.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,15 @@ async def test_stream_throttling_exception_from_event_stream_error(bedrock_clien
419419
)
420420

421421

422+
@pytest.mark.asyncio
423+
async def test_stream_with_invalid_content_throws(bedrock_client, model, alist):
424+
# We used to hang on None, so ensure we don't regress: https://github.com/strands-agents/sdk-python/issues/642
425+
messages = [{"role": "user", "content": None}]
426+
427+
with pytest.raises(TypeError):
428+
await alist(model.stream(messages))
429+
430+
422431
@pytest.mark.asyncio
423432
async def test_stream_throttling_exception_from_general_exception(bedrock_client, model, messages, alist):
424433
error_message = "ThrottlingException: Rate exceeded for ConverseStream"

0 commit comments

Comments
 (0)