Skip to content

Commit 9662367

Browse files
committed
feat: add cached token metrics support for Amazon Bedrock
1 parent adac26f commit 9662367

File tree

5 files changed

+66
-17
lines changed

5 files changed

+66
-17
lines changed

src/strands/telemetry/metrics.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ..telemetry import metrics_constants as constants
1313
from ..types.content import Message
14-
from ..types.streaming import Metrics, Usage
14+
from ..types.event_loop import Metrics, Usage
1515
from ..types.tools import ToolUse
1616

1717
logger = logging.getLogger(__name__)
@@ -264,6 +264,21 @@ def update_usage(self, usage: Usage) -> None:
264264
self.accumulated_usage["outputTokens"] += usage["outputTokens"]
265265
self.accumulated_usage["totalTokens"] += usage["totalTokens"]
266266

267+
# Handle optional cached token metrics
268+
if "cacheReadInputTokens" in usage:
269+
cache_read_tokens = usage["cacheReadInputTokens"]
270+
self._metrics_client.event_loop_cache_read_input_tokens.record(cache_read_tokens)
271+
self.accumulated_usage["cacheReadInputTokens"] = (
272+
self.accumulated_usage.get("cacheReadInputTokens", 0) + cache_read_tokens
273+
)
274+
275+
if "cacheWriteInputTokens" in usage:
276+
cache_write_tokens = usage["cacheWriteInputTokens"]
277+
self._metrics_client.event_loop_cache_write_input_tokens.record(cache_write_tokens)
278+
self.accumulated_usage["cacheWriteInputTokens"] = (
279+
self.accumulated_usage.get("cacheWriteInputTokens", 0) + cache_write_tokens
280+
)
281+
267282
def update_metrics(self, metrics: Metrics) -> None:
268283
"""Update the accumulated performance metrics with new metrics data.
269284
@@ -325,11 +340,21 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name
325340
f"├─ Cycles: total={summary['total_cycles']}, avg_time={summary['average_cycle_time']:.3f}s, "
326341
f"total_time={summary['total_duration']:.3f}s"
327342
)
328-
yield (
329-
f"├─ Tokens: in={summary['accumulated_usage']['inputTokens']}, "
330-
f"out={summary['accumulated_usage']['outputTokens']}, "
331-
f"total={summary['accumulated_usage']['totalTokens']}"
332-
)
343+
344+
# Build token display with optional cached tokens
345+
token_parts = [
346+
f"in={summary['accumulated_usage']['inputTokens']}",
347+
f"out={summary['accumulated_usage']['outputTokens']}",
348+
f"total={summary['accumulated_usage']['totalTokens']}",
349+
]
350+
351+
# Add cached token info if present
352+
if summary["accumulated_usage"].get("cacheReadInputTokens"):
353+
token_parts.append(f"cache_read={summary['accumulated_usage']['cacheReadInputTokens']}")
354+
if summary["accumulated_usage"].get("cacheWriteInputTokens"):
355+
token_parts.append(f"cache_write={summary['accumulated_usage']['cacheWriteInputTokens']}")
356+
357+
yield f"├─ Tokens: {', '.join(token_parts)}"
333358
yield f"├─ Bedrock Latency: {summary['accumulated_metrics']['latencyMs']}ms"
334359

335360
yield "├─ Tool Usage:"
@@ -421,6 +446,8 @@ class MetricsClient:
421446
event_loop_latency: Histogram
422447
event_loop_input_tokens: Histogram
423448
event_loop_output_tokens: Histogram
449+
event_loop_cache_read_input_tokens: Histogram
450+
event_loop_cache_write_input_tokens: Histogram
424451

425452
tool_call_count: Counter
426453
tool_success_count: Counter
@@ -474,3 +501,9 @@ def create_instruments(self) -> None:
474501
self.event_loop_output_tokens = self.meter.create_histogram(
475502
name=constants.STRANDS_EVENT_LOOP_OUTPUT_TOKENS, unit="token"
476503
)
504+
self.event_loop_cache_read_input_tokens = self.meter.create_histogram(
505+
name=constants.STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS, unit="token"
506+
)
507+
self.event_loop_cache_write_input_tokens = self.meter.create_histogram(
508+
name=constants.STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS, unit="token"
509+
)

src/strands/telemetry/metrics_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
STRANDS_EVENT_LOOP_CYCLE_DURATION = "strands.event_loop.cycle_duration"
1414
STRANDS_EVENT_LOOP_INPUT_TOKENS = "strands.event_loop.input.tokens"
1515
STRANDS_EVENT_LOOP_OUTPUT_TOKENS = "strands.event_loop.output.tokens"
16+
STRANDS_EVENT_LOOP_CACHE_READ_INPUT_TOKENS = "strands.event_loop.cache_read.input.tokens"
17+
STRANDS_EVENT_LOOP_CACHE_WRITE_INPUT_TOKENS = "strands.event_loop.cache_write.input.tokens"

src/strands/types/event_loop.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22

33
from typing import Literal
44

5-
from typing_extensions import TypedDict
5+
from typing_extensions import Required, TypedDict
66

77

8-
class Usage(TypedDict):
8+
class Usage(TypedDict, total=False):
99
"""Token usage information for model interactions.
1010
1111
Attributes:
12-
inputTokens: Number of tokens sent in the request to the model..
12+
inputTokens: Number of tokens sent in the request to the model.
1313
outputTokens: Number of tokens that the model generated for the request.
1414
totalTokens: Total number of tokens (input + output).
15+
cacheReadInputTokens: Number of tokens read from cache (optional).
16+
cacheWriteInputTokens: Number of tokens written to cache (optional).
1517
"""
1618

17-
inputTokens: int
18-
outputTokens: int
19-
totalTokens: int
19+
inputTokens: Required[int]
20+
outputTokens: Required[int]
21+
totalTokens: Required[int]
22+
cacheReadInputTokens: int
23+
cacheWriteInputTokens: int
2024

2125

2226
class Metrics(TypedDict):

tests/strands/event_loop/test_streaming.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,18 @@ def test_extract_usage_metrics():
260260
assert tru_usage == exp_usage and tru_metrics == exp_metrics
261261

262262

263+
def test_extract_usage_metrics_with_cache_tokens():
264+
event = {
265+
"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0, "cacheReadInputTokens": 0},
266+
"metrics": {"latencyMs": 0},
267+
}
268+
269+
tru_usage, tru_metrics = strands.event_loop.streaming.extract_usage_metrics(event)
270+
exp_usage, exp_metrics = event["usage"], event["metrics"]
271+
272+
assert tru_usage == exp_usage and tru_metrics == exp_metrics
273+
274+
263275
@pytest.mark.parametrize(
264276
("response", "exp_events"),
265277
[

tests/strands/telemetry/test_metrics.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def usage(request):
9090
"inputTokens": 1,
9191
"outputTokens": 2,
9292
"totalTokens": 3,
93+
"cacheWriteInputTokens": 10,
9394
}
9495
if hasattr(request, "param"):
9596
params.update(request.param)
@@ -315,17 +316,14 @@ def test_event_loop_metrics_update_usage(usage, event_loop_metrics, mock_get_met
315316
event_loop_metrics.update_usage(usage)
316317

317318
tru_usage = event_loop_metrics.accumulated_usage
318-
exp_usage = Usage(
319-
inputTokens=3,
320-
outputTokens=6,
321-
totalTokens=9,
322-
)
319+
exp_usage = Usage(inputTokens=3, outputTokens=6, totalTokens=9, cacheWriteInputTokens=30)
323320

324321
assert tru_usage == exp_usage
325322
mock_get_meter_provider.return_value.get_meter.assert_called()
326323
metrics_client = event_loop_metrics._metrics_client
327324
metrics_client.event_loop_input_tokens.record.assert_called()
328325
metrics_client.event_loop_output_tokens.record.assert_called()
326+
metrics_client.event_loop_cache_write_input_tokens.record.assert_called()
329327

330328

331329
def test_event_loop_metrics_update_metrics(metrics, event_loop_metrics, mock_get_meter_provider):

0 commit comments

Comments
 (0)