diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index c1e99f1a2..8af4bd042 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,7 +14,7 @@ from ..types.content import ContentBlock, Messages from ..types.streaming import StreamEvent -from ..types.tools import ToolSpec +from ..types.tools import ToolResult, ToolSpec from .openai import OpenAIModel logger = logging.getLogger(__name__) @@ -103,6 +103,93 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] return super().format_request_message_content(content) + @override + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + """Format a LiteLLM compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + LiteLLM compatible tool message with plain string content. + """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + + # Extract plain text content for LiteLLM/OpenRouter compatibility + content_text = "" + if contents: + if "text" in contents[0]: + content_text = contents[0]["text"] + elif "json" in contents[0]: + content_text = json.dumps(contents[0]["json"]) + + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": content_text, # Plain string instead of content blocks + } + + @override + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + # Handle content format for messages with tool_calls + if formatted_tool_calls: + # Extract text from first content block if available + content_value = None + if formatted_contents and isinstance(formatted_contents, list): + if formatted_contents[0].get("text"): + content_value = formatted_contents[0]["text"] + formatted_message = { + "role": message["role"], + "content": content_value, # String/None for messages with tool_calls + "tool_calls": formatted_tool_calls, + } + else: + formatted_message = { + "role": message["role"], + "content": formatted_contents, + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + @override async def stream( self,