@@ -66,6 +66,98 @@ def my_tool(param1: str, param2: int = 42) -> dict:
66
66
logger = logging .getLogger (__name__ )
67
67
68
68
69
+ def _resolve_json_schema_references (schema : dict [str , Any ]) -> dict [str , Any ]:
70
+ """Resolve all $ref references in a JSON schema by inlining definitions.
71
+
72
+ Some model providers (e.g., Bedrock via LiteLLM) don't support JSON Schema
73
+ $ref references. This function flattens the schema by replacing all $ref
74
+ occurrences with their actual definitions from the $defs section.
75
+
76
+ This is particularly important for Pydantic-generated schemas that use $defs
77
+ for enum types, as these would otherwise cause validation errors with certain
78
+ model providers.
79
+
80
+ Args:
81
+ schema: A JSON schema dict that may contain $ref references and a $defs section
82
+
83
+ Returns:
84
+ A new schema dict with all $ref references replaced by their definitions.
85
+ The $defs section is removed from the result.
86
+
87
+ Example:
88
+ Input schema with $ref:
89
+ {
90
+ "$defs": {"Color": {"type": "string", "enum": ["red", "blue"]}},
91
+ "properties": {"color": {"$ref": "#/$defs/Color"}}
92
+ }
93
+
94
+ Output schema with resolved reference:
95
+ {
96
+ "properties": {"color": {"type": "string", "enum": ["red", "blue"]}}
97
+ }
98
+ """
99
+ # Get definitions if they exist
100
+ defs = schema .get ("$defs" , {})
101
+ if not defs :
102
+ return schema
103
+
104
+ def resolve_node (node : Any ) -> Any :
105
+ """Recursively process a schema node, replacing any $ref with actual definitions.
106
+
107
+ Args:
108
+ node: Any value from the schema (dict, list, or primitive)
109
+
110
+ Returns:
111
+ The node with all $ref references resolved
112
+ """
113
+ if not isinstance (node , dict ):
114
+ return node
115
+
116
+ # If this node is a $ref, replace it with the referenced definition
117
+ if "$ref" in node :
118
+ # Extract the definition name from the reference (e.g., "#/$defs/Color" -> "Color")
119
+ ref_name = node ["$ref" ].split ("/" )[- 1 ]
120
+ if ref_name in defs :
121
+ # Copy the referenced definition to avoid modifying the original
122
+ resolved = defs [ref_name ].copy ()
123
+ # Preserve any additional properties from the $ref node (e.g., "default", "description")
124
+ for key , value in node .items ():
125
+ if key != "$ref" :
126
+ resolved [key ] = value
127
+ # Recursively resolve in case the definition itself contains references
128
+ return resolve_node (resolved )
129
+ # If reference not found, return as-is (shouldn't happen with valid schemas)
130
+ return node
131
+
132
+ # For dict nodes, recursively process all values
133
+ result : dict [str , Any ] = {}
134
+ for key , value in node .items ():
135
+ if isinstance (value , list ):
136
+ # For arrays, resolve each item
137
+ result [key ] = [resolve_node (item ) for item in value ]
138
+ elif isinstance (value , dict ):
139
+ # For objects, check if this is a properties dict that needs special handling
140
+ if key == "properties" and isinstance (value , dict ):
141
+ # Ensure all property definitions are fully resolved
142
+ result [key ] = {
143
+ prop_name : resolve_node (prop_schema )
144
+ for prop_name , prop_schema in value .items ()
145
+ }
146
+ else :
147
+ result [key ] = resolve_node (value )
148
+ else :
149
+ # Primitive values are copied as-is
150
+ result [key ] = value
151
+ return result
152
+
153
+ # Process the entire schema, excluding the $defs section from the result
154
+ result = {
155
+ key : resolve_node (value ) for key , value in schema .items () if key != "$defs"
156
+ }
157
+
158
+ return result
159
+
160
+
69
161
# Type for wrapped function
70
162
T = TypeVar ("T" , bound = Callable [..., Any ])
71
163
@@ -101,7 +193,8 @@ def __init__(self, func: Callable[..., Any]) -> None:
101
193
102
194
# Get parameter descriptions from parsed docstring
103
195
self .param_descriptions = {
104
- param .arg_name : param .description or f"Parameter { param .arg_name } " for param in self .doc .params
196
+ param .arg_name : param .description or f"Parameter { param .arg_name } "
197
+ for param in self .doc .params
105
198
}
106
199
107
200
# Create a Pydantic model for validation
@@ -131,7 +224,10 @@ def _create_input_model(self) -> Type[BaseModel]:
131
224
description = self .param_descriptions .get (name , f"Parameter { name } " )
132
225
133
226
# Create Field with description and default
134
- field_definitions [name ] = (param_type , Field (default = default , description = description ))
227
+ field_definitions [name ] = (
228
+ param_type ,
229
+ Field (default = default , description = description ),
230
+ )
135
231
136
232
# Create model name based on function name
137
233
model_name = f"{ self .func .__name__ .capitalize ()} Tool"
@@ -173,8 +269,17 @@ def extract_metadata(self) -> ToolSpec:
173
269
# Clean up Pydantic-specific schema elements
174
270
self ._clean_pydantic_schema (input_schema )
175
271
272
+ # Flatten schema by resolving $ref references to their definitions
273
+ # This is required for compatibility with model providers that don't support
274
+ # JSON Schema $ref (e.g., Bedrock/Anthropic via LiteLLM)
275
+ input_schema = _resolve_json_schema_references (input_schema )
276
+
176
277
# Create tool specification
177
- tool_spec : ToolSpec = {"name" : func_name , "description" : description , "inputSchema" : {"json" : input_schema }}
278
+ tool_spec : ToolSpec = {
279
+ "name" : func_name ,
280
+ "description" : description ,
281
+ "inputSchema" : {"json" : input_schema },
282
+ }
178
283
179
284
return tool_spec
180
285
@@ -206,7 +311,9 @@ def _clean_pydantic_schema(self, schema: dict[str, Any]) -> None:
206
311
if "anyOf" in prop_schema :
207
312
any_of = prop_schema ["anyOf" ]
208
313
# Handle Optional[Type] case (represented as anyOf[Type, null])
209
- if len (any_of ) == 2 and any (item .get ("type" ) == "null" for item in any_of ):
314
+ if len (any_of ) == 2 and any (
315
+ item .get ("type" ) == "null" for item in any_of
316
+ ):
210
317
# Find the non-null type
211
318
for item in any_of :
212
319
if item .get ("type" ) != "null" :
@@ -250,7 +357,9 @@ def validate_input(self, input_data: dict[str, Any]) -> dict[str, Any]:
250
357
except Exception as e :
251
358
# Re-raise with more detailed error message
252
359
error_msg = str (e )
253
- raise ValueError (f"Validation failed for input parameters: { error_msg } " ) from e
360
+ raise ValueError (
361
+ f"Validation failed for input parameters: { error_msg } "
362
+ ) from e
254
363
255
364
256
365
P = ParamSpec ("P" ) # Captures all parameters
@@ -296,7 +405,9 @@ def __init__(
296
405
297
406
functools .update_wrapper (wrapper = self , wrapped = self ._tool_func )
298
407
299
- def __get__ (self , instance : Any , obj_type : Optional [Type ] = None ) -> "DecoratedFunctionTool[P, R]" :
408
+ def __get__ (
409
+ self , instance : Any , obj_type : Optional [Type ] = None
410
+ ) -> "DecoratedFunctionTool[P, R]" :
300
411
"""Descriptor protocol implementation for proper method binding.
301
412
302
413
This method enables the decorated function to work correctly when used as a class method.
@@ -325,7 +436,9 @@ def my_tool():
325
436
if instance is not None and not inspect .ismethod (self ._tool_func ):
326
437
# Create a bound method
327
438
tool_func = self ._tool_func .__get__ (instance , instance .__class__ )
328
- return DecoratedFunctionTool (self ._tool_name , self ._tool_spec , tool_func , self ._metadata )
439
+ return DecoratedFunctionTool (
440
+ self ._tool_name , self ._tool_spec , tool_func , self ._metadata
441
+ )
329
442
330
443
return self
331
444
@@ -372,7 +485,9 @@ def tool_type(self) -> str:
372
485
return "function"
373
486
374
487
@override
375
- async def stream (self , tool_use : ToolUse , invocation_state : dict [str , Any ], ** kwargs : Any ) -> ToolGenerator :
488
+ async def stream (
489
+ self , tool_use : ToolUse , invocation_state : dict [str , Any ], ** kwargs : Any
490
+ ) -> ToolGenerator :
376
491
"""Stream the tool with a tool use specification.
377
492
378
493
This method handles tool use streams from a Strands Agent. It validates the input,
@@ -403,7 +518,10 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
403
518
validated_input = self ._metadata .validate_input (tool_input )
404
519
405
520
# Pass along the agent if provided and expected by the function
406
- if "agent" in invocation_state and "agent" in self ._metadata .signature .parameters :
521
+ if (
522
+ "agent" in invocation_state
523
+ and "agent" in self ._metadata .signature .parameters
524
+ ):
407
525
validated_input ["agent" ] = invocation_state .get ("agent" )
408
526
409
527
# "Too few arguments" expected, hence the type ignore
@@ -468,21 +586,27 @@ def get_display_properties(self) -> dict[str, str]:
468
586
# Handle @decorator
469
587
@overload
470
588
def tool (__func : Callable [P , R ]) -> DecoratedFunctionTool [P , R ]: ...
589
+
590
+
471
591
# Handle @decorator()
472
592
@overload
473
593
def tool (
474
594
description : Optional [str ] = None ,
475
595
inputSchema : Optional [JSONSchema ] = None ,
476
596
name : Optional [str ] = None ,
477
597
) -> Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]: ...
598
+
599
+
478
600
# Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the
479
601
# call site, but the actual implementation handles that and it's not representable via the type-system
480
602
def tool ( # type: ignore
481
603
func : Optional [Callable [P , R ]] = None ,
482
604
description : Optional [str ] = None ,
483
605
inputSchema : Optional [JSONSchema ] = None ,
484
606
name : Optional [str ] = None ,
485
- ) -> Union [DecoratedFunctionTool [P , R ], Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]]:
607
+ ) -> Union [
608
+ DecoratedFunctionTool [P , R ], Callable [[Callable [P , R ]], DecoratedFunctionTool [P , R ]]
609
+ ]:
486
610
"""Decorator that transforms a Python function into a Strands tool.
487
611
488
612
This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool.
0 commit comments