Skip to content

feat: claude citation support with BedrockModel #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 0 additions & 1 deletion src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,4 @@ def __str__(self) -> str:
for item in content_array:
if isinstance(item, dict) and "text" in item:
result += item.get("text", "") + "\n"

return result
21 changes: 21 additions & 0 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ def handle_content_block_delta(
state["text"] += delta_content["text"]
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content}

elif "citation" in delta_content:
if "citationsContent" not in state:
state["citationsContent"] = []

state["citationsContent"].append(delta_content["citation"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of just appending the response to a list, I can see that the CitationSourceContentDelta event mentions that it contains "incremental updates to the source content text during streaming responses". Does this mean we are concatenating the strings of the CitationSourceContentDelta events, or just appending the events to an array?

callback_event["callback"] = {"citation_metadata": delta_content["citation"], "delta": delta_content}

elif "reasoningContent" in delta_content:
if "text" in delta_content["reasoningContent"]:
if "reasoningText" not in state:
Expand Down Expand Up @@ -168,6 +175,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
current_tool_use = state["current_tool_use"]
text = state["text"]
reasoning_text = state["reasoningText"]
citations_content = state["citationsContent"] if "citationsContent" in state else []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is citationsContent not in state?


if current_tool_use:
if "input" not in current_tool_use:
Expand Down Expand Up @@ -206,6 +214,18 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
)
state["reasoningText"] = ""

# Handle citations_content independently - not as elif since we can have both text and citations
if citations_content:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be made into an elif?

# Convert CitationsDelta objects back to CitationsContentBlock format
# that matches non-streaming behavior
from ..types.citations import CitationsContentBlock

citations_block: CitationsContentBlock = {
"citations": citations_content # citations_content contains CitationsDelta objects
}
content.append({"citationsContent": citations_block})
state["citationsContent"] = []

return state


Expand Down Expand Up @@ -264,6 +284,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
"current_tool_use": {},
"reasoningText": "",
"signature": "",
"citationsContent": [],
}
state["content"] = state["message"]["content"]

Expand Down
77 changes: 74 additions & 3 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
import os
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast

import boto3
from botocore.config import Config as BotocoreConfig
Expand All @@ -18,7 +18,11 @@
from ..event_loop import streaming
from ..tools import convert_pydantic_to_tool_spec
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
from ..types.exceptions import (
ContextWindowOverflowException,
ModelThrottledException,
UnsupportedModelCitationsException,
)
from ..types.streaming import StreamEvent
from ..types.tools import ToolResult, ToolSpec
from .model import Model
Expand All @@ -34,6 +38,15 @@
"too many total text bytes",
]

# Model IDs that support citation functionality
CITATION_SUPPORTED_MODELS = [
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
]

T = TypeVar("T", bound=BaseModel)


Expand Down Expand Up @@ -349,6 +362,42 @@ def _generate_redaction_events(self) -> list[StreamEvent]:

return events

def _has_citations_config(self, messages: Messages) -> bool:
"""Check if any message contains document content with citations enabled.

Args:
messages: List of messages to check for citations config.

Returns:
True if any message contains a document with citations enabled, False otherwise.
"""
for message in messages:
for content_block in message["content"]:
if "document" in content_block:
document = content_block["document"]
if "citations" in document and document["citations"] is not None:
citations_config = document["citations"]
if "enabled" in citations_config and citations_config["enabled"]:
return True
return False

def _validate_citations_support(self, messages: Messages) -> None:
"""Validate that the current model supports citations if citations are requested.

Args:
messages: List of messages to check for citations config.

Raises:
UnsupportedModelCitationsException: If citations are requested but the model doesn't support them.
"""
if self._has_citations_config(messages):
model_id = self.config["model_id"]
# Bedrock model IDs may include a cross-region prefix (e.g., "us.") before the model ID.
# Treat a model as supported if its ID ends with any of the supported model IDs.
is_supported = any(model_id.endswith(supported_id) for supported_id in CITATION_SUPPORTED_MODELS)
if not is_supported:
raise UnsupportedModelCitationsException(model_id, CITATION_SUPPORTED_MODELS)

@override
async def stream(
self,
Expand All @@ -374,7 +423,10 @@ async def stream(
Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
ModelThrottledException: If the model service is throttling requests.
UnsupportedModelCitationsException: If citations are requested but the model doesn't support them.
"""
# Validate citations support before starting the thread (fail fast in async context)
self._validate_citations_support(messages)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather not have this logic client side, and rely on the API to throw a validation exception here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally get your concern, but I'm guessing provider SDK server errors will be cryptic. Happy to take it off though


def callback(event: Optional[StreamEvent] = None) -> None:
loop.call_soon_threadsafe(queue.put_nowait, event)
Expand Down Expand Up @@ -510,7 +562,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
yield {"messageStart": {"role": response["output"]["message"]["role"]}}

# Process content blocks
for content in response["output"]["message"]["content"]:
for content in cast(list[ContentBlock], response["output"]["message"]["content"]):
# Yield contentBlockStart event if needed
if "toolUse" in content:
yield {
Expand Down Expand Up @@ -553,6 +605,25 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
}
}
}
elif "citationsContent" in content:
# For non-streaming citations, emit text and metadata deltas in sequence
# to match streaming behavior where they flow naturally
if "content" in content["citationsContent"]:
text_content = "".join([content["text"] for content in content["citationsContent"]["content"]])
yield {
"contentBlockDelta": {"delta": {"text": text_content}},
}

for citation in content["citationsContent"]["citations"]:
# Then emit citation metadata (for structure)
from ..types.streaming import CitationsDelta

citation_metadata: CitationsDelta = {
"title": citation["title"],
"location": citation["location"],
"sourceContent": citation["sourceContent"],
}
yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}}

# Yield contentBlockStop event
yield {"contentBlockStop": {}}
Expand Down
3 changes: 2 additions & 1 deletion src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from mcp.types import ImageContent as MCPImageContent
from mcp.types import TextContent as MCPTextContent

from strands.types.tools import ToolResultContent, ToolResultStatus
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why relative to absolute import?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy was complaining without it


from ...types import PaginatedList
from ...types.exceptions import MCPClientInitializationError
from ...types.media import ImageFormat
from ...types.tools import ToolResultContent, ToolResultStatus
from .mcp_agent_tool import MCPAgentTool
from .mcp_instrumentation import mcp_instrumentation
from .mcp_types import MCPToolResult, MCPTransport
Expand Down
152 changes: 152 additions & 0 deletions src/strands/types/citations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Citation type definitions for the SDK.

These types are modeled after the Bedrock API.
"""

from typing import List, Union

from typing_extensions import TypedDict


class CitationsConfig(TypedDict):
"""Configuration for enabling citations on documents.

Attributes:
enabled: Whether citations are enabled for this document.
"""

enabled: bool


class DocumentCharLocation(TypedDict, total=False):
"""Specifies a character-level location within a document.

Provides precise positioning information for cited content using
start and end character indices.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting character position of the cited content within
the document. Minimum value of 0.
end: The ending character position of the cited content within
the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


class DocumentChunkLocation(TypedDict, total=False):
"""Specifies a chunk-level location within a document.

Provides positioning information for cited content using logical
document segments or chunks.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting chunk identifier or index of the cited content
within the document. Minimum value of 0.
end: The ending chunk identifier or index of the cited content
within the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


class DocumentPageLocation(TypedDict, total=False):
"""Specifies a page-level location within a document.

Provides positioning information for cited content using page numbers.

Attributes:
documentIndex: The index of the document within the array of documents
provided in the request. Minimum value of 0.
start: The starting page number of the cited content within
the document. Minimum value of 0.
end: The ending page number of the cited content within
the document. Minimum value of 0.
"""

documentIndex: int
start: int
end: int


# Union type for citation locations
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation]


class CitationSourceContent(TypedDict, total=False):
"""Contains the actual text content from a source document.

Contains the actual text content from a source document that is being
cited or referenced in the model's response.

Note:
This is a UNION type, so only one of the members can be specified.

Attributes:
text: The text content from the source document that is being cited.
"""

text: str


class CitationGeneratedContent(TypedDict, total=False):
"""Contains the generated text content that corresponds to a citation.

Contains the generated text content that corresponds to or is supported
by a citation from a source document.

Note:
This is a UNION type, so only one of the members can be specified.

Attributes:
text: The text content that was generated by the model and is
supported by the associated citation.
"""

text: str


class Citation(TypedDict, total=False):
"""Contains information about a citation that references a source document.

Citations provide traceability between the model's generated response
and the source documents that informed that response.

Attributes:
location: The precise location within the source document where the
cited content can be found, including character positions, page
numbers, or chunk identifiers.
sourceContent: The specific content from the source document that was
referenced or cited in the generated response.
title: The title or identifier of the source document being cited.
"""

location: CitationLocation
sourceContent: List[CitationSourceContent]
title: str


class CitationsContentBlock(TypedDict, total=False):
"""A content block containing generated text and associated citations.

This block type is returned when document citations are enabled, providing
traceability between the generated content and the source documents that
informed the response.

Attributes:
citations: An array of citations that reference the source documents
used to generate the associated content.
content: The generated content that is supported by the associated
citations.
"""

citations: List[Citation]
content: List[CitationGeneratedContent]
3 changes: 3 additions & 0 deletions src/strands/types/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing_extensions import TypedDict

from .citations import CitationsContentBlock
from .media import DocumentContent, ImageContent, VideoContent
from .tools import ToolResult, ToolUse

Expand Down Expand Up @@ -83,6 +84,7 @@ class ContentBlock(TypedDict, total=False):
toolResult: The result for a tool request that a model makes.
toolUse: Information about a tool use request from a model.
video: Video to include in the message.
citationsContent: Contains the citations for a document.
"""

cachePoint: CachePoint
Expand All @@ -94,6 +96,7 @@ class ContentBlock(TypedDict, total=False):
toolResult: ToolResult
toolUse: ToolUse
video: VideoContent
citationsContent: CitationsContentBlock


class SystemContentBlock(TypedDict, total=False):
Expand Down
24 changes: 24 additions & 0 deletions src/strands/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,27 @@ class SessionException(Exception):
"""Exception raised when session operations fail."""

pass


class UnsupportedModelCitationsException(Exception):
"""Exception raised when trying to use citations with an unsupported model.

This exception is raised when a user attempts to use document citations with a Bedrock model
that does not support the citations feature. Citations are only supported by specific Claude models.
"""

def __init__(self, model_id: str, supported_models: list[str]) -> None:
"""Initialize exception with model information.

Args:
model_id: The model ID that doesn't support citations.
supported_models: List of model IDs that do support citations.
"""
self.model_id = model_id
self.supported_models = supported_models

message = (
f"Model '{model_id}' does not support document citations. "
f"Supported models for citations are: {', '.join(supported_models)}"
)
super().__init__(message)
Loading
Loading