-
Notifications
You must be signed in to change notification settings - Fork 285
feat: claude citation support with BedrockModel #631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c586720
a156321
20ee330
88ebbeb
7f36078
4d55809
441e583
08b8bdb
df2e579
7e70bf7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,6 +128,13 @@ def handle_content_block_delta( | |
state["text"] += delta_content["text"] | ||
callback_event["callback"] = {"data": delta_content["text"], "delta": delta_content} | ||
|
||
elif "citation" in delta_content: | ||
if "citationsContent" not in state: | ||
state["citationsContent"] = [] | ||
|
||
state["citationsContent"].append(delta_content["citation"]) | ||
callback_event["callback"] = {"citation_metadata": delta_content["citation"], "delta": delta_content} | ||
|
||
elif "reasoningContent" in delta_content: | ||
if "text" in delta_content["reasoningContent"]: | ||
if "reasoningText" not in state: | ||
|
@@ -168,6 +175,7 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: | |
current_tool_use = state["current_tool_use"] | ||
text = state["text"] | ||
reasoning_text = state["reasoningText"] | ||
citations_content = state["citationsContent"] if "citationsContent" in state else [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When is |
||
|
||
if current_tool_use: | ||
if "input" not in current_tool_use: | ||
|
@@ -206,6 +214,18 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: | |
) | ||
state["reasoningText"] = "" | ||
|
||
# Handle citations_content independently - not as elif since we can have both text and citations | ||
if citations_content: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be made into an elif? |
||
# Convert CitationsDelta objects back to CitationsContentBlock format | ||
# that matches non-streaming behavior | ||
from ..types.citations import CitationsContentBlock | ||
|
||
citations_block: CitationsContentBlock = { | ||
"citations": citations_content # citations_content contains CitationsDelta objects | ||
} | ||
content.append({"citationsContent": citations_block}) | ||
state["citationsContent"] = [] | ||
|
||
return state | ||
|
||
|
||
|
@@ -264,6 +284,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d | |
"current_tool_use": {}, | ||
"reasoningText": "", | ||
"signature": "", | ||
"citationsContent": [], | ||
} | ||
state["content"] = state["message"]["content"] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import json | ||
import logging | ||
import os | ||
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union | ||
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, cast | ||
|
||
import boto3 | ||
from botocore.config import Config as BotocoreConfig | ||
|
@@ -18,7 +18,11 @@ | |
from ..event_loop import streaming | ||
from ..tools import convert_pydantic_to_tool_spec | ||
from ..types.content import ContentBlock, Message, Messages | ||
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException | ||
from ..types.exceptions import ( | ||
ContextWindowOverflowException, | ||
ModelThrottledException, | ||
UnsupportedModelCitationsException, | ||
) | ||
from ..types.streaming import StreamEvent | ||
from ..types.tools import ToolResult, ToolSpec | ||
from .model import Model | ||
|
@@ -34,6 +38,15 @@ | |
"too many total text bytes", | ||
] | ||
|
||
# Model IDs that support citation functionality | ||
CITATION_SUPPORTED_MODELS = [ | ||
"anthropic.claude-3-5-sonnet-20241022-v2:0", | ||
"anthropic.claude-3-7-sonnet-20250219-v1:0", | ||
"anthropic.claude-opus-4-20250514-v1:0", | ||
"anthropic.claude-sonnet-4-20250514-v1:0", | ||
"anthropic.claude-opus-4-1-20250805-v1:0", | ||
] | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
|
@@ -349,6 +362,42 @@ def _generate_redaction_events(self) -> list[StreamEvent]: | |
|
||
return events | ||
|
||
def _has_citations_config(self, messages: Messages) -> bool: | ||
"""Check if any message contains document content with citations enabled. | ||
|
||
Args: | ||
messages: List of messages to check for citations config. | ||
|
||
Returns: | ||
True if any message contains a document with citations enabled, False otherwise. | ||
""" | ||
for message in messages: | ||
for content_block in message["content"]: | ||
if "document" in content_block: | ||
document = content_block["document"] | ||
if "citations" in document and document["citations"] is not None: | ||
citations_config = document["citations"] | ||
if "enabled" in citations_config and citations_config["enabled"]: | ||
return True | ||
return False | ||
|
||
def _validate_citations_support(self, messages: Messages) -> None: | ||
"""Validate that the current model supports citations if citations are requested. | ||
|
||
Args: | ||
messages: List of messages to check for citations config. | ||
|
||
Raises: | ||
UnsupportedModelCitationsException: If citations are requested but the model doesn't support them. | ||
""" | ||
if self._has_citations_config(messages): | ||
model_id = self.config["model_id"] | ||
# Bedrock model IDs may include a cross-region prefix (e.g., "us.") before the model ID. | ||
# Treat a model as supported if its ID ends with any of the supported model IDs. | ||
is_supported = any(model_id.endswith(supported_id) for supported_id in CITATION_SUPPORTED_MODELS) | ||
if not is_supported: | ||
raise UnsupportedModelCitationsException(model_id, CITATION_SUPPORTED_MODELS) | ||
|
||
@override | ||
async def stream( | ||
self, | ||
|
@@ -374,7 +423,10 @@ async def stream( | |
Raises: | ||
ContextWindowOverflowException: If the input exceeds the model's context window. | ||
ModelThrottledException: If the model service is throttling requests. | ||
UnsupportedModelCitationsException: If citations are requested but the model doesn't support them. | ||
""" | ||
# Validate citations support before starting the thread (fail fast in async context) | ||
self._validate_citations_support(messages) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would rather not have this logic client side, and rely on the API to throw a validation exception here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally get your concern, but I'm guessing provider SDK server errors will be cryptic. Happy to take it off though |
||
|
||
def callback(event: Optional[StreamEvent] = None) -> None: | ||
loop.call_soon_threadsafe(queue.put_nowait, event) | ||
|
@@ -510,7 +562,7 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera | |
yield {"messageStart": {"role": response["output"]["message"]["role"]}} | ||
|
||
# Process content blocks | ||
for content in response["output"]["message"]["content"]: | ||
for content in cast(list[ContentBlock], response["output"]["message"]["content"]): | ||
# Yield contentBlockStart event if needed | ||
if "toolUse" in content: | ||
yield { | ||
|
@@ -553,6 +605,25 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera | |
} | ||
} | ||
} | ||
elif "citationsContent" in content: | ||
# For non-streaming citations, emit text and metadata deltas in sequence | ||
# to match streaming behavior where they flow naturally | ||
if "content" in content["citationsContent"]: | ||
text_content = "".join([content["text"] for content in content["citationsContent"]["content"]]) | ||
yield { | ||
"contentBlockDelta": {"delta": {"text": text_content}}, | ||
} | ||
|
||
for citation in content["citationsContent"]["citations"]: | ||
# Then emit citation metadata (for structure) | ||
from ..types.streaming import CitationsDelta | ||
|
||
citation_metadata: CitationsDelta = { | ||
"title": citation["title"], | ||
"location": citation["location"], | ||
"sourceContent": citation["sourceContent"], | ||
} | ||
yield {"contentBlockDelta": {"delta": {"citation": citation_metadata}}} | ||
|
||
# Yield contentBlockStop event | ||
yield {"contentBlockStop": {}} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,10 +24,11 @@ | |
from mcp.types import ImageContent as MCPImageContent | ||
from mcp.types import TextContent as MCPTextContent | ||
|
||
from strands.types.tools import ToolResultContent, ToolResultStatus | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why relative to absolute import? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mypy was complaining without it |
||
|
||
from ...types import PaginatedList | ||
from ...types.exceptions import MCPClientInitializationError | ||
from ...types.media import ImageFormat | ||
from ...types.tools import ToolResultContent, ToolResultStatus | ||
from .mcp_agent_tool import MCPAgentTool | ||
from .mcp_instrumentation import mcp_instrumentation | ||
from .mcp_types import MCPToolResult, MCPTransport | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""Citation type definitions for the SDK. | ||
|
||
These types are modeled after the Bedrock API. | ||
""" | ||
|
||
from typing import List, Union | ||
|
||
from typing_extensions import TypedDict | ||
|
||
|
||
class CitationsConfig(TypedDict): | ||
"""Configuration for enabling citations on documents. | ||
|
||
Attributes: | ||
enabled: Whether citations are enabled for this document. | ||
""" | ||
|
||
enabled: bool | ||
|
||
|
||
class DocumentCharLocation(TypedDict, total=False): | ||
"""Specifies a character-level location within a document. | ||
|
||
Provides precise positioning information for cited content using | ||
start and end character indices. | ||
|
||
Attributes: | ||
documentIndex: The index of the document within the array of documents | ||
provided in the request. Minimum value of 0. | ||
start: The starting character position of the cited content within | ||
the document. Minimum value of 0. | ||
end: The ending character position of the cited content within | ||
the document. Minimum value of 0. | ||
""" | ||
|
||
documentIndex: int | ||
start: int | ||
end: int | ||
|
||
|
||
class DocumentChunkLocation(TypedDict, total=False): | ||
"""Specifies a chunk-level location within a document. | ||
|
||
Provides positioning information for cited content using logical | ||
document segments or chunks. | ||
|
||
Attributes: | ||
documentIndex: The index of the document within the array of documents | ||
provided in the request. Minimum value of 0. | ||
start: The starting chunk identifier or index of the cited content | ||
within the document. Minimum value of 0. | ||
end: The ending chunk identifier or index of the cited content | ||
within the document. Minimum value of 0. | ||
""" | ||
|
||
documentIndex: int | ||
start: int | ||
end: int | ||
|
||
|
||
class DocumentPageLocation(TypedDict, total=False): | ||
"""Specifies a page-level location within a document. | ||
|
||
Provides positioning information for cited content using page numbers. | ||
|
||
Attributes: | ||
documentIndex: The index of the document within the array of documents | ||
provided in the request. Minimum value of 0. | ||
start: The starting page number of the cited content within | ||
the document. Minimum value of 0. | ||
end: The ending page number of the cited content within | ||
the document. Minimum value of 0. | ||
""" | ||
|
||
documentIndex: int | ||
start: int | ||
end: int | ||
|
||
|
||
# Union type for citation locations | ||
CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] | ||
|
||
|
||
class CitationSourceContent(TypedDict, total=False): | ||
"""Contains the actual text content from a source document. | ||
|
||
Contains the actual text content from a source document that is being | ||
cited or referenced in the model's response. | ||
|
||
Note: | ||
This is a UNION type, so only one of the members can be specified. | ||
|
||
Attributes: | ||
text: The text content from the source document that is being cited. | ||
""" | ||
|
||
text: str | ||
|
||
|
||
class CitationGeneratedContent(TypedDict, total=False): | ||
"""Contains the generated text content that corresponds to a citation. | ||
|
||
Contains the generated text content that corresponds to or is supported | ||
by a citation from a source document. | ||
|
||
Note: | ||
This is a UNION type, so only one of the members can be specified. | ||
|
||
Attributes: | ||
text: The text content that was generated by the model and is | ||
supported by the associated citation. | ||
""" | ||
|
||
text: str | ||
|
||
|
||
class Citation(TypedDict, total=False): | ||
"""Contains information about a citation that references a source document. | ||
|
||
Citations provide traceability between the model's generated response | ||
and the source documents that informed that response. | ||
|
||
Attributes: | ||
location: The precise location within the source document where the | ||
cited content can be found, including character positions, page | ||
numbers, or chunk identifiers. | ||
sourceContent: The specific content from the source document that was | ||
referenced or cited in the generated response. | ||
title: The title or identifier of the source document being cited. | ||
""" | ||
|
||
location: CitationLocation | ||
sourceContent: List[CitationSourceContent] | ||
title: str | ||
|
||
|
||
class CitationsContentBlock(TypedDict, total=False): | ||
"""A content block containing generated text and associated citations. | ||
|
||
This block type is returned when document citations are enabled, providing | ||
traceability between the generated content and the source documents that | ||
informed the response. | ||
|
||
Attributes: | ||
citations: An array of citations that reference the source documents | ||
used to generate the associated content. | ||
content: The generated content that is supported by the associated | ||
citations. | ||
""" | ||
|
||
citations: List[Citation] | ||
content: List[CitationGeneratedContent] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of just appending the response to a list, I can see that the
CitationSourceContentDelta
event mentions that it contains "incremental updates to the source content text during streaming responses". Does this mean we are concatenating the strings of the CitationSourceContentDelta events, or just appending the events to an array?