@@ -53,6 +53,9 @@ class StreamingState:
53
53
refusal_content_index_and_output : tuple [int , ResponseOutputRefusal ] | None = None
54
54
reasoning_content_index_and_output : tuple [int , ResponseReasoningItem ] | None = None
55
55
function_calls : dict [int , ResponseFunctionToolCall ] = field (default_factory = dict )
56
+ # Fields for real-time function call streaming
57
+ function_call_streaming : dict [int , bool ] = field (default_factory = dict )
58
+ function_call_output_idx : dict [int , int ] = field (default_factory = dict )
56
59
57
60
58
61
class SequenceNumber :
@@ -255,9 +258,7 @@ async def handle_stream(
255
258
# Accumulate the refusal string in the output part
256
259
state .refusal_content_index_and_output [1 ].refusal += delta .refusal
257
260
258
- # Handle tool calls
259
- # Because we don't know the name of the function until the end of the stream, we'll
260
- # save everything and yield events at the end
261
+ # Handle tool calls with real-time streaming support
261
262
if delta .tool_calls :
262
263
for tc_delta in delta .tool_calls :
263
264
if tc_delta .index not in state .function_calls :
@@ -268,15 +269,76 @@ async def handle_stream(
268
269
type = "function_call" ,
269
270
call_id = "" ,
270
271
)
272
+ state .function_call_streaming [tc_delta .index ] = False
273
+
271
274
tc_function = tc_delta .function
272
275
276
+ # Accumulate arguments as they come in
273
277
state .function_calls [tc_delta .index ].arguments += (
274
278
tc_function .arguments if tc_function else ""
275
279
) or ""
276
- state .function_calls [tc_delta .index ].name += (
277
- tc_function .name if tc_function else ""
278
- ) or ""
279
- state .function_calls [tc_delta .index ].call_id = tc_delta .id or ""
280
+
281
+ # Set function name directly (it's correct from the first function call chunk)
282
+ if tc_function and tc_function .name :
283
+ state .function_calls [tc_delta .index ].name = tc_function .name
284
+
285
+ if tc_delta .id :
286
+ state .function_calls [tc_delta .index ].call_id = tc_delta .id
287
+
288
+ function_call = state .function_calls [tc_delta .index ]
289
+
290
+ # Start streaming as soon as we have function name and call_id
291
+ if (not state .function_call_streaming [tc_delta .index ] and
292
+ function_call .name and
293
+ function_call .call_id ):
294
+
295
+ # Calculate the output index for this function call
296
+ function_call_starting_index = 0
297
+ if state .reasoning_content_index_and_output :
298
+ function_call_starting_index += 1
299
+ if state .text_content_index_and_output :
300
+ function_call_starting_index += 1
301
+ if state .refusal_content_index_and_output :
302
+ function_call_starting_index += 1
303
+
304
+ # Add offset for already started function calls
305
+ function_call_starting_index += sum (
306
+ 1 for streaming in state .function_call_streaming .values () if streaming
307
+ )
308
+
309
+ # Mark this function call as streaming and store its output index
310
+ state .function_call_streaming [tc_delta .index ] = True
311
+ state .function_call_output_idx [
312
+ tc_delta .index
313
+ ] = function_call_starting_index
314
+
315
+ # Send initial function call added event
316
+ yield ResponseOutputItemAddedEvent (
317
+ item = ResponseFunctionToolCall (
318
+ id = FAKE_RESPONSES_ID ,
319
+ call_id = function_call .call_id ,
320
+ arguments = "" , # Start with empty arguments
321
+ name = function_call .name ,
322
+ type = "function_call" ,
323
+ ),
324
+ output_index = function_call_starting_index ,
325
+ type = "response.output_item.added" ,
326
+ sequence_number = sequence_number .get_and_increment (),
327
+ )
328
+
329
+ # Stream arguments if we've started streaming this function call
330
+ if (state .function_call_streaming .get (tc_delta .index , False ) and
331
+ tc_function and
332
+ tc_function .arguments ):
333
+
334
+ output_index = state .function_call_output_idx [tc_delta .index ]
335
+ yield ResponseFunctionCallArgumentsDeltaEvent (
336
+ delta = tc_function .arguments ,
337
+ item_id = FAKE_RESPONSES_ID ,
338
+ output_index = output_index ,
339
+ type = "response.function_call_arguments.delta" ,
340
+ sequence_number = sequence_number .get_and_increment (),
341
+ )
280
342
281
343
if state .reasoning_content_index_and_output :
282
344
yield ResponseReasoningSummaryPartDoneEvent (
@@ -327,42 +389,71 @@ async def handle_stream(
327
389
sequence_number = sequence_number .get_and_increment (),
328
390
)
329
391
330
- # Actually send events for the function calls
331
- for function_call in state .function_calls .values ():
332
- # First, a ResponseOutputItemAdded for the function call
333
- yield ResponseOutputItemAddedEvent (
334
- item = ResponseFunctionToolCall (
335
- id = FAKE_RESPONSES_ID ,
336
- call_id = function_call .call_id ,
337
- arguments = function_call .arguments ,
338
- name = function_call .name ,
339
- type = "function_call" ,
340
- ),
341
- output_index = function_call_starting_index ,
342
- type = "response.output_item.added" ,
343
- sequence_number = sequence_number .get_and_increment (),
344
- )
345
- # Then, yield the args
346
- yield ResponseFunctionCallArgumentsDeltaEvent (
347
- delta = function_call .arguments ,
348
- item_id = FAKE_RESPONSES_ID ,
349
- output_index = function_call_starting_index ,
350
- type = "response.function_call_arguments.delta" ,
351
- sequence_number = sequence_number .get_and_increment (),
352
- )
353
- # Finally, the ResponseOutputItemDone
354
- yield ResponseOutputItemDoneEvent (
355
- item = ResponseFunctionToolCall (
356
- id = FAKE_RESPONSES_ID ,
357
- call_id = function_call .call_id ,
358
- arguments = function_call .arguments ,
359
- name = function_call .name ,
360
- type = "function_call" ,
361
- ),
362
- output_index = function_call_starting_index ,
363
- type = "response.output_item.done" ,
364
- sequence_number = sequence_number .get_and_increment (),
365
- )
392
+ # Send completion events for function calls
393
+ for index , function_call in state .function_calls .items ():
394
+ if state .function_call_streaming .get (index , False ):
395
+ # Function call was streamed, just send the completion event
396
+ output_index = state .function_call_output_idx [index ]
397
+ yield ResponseOutputItemDoneEvent (
398
+ item = ResponseFunctionToolCall (
399
+ id = FAKE_RESPONSES_ID ,
400
+ call_id = function_call .call_id ,
401
+ arguments = function_call .arguments ,
402
+ name = function_call .name ,
403
+ type = "function_call" ,
404
+ ),
405
+ output_index = output_index ,
406
+ type = "response.output_item.done" ,
407
+ sequence_number = sequence_number .get_and_increment (),
408
+ )
409
+ else :
410
+ # Function call was not streamed (fallback to old behavior)
411
+ # This handles edge cases where function name never arrived
412
+ fallback_starting_index = 0
413
+ if state .reasoning_content_index_and_output :
414
+ fallback_starting_index += 1
415
+ if state .text_content_index_and_output :
416
+ fallback_starting_index += 1
417
+ if state .refusal_content_index_and_output :
418
+ fallback_starting_index += 1
419
+
420
+ # Add offset for already started function calls
421
+ fallback_starting_index += sum (
422
+ 1 for streaming in state .function_call_streaming .values () if streaming
423
+ )
424
+
425
+ # Send all events at once (backward compatibility)
426
+ yield ResponseOutputItemAddedEvent (
427
+ item = ResponseFunctionToolCall (
428
+ id = FAKE_RESPONSES_ID ,
429
+ call_id = function_call .call_id ,
430
+ arguments = function_call .arguments ,
431
+ name = function_call .name ,
432
+ type = "function_call" ,
433
+ ),
434
+ output_index = fallback_starting_index ,
435
+ type = "response.output_item.added" ,
436
+ sequence_number = sequence_number .get_and_increment (),
437
+ )
438
+ yield ResponseFunctionCallArgumentsDeltaEvent (
439
+ delta = function_call .arguments ,
440
+ item_id = FAKE_RESPONSES_ID ,
441
+ output_index = fallback_starting_index ,
442
+ type = "response.function_call_arguments.delta" ,
443
+ sequence_number = sequence_number .get_and_increment (),
444
+ )
445
+ yield ResponseOutputItemDoneEvent (
446
+ item = ResponseFunctionToolCall (
447
+ id = FAKE_RESPONSES_ID ,
448
+ call_id = function_call .call_id ,
449
+ arguments = function_call .arguments ,
450
+ name = function_call .name ,
451
+ type = "function_call" ,
452
+ ),
453
+ output_index = fallback_starting_index ,
454
+ type = "response.output_item.done" ,
455
+ sequence_number = sequence_number .get_and_increment (),
456
+ )
366
457
367
458
# Finally, send the Response completed event
368
459
outputs : list [ResponseOutputItem ] = []
0 commit comments