Skip to content

Commit 0fe0249

Browse files
committed
fix to bug 565;
1 parent abbc460 commit 0fe0249

File tree

1 file changed

+134
-10
lines changed

1 file changed

+134
-10
lines changed

src/strands/tools/decorator.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,98 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6666
logger = logging.getLogger(__name__)
6767

6868

69+
def _resolve_json_schema_references(schema: dict[str, Any]) -> dict[str, Any]:
70+
"""Resolve all $ref references in a JSON schema by inlining definitions.
71+
72+
Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema
73+
$ref references. This function flattens the schema by replacing all $ref
74+
occurrences with their actual definitions from the $defs section.
75+
76+
This is particularly important for Pydantic-generated schemas that use $defs
77+
for enum types, as these would otherwise cause validation errors with certain
78+
model providers.
79+
80+
Args:
81+
schema: A JSON schema dict that may contain $ref references and a $defs section
82+
83+
Returns:
84+
A new schema dict with all $ref references replaced by their definitions.
85+
The $defs section is removed from the result.
86+
87+
Example:
88+
Input schema with $ref:
89+
{
90+
"$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}},
91+
"properties": {"color": {"$ref": "#/$defs/Color"}}
92+
}
93+
94+
Output schema with resolved reference:
95+
{
96+
"properties": {"color": {"type": "string", "enum": ["red", "blue"]}}
97+
}
98+
"""
99+
# Get definitions if they exist
100+
defs = schema.get("$defs", {})
101+
if not defs:
102+
return schema
103+
104+
def resolve_node(node: Any) -> Any:
105+
"""Recursively process a schema node, replacing any $ref with actual definitions.
106+
107+
Args:
108+
node: Any value from the schema (dict, list, or primitive)
109+
110+
Returns:
111+
The node with all $ref references resolved
112+
"""
113+
if not isinstance(node, dict):
114+
return node
115+
116+
# If this node is a $ref, replace it with the referenced definition
117+
if "$ref" in node:
118+
# Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color")
119+
ref_name = node["$ref"].split("/")[-1]
120+
if ref_name in defs:
121+
# Copy the referenced definition to avoid modifying the original
122+
resolved = defs[ref_name].copy()
123+
# Preserve any additional properties from the $ref node (e.g., "default", "description")
124+
for key, value in node.items():
125+
if key != "$ref":
126+
resolved[key] = value
127+
# Recursively resolve in case the definition itself contains references
128+
return resolve_node(resolved)
129+
# If reference not found, return as-is (shouldn't happen with valid schemas)
130+
return node
131+
132+
# For dict nodes, recursively process all values
133+
result: dict[str, Any] = {}
134+
for key, value in node.items():
135+
if isinstance(value, list):
136+
# For arrays, resolve each item
137+
result[key] = [resolve_node(item) for item in value]
138+
elif isinstance(value, dict):
139+
# For objects, check if this is a properties dict that needs special handling
140+
if key == "properties" and isinstance(value, dict):
141+
# Ensure all property definitions are fully resolved
142+
result[key] = {
143+
prop_name: resolve_node(prop_schema)
144+
for prop_name, prop_schema in value.items()
145+
}
146+
else:
147+
result[key] = resolve_node(value)
148+
else:
149+
# Primitive values are copied as-is
150+
result[key] = value
151+
return result
152+
153+
# Process the entire schema, excluding the $defs section from the result
154+
result = {
155+
key: resolve_node(value) for key, value in schema.items() if key != "$defs"
156+
}
157+
158+
return result
159+
160+
69161
# Type for wrapped function
70162
T = TypeVar("T", bound=Callable[..., Any])
71163

@@ -101,7 +193,8 @@ def __init__(self, func: Callable[..., Any]) -> None:
101193

102194
# Get parameter descriptions from parsed docstring
103195
self.param_descriptions = {
104-
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params
196+
param.arg_name: param.description or f"Parameter {param.arg_name}"
197+
for param in self.doc.params
105198
}
106199

107200
# Create a Pydantic model for validation
@@ -131,7 +224,10 @@ def _create_input_model(self) -> Type[BaseModel]:
131224
description = self.param_descriptions.get(name, f"Parameter {name}")
132225

133226
# Create Field with description and default
134-
field_definitions[name] = (param_type, Field(default=default, description=description))
227+
field_definitions[name] = (
228+
param_type,
229+
Field(default=default, description=description),
230+
)
135231

136232
# Create model name based on function name
137233
model_name = f"{self.func.__name__.capitalize()}Tool"
@@ -173,8 +269,17 @@ def extract_metadata(self) -> ToolSpec:
173269
# Clean up Pydantic-specific schema elements
174270
self._clean_pydantic_schema(input_schema)
175271

272+
# Flatten schema by resolving $ref references to their definitions
273+
# This is required for compatibility with model providers that don't support
274+
# JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM)
275+
input_schema = _resolve_json_schema_references(input_schema)
276+
176277
# Create tool specification
177-
tool_spec: ToolSpec = {"name": func_name, "description": description, "inputSchema": {"json": input_schema}}
278+
tool_spec: ToolSpec = {
279+
"name": func_name,
280+
"description": description,
281+
"inputSchema": {"json": input_schema},
282+
}
178283

179284
return tool_spec
180285

@@ -206,7 +311,9 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
206311
if "anyOf" in prop_schema:
207312
any_of = prop_schema["anyOf"]
208313
# Handle Optional[Type] case (represented as anyOf[Type, null])
209-
if len(any_of) == 2 and any(item.get("type") == "null" for item in any_of):
314+
if len(any_of) == 2 and any(
315+
item.get("type") == "null" for item in any_of
316+
):
210317
# Find the non-null type
211318
for item in any_of:
212319
if item.get("type") != "null":
@@ -250,7 +357,9 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
250357
except Exception as e:
251358
# Re-raise with more detailed error message
252359
error_msg = str(e)
253-
raise ValueError(f"Validation failed for input parameters: {error_msg}") from e
360+
raise ValueError(
361+
f"Validation failed for input parameters: {error_msg}"
362+
) from e
254363

255364

256365
P = ParamSpec("P") # Captures all parameters
@@ -296,7 +405,9 @@ def __init__(
296405

297406
functools.update_wrapper(wrapper=self, wrapped=self._tool_func)
298407

299-
def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]":
408+
def __get__(
409+
self, instance: Any, obj_type: Optional[Type] = None
410+
) -> "DecoratedFunctionTool[P, R]":
300411
"""Descriptor protocol implementation for proper method binding.
301412
302413
This method enables the decorated function to work correctly when used as a class method.
@@ -325,7 +436,9 @@ def my_tool():
325436
if instance is not None and not inspect.ismethod(self._tool_func):
326437
# Create a bound method
327438
tool_func = self._tool_func.__get__(instance, instance.__class__)
328-
return DecoratedFunctionTool(self._tool_name, self._tool_spec, tool_func, self._metadata)
439+
return DecoratedFunctionTool(
440+
self._tool_name, self._tool_spec, tool_func, self._metadata
441+
)
329442

330443
return self
331444

@@ -372,7 +485,9 @@ def tool_type(self) -> str:
372485
return "function"
373486

374487
@override
375-
async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any) -> ToolGenerator:
488+
async def stream(
489+
self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs: Any
490+
) -> ToolGenerator:
376491
"""Stream the tool with a tool use specification.
377492
378493
This method handles tool use streams from a Strands Agent. It validates the input,
@@ -403,7 +518,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
403518
validated_input = self._metadata.validate_input(tool_input)
404519

405520
# Pass along the agent if provided and expected by the function
406-
if "agent" in invocation_state and "agent" in self._metadata.signature.parameters:
521+
if (
522+
"agent" in invocation_state
523+
and "agent" in self._metadata.signature.parameters
524+
):
407525
validated_input["agent"] = invocation_state.get("agent")
408526

409527
# "Too few arguments" expected, hence the type ignore
@@ -468,21 +586,27 @@ def get_display_properties(self) -> dict[str, str]:
468586
# Handle @decorator
469587
@overload
470588
def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ...
589+
590+
471591
# Handle @decorator()
472592
@overload
473593
def tool(
474594
description: Optional[str] = None,
475595
inputSchema: Optional[JSONSchema] = None,
476596
name: Optional[str] = None,
477597
) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ...
598+
599+
478600
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
479601
# call site, but the actual implementation handles that and it's not representable via the type-system
480602
def tool( # type: ignore
481603
func: Optional[Callable[P, R]] = None,
482604
description: Optional[str] = None,
483605
inputSchema: Optional[JSONSchema] = None,
484606
name: Optional[str] = None,
485-
) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]:
607+
) -> Union[
608+
DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]
609+
]:
486610
"""Decorator that transforms a Python function into a Strands tool.
487611
488612
This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool.

0 commit comments

Comments
 (0)