diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5ec324b68..ffac02ead 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -66,6 +66,98 @@ def my_tool(param1: str, param2: int = 42) -> dict: logger = logging.getLogger(__name__) +def _resolve_json_schema_references(schema: dict[str, Any]) -> dict[str, Any]: + """Resolve all $ref references in a JSON schema by inlining definitions. + + Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema + $ref references. This function flattens the schema by replacing all $ref + occurrences with their actual definitions from the $defs section. + + This is particularly important for Pydantic-generated schemas that use $defs + for enum types, as these would otherwise cause validation errors with certain + model providers. + + Args: + schema: A JSON schema dict that may contain $ref references and a $defs section + + Returns: + A new schema dict with all $ref references replaced by their definitions. + The $defs section is removed from the result. + + Example: + Input schema with $ref: + { + "$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}}, + "properties": {"color": {"$ref": "#/$defs/Color"}} + } + + Output schema with resolved reference: + { + "properties": {"color": {"type": "string", "enum": ["red", "blue"]}} + } + """ + # Get definitions if they exist + defs = schema.get("$defs", {}) + if not defs: + return schema + + def resolve_node(node: Any) -> Any: + """Recursively process a schema node, replacing any $ref with actual definitions. + + Args: + node: Any value from the schema (dict, list, or primitive) + + Returns: + The node with all $ref references resolved + """ + if not isinstance(node, dict): + return node + + # If this node is a $ref, replace it with the referenced definition + if "$ref" in node: + # Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color") + ref_name = node["$ref"].split("/")[-1] + if ref_name in defs: + # Copy the referenced definition to avoid modifying the original + resolved = defs[ref_name].copy() + # Preserve any additional properties from the $ref node (e.g., "default", "description") + for key, value in node.items(): + if key != "$ref": + resolved[key] = value + # Recursively resolve in case the definition itself contains references + return resolve_node(resolved) + # If reference not found, return as-is (shouldn't happen with valid schemas) + return node + + # For dict nodes, recursively process all values + result: dict[str, Any] = {} + for key, value in node.items(): + if isinstance(value, list): + # For arrays, resolve each item + result[key] = [resolve_node(item) for item in value] + elif isinstance(value, dict): + # For objects, check if this is a properties dict that needs special handling + if key == "properties" and isinstance(value, dict): + # Ensure all property definitions are fully resolved + result[key] = { + prop_name: resolve_node(prop_schema) + for prop_name, prop_schema in value.items() + } + else: + result[key] = resolve_node(value) + else: + # Primitive values are copied as-is + result[key] = value + return result + + # Process the entire schema, excluding the $defs section from the result + result = { + key: resolve_node(value) for key, value in schema.items() if key != "$defs" + } + + return result + + # Type for wrapped function T = TypeVar("T", bound=Callable[..., Any]) @@ -101,7 +193,8 @@ def __init__(self, func: Callable[..., Any]) -> None: # Get parameter descriptions from parsed docstring self.param_descriptions = { - param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params + param.arg_name: param.description or f"Parameter {param.arg_name}" + for param in self.doc.params } # Create a Pydantic model for validation @@ -131,7 +224,10 @@ def _create_input_model(self) -> Type[BaseModel]: description = self.param_descriptions.get(name, f"Parameter {name}") # Create Field with description and default - field_definitions[name] = (param_type, Field(default=default, description=description)) + field_definitions[name] = ( + param_type, + Field(default=default, description=description), + ) # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" @@ -173,8 +269,17 @@ def extract_metadata(self) -> ToolSpec: # Clean up Pydantic-specific schema elements self._clean_pydantic_schema(input_schema) + # Flatten schema by resolving $ref references to their definitions + # This is required for compatibility with model providers that don't support + # JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM) + input_schema = _resolve_json_schema_references(input_schema) + # Create tool specification - tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}} + tool_spec: ToolSpec = { + "name": func_name, + "description": description, + "inputSchema": {"json": input_schema}, + } return tool_spec @@ -206,7 +311,9 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None: if "anyOf" in prop_schema: any_of = prop_schema["anyOf"] # Handle Optional[Type] case (represented as anyOf[Type, null]) - if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of): + if len(any_of) == 2 and any( + item.get("type") == "null" for item in any_of + ): # Find the non-null type for item in any_of: if item.get("type") != "null": @@ -250,7 +357,9 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]: except Exception as e: # Re-raise with more detailed error message error_msg = str(e) - raise ValueError(f"Validation failed for input parameters: {error_msg}") from e + raise ValueError( + f"Validation failed for input parameters: {error_msg}" + ) from e P = ParamSpec("P") # Captures all parameters @@ -296,7 +405,9 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": + def __get__( + self, instance: Any, obj_type: Optional[Type] = None + ) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -325,7 +436,9 @@ def my_tool(): if instance is not None and not inspect.ismethod(self._tool_func): # Create a bound method tool_func = self._tool_func.__get__(instance, instance.__class__) - return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata) + return DecoratedFunctionTool( + self._tool_name, self._tool_spec, tool_func, self._metadata + ) return self @@ -372,7 +485,9 @@ def tool_type(self) -> str: return "function" @override - async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator: + async def stream( + self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any + ) -> ToolGenerator: """Stream the tool with a tool use specification. This method handles tool use streams from a Strands Agent. It validates the input, @@ -403,7 +518,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw validated_input = self._metadata.validate_input(tool_input) # Pass along the agent if provided and expected by the function - if "agent" in invocation_state and "agent" in self._metadata.signature.parameters: + if ( + "agent" in invocation_state + and "agent" in self._metadata.signature.parameters + ): validated_input["agent"] = invocation_state.get("agent") # "Too few arguments" expected, hence the type ignore @@ -468,6 +586,8 @@ def get_display_properties(self) -> dict[str, str]: # Handle @decorator @overload def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... + + # Handle @decorator() @overload def tool( @@ -475,6 +595,8 @@ def tool( inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... + + # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore @@ -482,7 +604,9 @@ def tool( # type: ignore description: Optional[str] = None, inputSchema: Optional[JSONSchema] = None, name: Optional[str] = None, -) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: +) -> Union[ + DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]] +]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool.