Skip to content

Fix enum schema flattening #600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 134 additions & 10 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,98 @@ def my_tool(param1: str, param2: int = 42) -> dict:
logger = logging.getLogger(__name__)


def _resolve_json_schema_references(schema: dict[str, Any]) -> dict[str, Any]:
"""Resolve all $ref references in a JSON schema by inlining definitions.

Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema
$ref references. This function flattens the schema by replacing all $ref
occurrences with their actual definitions from the $defs section.

This is particularly important for Pydantic-generated schemas that use $defs
for enum types, as these would otherwise cause validation errors with certain
model providers.

Args:
schema: A JSON schema dict that may contain $ref references and a $defs section

Returns:
A new schema dict with all $ref references replaced by their definitions.
The $defs section is removed from the result.

Example:
Input schema with $ref:
{
"$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}},
"properties": {"color": {"$ref": "#/$defs/Color"}}
}

Output schema with resolved reference:
{
"properties": {"color": {"type": "string", "enum": ["red", "blue"]}}
}
"""
# Get definitions if they exist
defs = schema.get("$defs", {})
if not defs:
return schema

def resolve_node(node: Any) -> Any:
"""Recursively process a schema node, replacing any $ref with actual definitions.

Args:
node: Any value from the schema (dict, list, or primitive)

Returns:
The node with all $ref references resolved
"""
if not isinstance(node, dict):
return node

# If this node is a $ref, replace it with the referenced definition
if "$ref" in node:
# Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color")
ref_name = node["$ref"].split("/")[-1]
if ref_name in defs:
# Copy the referenced definition to avoid modifying the original
resolved = defs[ref_name].copy()
# Preserve any additional properties from the $ref node (e.g., "default", "description")
for key, value in node.items():
if key != "$ref":
resolved[key] = value
# Recursively resolve in case the definition itself contains references
return resolve_node(resolved)
# If reference not found, return as-is (shouldn't happen with valid schemas)
return node

# For dict nodes, recursively process all values
result: dict[str, Any] = {}
for key, value in node.items():
if isinstance(value, list):
# For arrays, resolve each item
result[key] = [resolve_node(item) for item in value]
elif isinstance(value, dict):
# For objects, check if this is a properties dict that needs special handling
if key == "properties" and isinstance(value, dict):
# Ensure all property definitions are fully resolved
result[key] = {
prop_name: resolve_node(prop_schema)
for prop_name, prop_schema in value.items()
}
else:
result[key] = resolve_node(value)
else:
# Primitive values are copied as-is
result[key] = value
return result

# Process the entire schema, excluding the $defs section from the result
result = {
key: resolve_node(value) for key, value in schema.items() if key != "$defs"
}

return result


# Type for wrapped function
T = TypeVar("T", bound=Callable[..., Any])

Expand Down Expand Up @@ -101,7 +193,8 @@ def __init__(self, func: Callable[..., Any]) -> None:

# Get parameter descriptions from parsed docstring
self.param_descriptions = {
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params
param.arg_name: param.description or f"Parameter {param.arg_name}"
for param in self.doc.params
}

# Create a Pydantic model for validation
Expand Down Expand Up @@ -131,7 +224,10 @@ def _create_input_model(self) -> Type[BaseModel]:
description = self.param_descriptions.get(name, f"Parameter {name}")

# Create Field with description and default
field_definitions[name] = (param_type, Field(default=default, description=description))
field_definitions[name] = (
param_type,
Field(default=default, description=description),
)

# Create model name based on function name
model_name = f"{self.func.__name__.capitalize()}Tool"
Expand Down Expand Up @@ -173,8 +269,17 @@ def extract_metadata(self) -> ToolSpec:
# Clean up Pydantic-specific schema elements
self._clean_pydantic_schema(input_schema)

# Flatten schema by resolving $ref references to their definitions
# This is required for compatibility with model providers that don't support
# JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM)
input_schema = _resolve_json_schema_references(input_schema)

# Create tool specification
tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}}
tool_spec: ToolSpec = {
"name": func_name,
"description": description,
"inputSchema": {"json": input_schema},
}

return tool_spec

Expand Down Expand Up @@ -206,7 +311,9 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
if "anyOf" in prop_schema:
any_of = prop_schema["anyOf"]
# Handle Optional[Type] case (represented as anyOf[Type, null])
if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of):
if len(any_of) == 2 and any(
item.get("type") == "null" for item in any_of
):
# Find the non-null type
for item in any_of:
if item.get("type") != "null":
Expand Down Expand Up @@ -250,7 +357,9 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
except Exception as e:
# Re-raise with more detailed error message
error_msg = str(e)
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e
raise ValueError(
f"Validation failed for input parameters: {error_msg}"
) from e


P = ParamSpec("P") # Captures all parameters
Expand Down Expand Up @@ -296,7 +405,9 @@ def __init__(

functools.update_wrapper(wrapper=self, wrapped=self._tool_func)

def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]":
def __get__(
self, instance: Any, obj_type: Optional[Type] = None
) -> "DecoratedFunctionTool[P, R]":
"""Descriptor protocol implementation for proper method binding.

This method enables the decorated function to work correctly when used as a class method.
Expand Down Expand Up @@ -325,7 +436,9 @@ def my_tool():
if instance is not None and not inspect.ismethod(self._tool_func):
# Create a bound method
tool_func = self._tool_func.__get__(instance, instance.__class__)
return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata)
return DecoratedFunctionTool(
self._tool_name, self._tool_spec, tool_func, self._metadata
)

return self

Expand Down Expand Up @@ -372,7 +485,9 @@ def tool_type(self) -> str:
return "function"

@override
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
async def stream(
self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any
) -> ToolGenerator:
"""Stream the tool with a tool use specification.

This method handles tool use streams from a Strands Agent. It validates the input,
Expand Down Expand Up @@ -403,7 +518,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
validated_input = self._metadata.validate_input(tool_input)

# Pass along the agent if provided and expected by the function
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
if (
"agent" in invocation_state
and "agent" in self._metadata.signature.parameters
):
validated_input["agent"] = invocation_state.get("agent")

# "Too few arguments" expected, hence the type ignore
Expand Down Expand Up @@ -468,21 +586,27 @@ def get_display_properties(self) -> dict[str, str]:
# Handle @decorator
@overload
def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ...


# Handle @decorator()
@overload
def tool(
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...


# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
# call site, but the actual implementation handles that and it's not representable via the type-system
def tool( # type: ignore
func: Optional[Callable[P, R]] = None,
description: Optional[str] = None,
inputSchema: Optional[JSONSchema] = None,
name: Optional[str] = None,
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
) -> Union[
DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]
]:
"""Decorator that transforms a Python function into a Strands tool.

This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool.
Expand Down