Skip to content

Commit 00c87ea

Browse files
authored
Add Streaming of Function Call Arguments to Chat Completions (#999)
## Summary This PR implements real-time streaming of function call arguments as requested in #834. Previously, function call arguments were only emitted after the entire function call was complete, causing poor user experience for large parameter generation. ## Changes - **Enhanced `ChatCmplStreamHandler`**: Added real-time streaming of function call arguments during generation - **New streaming logic**: Function call arguments now stream incrementally as they are generated, similar to text content - **Backward compatibility**: Maintains existing behavior for completed function calls - **Comprehensive testing**: Added tests for both OpenAI and LiteLLM models - **Example implementation**: Created demonstration code showing the new streaming capability Closes #834
1 parent 99ba260 commit 00c87ea

File tree

3 files changed

+419
-77
lines changed

3 files changed

+419
-77
lines changed

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 134 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class StreamingState:
5353
refusal_content_index_and_output: tuple[int, ResponseOutputRefusal] | None = None
5454
reasoning_content_index_and_output: tuple[int, ResponseReasoningItem] | None = None
5555
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
56+
# Fields for real-time function call streaming
57+
function_call_streaming: dict[int, bool] = field(default_factory=dict)
58+
function_call_output_idx: dict[int, int] = field(default_factory=dict)
5659

5760

5861
class SequenceNumber:
@@ -255,9 +258,7 @@ async def handle_stream(
255258
# Accumulate the refusal string in the output part
256259
state.refusal_content_index_and_output[1].refusal += delta.refusal
257260

258-
# Handle tool calls
259-
# Because we don't know the name of the function until the end of the stream, we'll
260-
# save everything and yield events at the end
261+
# Handle tool calls with real-time streaming support
261262
if delta.tool_calls:
262263
for tc_delta in delta.tool_calls:
263264
if tc_delta.index not in state.function_calls:
@@ -268,15 +269,76 @@ async def handle_stream(
268269
type="function_call",
269270
call_id="",
270271
)
272+
state.function_call_streaming[tc_delta.index] = False
273+
271274
tc_function = tc_delta.function
272275

276+
# Accumulate arguments as they come in
273277
state.function_calls[tc_delta.index].arguments += (
274278
tc_function.arguments if tc_function else ""
275279
) or ""
276-
state.function_calls[tc_delta.index].name += (
277-
tc_function.name if tc_function else ""
278-
) or ""
279-
state.function_calls[tc_delta.index].call_id = tc_delta.id or ""
280+
281+
# Set function name directly (it's correct from the first function call chunk)
282+
if tc_function and tc_function.name:
283+
state.function_calls[tc_delta.index].name = tc_function.name
284+
285+
if tc_delta.id:
286+
state.function_calls[tc_delta.index].call_id = tc_delta.id
287+
288+
function_call = state.function_calls[tc_delta.index]
289+
290+
# Start streaming as soon as we have function name and call_id
291+
if (not state.function_call_streaming[tc_delta.index] and
292+
function_call.name and
293+
function_call.call_id):
294+
295+
# Calculate the output index for this function call
296+
function_call_starting_index = 0
297+
if state.reasoning_content_index_and_output:
298+
function_call_starting_index += 1
299+
if state.text_content_index_and_output:
300+
function_call_starting_index += 1
301+
if state.refusal_content_index_and_output:
302+
function_call_starting_index += 1
303+
304+
# Add offset for already started function calls
305+
function_call_starting_index += sum(
306+
1 for streaming in state.function_call_streaming.values() if streaming
307+
)
308+
309+
# Mark this function call as streaming and store its output index
310+
state.function_call_streaming[tc_delta.index] = True
311+
state.function_call_output_idx[
312+
tc_delta.index
313+
] = function_call_starting_index
314+
315+
# Send initial function call added event
316+
yield ResponseOutputItemAddedEvent(
317+
item=ResponseFunctionToolCall(
318+
id=FAKE_RESPONSES_ID,
319+
call_id=function_call.call_id,
320+
arguments="", # Start with empty arguments
321+
name=function_call.name,
322+
type="function_call",
323+
),
324+
output_index=function_call_starting_index,
325+
type="response.output_item.added",
326+
sequence_number=sequence_number.get_and_increment(),
327+
)
328+
329+
# Stream arguments if we've started streaming this function call
330+
if (state.function_call_streaming.get(tc_delta.index, False) and
331+
tc_function and
332+
tc_function.arguments):
333+
334+
output_index = state.function_call_output_idx[tc_delta.index]
335+
yield ResponseFunctionCallArgumentsDeltaEvent(
336+
delta=tc_function.arguments,
337+
item_id=FAKE_RESPONSES_ID,
338+
output_index=output_index,
339+
type="response.function_call_arguments.delta",
340+
sequence_number=sequence_number.get_and_increment(),
341+
)
280342

281343
if state.reasoning_content_index_and_output:
282344
yield ResponseReasoningSummaryPartDoneEvent(
@@ -327,42 +389,71 @@ async def handle_stream(
327389
sequence_number=sequence_number.get_and_increment(),
328390
)
329391

330-
# Actually send events for the function calls
331-
for function_call in state.function_calls.values():
332-
# First, a ResponseOutputItemAdded for the function call
333-
yield ResponseOutputItemAddedEvent(
334-
item=ResponseFunctionToolCall(
335-
id=FAKE_RESPONSES_ID,
336-
call_id=function_call.call_id,
337-
arguments=function_call.arguments,
338-
name=function_call.name,
339-
type="function_call",
340-
),
341-
output_index=function_call_starting_index,
342-
type="response.output_item.added",
343-
sequence_number=sequence_number.get_and_increment(),
344-
)
345-
# Then, yield the args
346-
yield ResponseFunctionCallArgumentsDeltaEvent(
347-
delta=function_call.arguments,
348-
item_id=FAKE_RESPONSES_ID,
349-
output_index=function_call_starting_index,
350-
type="response.function_call_arguments.delta",
351-
sequence_number=sequence_number.get_and_increment(),
352-
)
353-
# Finally, the ResponseOutputItemDone
354-
yield ResponseOutputItemDoneEvent(
355-
item=ResponseFunctionToolCall(
356-
id=FAKE_RESPONSES_ID,
357-
call_id=function_call.call_id,
358-
arguments=function_call.arguments,
359-
name=function_call.name,
360-
type="function_call",
361-
),
362-
output_index=function_call_starting_index,
363-
type="response.output_item.done",
364-
sequence_number=sequence_number.get_and_increment(),
365-
)
392+
# Send completion events for function calls
393+
for index, function_call in state.function_calls.items():
394+
if state.function_call_streaming.get(index, False):
395+
# Function call was streamed, just send the completion event
396+
output_index = state.function_call_output_idx[index]
397+
yield ResponseOutputItemDoneEvent(
398+
item=ResponseFunctionToolCall(
399+
id=FAKE_RESPONSES_ID,
400+
call_id=function_call.call_id,
401+
arguments=function_call.arguments,
402+
name=function_call.name,
403+
type="function_call",
404+
),
405+
output_index=output_index,
406+
type="response.output_item.done",
407+
sequence_number=sequence_number.get_and_increment(),
408+
)
409+
else:
410+
# Function call was not streamed (fallback to old behavior)
411+
# This handles edge cases where function name never arrived
412+
fallback_starting_index = 0
413+
if state.reasoning_content_index_and_output:
414+
fallback_starting_index += 1
415+
if state.text_content_index_and_output:
416+
fallback_starting_index += 1
417+
if state.refusal_content_index_and_output:
418+
fallback_starting_index += 1
419+
420+
# Add offset for already started function calls
421+
fallback_starting_index += sum(
422+
1 for streaming in state.function_call_streaming.values() if streaming
423+
)
424+
425+
# Send all events at once (backward compatibility)
426+
yield ResponseOutputItemAddedEvent(
427+
item=ResponseFunctionToolCall(
428+
id=FAKE_RESPONSES_ID,
429+
call_id=function_call.call_id,
430+
arguments=function_call.arguments,
431+
name=function_call.name,
432+
type="function_call",
433+
),
434+
output_index=fallback_starting_index,
435+
type="response.output_item.added",
436+
sequence_number=sequence_number.get_and_increment(),
437+
)
438+
yield ResponseFunctionCallArgumentsDeltaEvent(
439+
delta=function_call.arguments,
440+
item_id=FAKE_RESPONSES_ID,
441+
output_index=fallback_starting_index,
442+
type="response.function_call_arguments.delta",
443+
sequence_number=sequence_number.get_and_increment(),
444+
)
445+
yield ResponseOutputItemDoneEvent(
446+
item=ResponseFunctionToolCall(
447+
id=FAKE_RESPONSES_ID,
448+
call_id=function_call.call_id,
449+
arguments=function_call.arguments,
450+
name=function_call.name,
451+
type="function_call",
452+
),
453+
output_index=fallback_starting_index,
454+
type="response.output_item.done",
455+
sequence_number=sequence_number.get_and_increment(),
456+
)
366457

367458
# Finally, send the Response completed event
368459
outputs: list[ResponseOutputItem] = []

tests/models/test_litellm_chatcompletions_stream.py

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,18 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
214214
the model is streaming a function/tool call instead of plain text.
215215
The function call will be split across two chunks.
216216
"""
217-
# Simulate a single tool call whose ID stays constant and function name/args built over chunks.
217+
# Simulate a single tool call with complete function name in first chunk
218+
# and arguments split across chunks (reflecting real API behavior)
218219
tool_call_delta1 = ChoiceDeltaToolCall(
219220
index=0,
220221
id="tool-id",
221-
function=ChoiceDeltaToolCallFunction(name="my_", arguments="arg1"),
222+
function=ChoiceDeltaToolCallFunction(name="my_func", arguments="arg1"),
222223
type="function",
223224
)
224225
tool_call_delta2 = ChoiceDeltaToolCall(
225226
index=0,
226227
id="tool-id",
227-
function=ChoiceDeltaToolCallFunction(name="func", arguments="arg2"),
228+
function=ChoiceDeltaToolCallFunction(name=None, arguments="arg2"),
228229
type="function",
229230
)
230231
chunk1 = ChatCompletionChunk(
@@ -284,18 +285,131 @@ async def patched_fetch_response(self, *args, **kwargs):
284285
# The added item should be a ResponseFunctionToolCall.
285286
added_fn = output_events[1].item
286287
assert isinstance(added_fn, ResponseFunctionToolCall)
287-
assert added_fn.name == "my_func" # Name should be concatenation of both chunks.
288-
assert added_fn.arguments == "arg1arg2"
288+
assert added_fn.name == "my_func" # Name should be complete from first chunk
289+
assert added_fn.arguments == "" # Arguments start empty
289290
assert output_events[2].type == "response.function_call_arguments.delta"
290-
assert output_events[2].delta == "arg1arg2"
291-
assert output_events[3].type == "response.output_item.done"
292-
assert output_events[4].type == "response.completed"
293-
assert output_events[2].delta == "arg1arg2"
294-
assert output_events[3].type == "response.output_item.done"
295-
assert output_events[4].type == "response.completed"
296-
assert added_fn.name == "my_func" # Name should be concatenation of both chunks.
297-
assert added_fn.arguments == "arg1arg2"
298-
assert output_events[2].type == "response.function_call_arguments.delta"
299-
assert output_events[2].delta == "arg1arg2"
300-
assert output_events[3].type == "response.output_item.done"
301-
assert output_events[4].type == "response.completed"
291+
assert output_events[2].delta == "arg1" # First argument chunk
292+
assert output_events[3].type == "response.function_call_arguments.delta"
293+
assert output_events[3].delta == "arg2" # Second argument chunk
294+
assert output_events[4].type == "response.output_item.done"
295+
assert output_events[5].type == "response.completed"
296+
# Final function call should have complete arguments
297+
final_fn = output_events[4].item
298+
assert isinstance(final_fn, ResponseFunctionToolCall)
299+
assert final_fn.name == "my_func"
300+
assert final_fn.arguments == "arg1arg2"
301+
302+
303+
@pytest.mark.allow_call_model_methods
304+
@pytest.mark.asyncio
305+
async def test_stream_response_yields_real_time_function_call_arguments(monkeypatch) -> None:
306+
"""
307+
Validate that LiteLLM `stream_response` also emits function call arguments in real-time
308+
as they are received, ensuring consistent behavior across model providers.
309+
"""
310+
# Simulate realistic chunks: name first, then arguments incrementally
311+
tool_call_delta1 = ChoiceDeltaToolCall(
312+
index=0,
313+
id="litellm-call-456",
314+
function=ChoiceDeltaToolCallFunction(name="generate_code", arguments=""),
315+
type="function",
316+
)
317+
tool_call_delta2 = ChoiceDeltaToolCall(
318+
index=0,
319+
function=ChoiceDeltaToolCallFunction(arguments='{"language": "'),
320+
type="function",
321+
)
322+
tool_call_delta3 = ChoiceDeltaToolCall(
323+
index=0,
324+
function=ChoiceDeltaToolCallFunction(arguments='python", "task": "'),
325+
type="function",
326+
)
327+
tool_call_delta4 = ChoiceDeltaToolCall(
328+
index=0,
329+
function=ChoiceDeltaToolCallFunction(arguments='hello world"}'),
330+
type="function",
331+
)
332+
333+
chunk1 = ChatCompletionChunk(
334+
id="chunk-id",
335+
created=1,
336+
model="fake",
337+
object="chat.completion.chunk",
338+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
339+
)
340+
chunk2 = ChatCompletionChunk(
341+
id="chunk-id",
342+
created=1,
343+
model="fake",
344+
object="chat.completion.chunk",
345+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
346+
)
347+
chunk3 = ChatCompletionChunk(
348+
id="chunk-id",
349+
created=1,
350+
model="fake",
351+
object="chat.completion.chunk",
352+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta3]))],
353+
)
354+
chunk4 = ChatCompletionChunk(
355+
id="chunk-id",
356+
created=1,
357+
model="fake",
358+
object="chat.completion.chunk",
359+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta4]))],
360+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
361+
)
362+
363+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
364+
for c in (chunk1, chunk2, chunk3, chunk4):
365+
yield c
366+
367+
async def patched_fetch_response(self, *args, **kwargs):
368+
resp = Response(
369+
id="resp-id",
370+
created_at=0,
371+
model="fake-model",
372+
object="response",
373+
output=[],
374+
tool_choice="none",
375+
tools=[],
376+
parallel_tool_calls=False,
377+
)
378+
return resp, fake_stream()
379+
380+
monkeypatch.setattr(LitellmModel, "_fetch_response", patched_fetch_response)
381+
model = LitellmProvider().get_model("gpt-4")
382+
output_events = []
383+
async for event in model.stream_response(
384+
system_instructions=None,
385+
input="",
386+
model_settings=ModelSettings(),
387+
tools=[],
388+
output_schema=None,
389+
handoffs=[],
390+
tracing=ModelTracing.DISABLED,
391+
previous_response_id=None,
392+
prompt=None,
393+
):
394+
output_events.append(event)
395+
396+
# Extract events by type
397+
function_args_delta_events = [
398+
e for e in output_events if e.type == "response.function_call_arguments.delta"
399+
]
400+
output_item_added_events = [e for e in output_events if e.type == "response.output_item.added"]
401+
402+
# Verify we got real-time streaming (3 argument delta events)
403+
assert len(function_args_delta_events) == 3
404+
assert len(output_item_added_events) == 1
405+
406+
# Verify the deltas were streamed correctly
407+
expected_deltas = ['{"language": "', 'python", "task": "', 'hello world"}']
408+
for i, delta_event in enumerate(function_args_delta_events):
409+
assert delta_event.delta == expected_deltas[i]
410+
411+
# Verify function call metadata
412+
added_event = output_item_added_events[0]
413+
assert isinstance(added_event.item, ResponseFunctionToolCall)
414+
assert added_event.item.name == "generate_code"
415+
assert added_event.item.call_id == "litellm-call-456"

0 commit comments

Comments
 (0)