mirror of https://github.com/kortix-ai/suna.git
Enhance assistant message processing in async stream printer
- Introduced chunk accumulation for assistant messages to handle multi-part responses more effectively. - Implemented a function to rebuild full text from sorted chunks, improving the accuracy of message content. - Added state management for parsing function calls, allowing detection and handling of tool usage within messages. - Updated output formatting to provide clearer logs for assistant messages and tool interactions.
This commit is contained in:
parent
be05cc8348
commit
ac9dbd7127
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import re
|
||||
from typing import AsyncGenerator, Optional, Any
|
||||
|
||||
|
||||
|
@ -16,6 +17,25 @@ async def print_stream(stream: AsyncGenerator[str, None]):
|
|||
Follows the same output format as stream_test.py.
|
||||
"""
|
||||
stream_started = False
|
||||
chunks = [] # Store chunks to sort by sequence
|
||||
parsing_state = "text" # "text", "in_function_call", "function_call_ended"
|
||||
current_function_name = None
|
||||
invoke_name_regex = re.compile(r'<invoke\s+name="([^"]+)"')
|
||||
|
||||
def rebuild_full_text():
|
||||
"""Rebuild full text from sorted chunks like the original processor"""
|
||||
# Sort chunks by sequence
|
||||
sorted_chunks = sorted(chunks, key=lambda c: c.get("sequence", 0))
|
||||
|
||||
full_text_parts = []
|
||||
for chunk in sorted_chunks:
|
||||
content = chunk.get("content", "")
|
||||
if content:
|
||||
parsed_content = try_parse_json(content)
|
||||
if parsed_content and "content" in parsed_content:
|
||||
full_text_parts.append(parsed_content["content"])
|
||||
|
||||
return "".join(full_text_parts)
|
||||
|
||||
async for line in stream:
|
||||
line = line.strip()
|
||||
|
@ -40,28 +60,71 @@ async def print_stream(stream: AsyncGenerator[str, None]):
|
|||
stream_started = True
|
||||
|
||||
if event_type == "status":
|
||||
finish_reason = data.get("finish_reason", "received")
|
||||
status_type = data.get("status_type", data.get("status", "unknown"))
|
||||
# Parse status like the original - merge content with event data
|
||||
status_details = try_parse_json(data.get("content", "{}")) or {}
|
||||
if data.get("status"):
|
||||
status_details["status_type"] = data["status"]
|
||||
if data.get("message"):
|
||||
status_details["message"] = data["message"]
|
||||
|
||||
full_status = {**data, **status_details}
|
||||
finish_reason = full_status.get("finish_reason", "received")
|
||||
status_type = full_status.get("status_type", full_status.get("status", "unknown"))
|
||||
print(f"[STATUS] {status_type} {finish_reason}")
|
||||
|
||||
elif event_type == "assistant":
|
||||
# Handle assistant messages - print at end of message
|
||||
message_id = data.get("message_id")
|
||||
sequence = data.get("sequence")
|
||||
content = data.get("content", "")
|
||||
if content and data.get("message_id"):
|
||||
parsed_content = try_parse_json(content)
|
||||
if parsed_content:
|
||||
role = parsed_content.get("role", "unknown")
|
||||
message_content = parsed_content.get("content", "")
|
||||
preview = (
|
||||
message_content[:100] + "..."
|
||||
if len(message_content) > 100
|
||||
else message_content
|
||||
)
|
||||
print() # New line
|
||||
print(f"[ASSISTANT] {preview}")
|
||||
else:
|
||||
print() # New line
|
||||
print(f"[ASSISTANT] Failed to parse message content")
|
||||
|
||||
# Assistant chunks (message_id is null, has sequence) - accumulate text
|
||||
if message_id is None and sequence is not None:
|
||||
# Add chunk to collection
|
||||
chunks.append(data)
|
||||
|
||||
# Rebuild full text from all chunks
|
||||
full_text = rebuild_full_text()
|
||||
|
||||
# Check for function call detection
|
||||
if parsing_state == "text":
|
||||
if "<function_calls>" in full_text:
|
||||
parsing_state = "in_function_call"
|
||||
print("\n[TOOL USE DETECTED]")
|
||||
|
||||
elif parsing_state == "in_function_call":
|
||||
if current_function_name is None:
|
||||
match = invoke_name_regex.search(full_text)
|
||||
if match:
|
||||
current_function_name = match.group(1)
|
||||
print(f'[TOOL UPDATE] Calling function: "{current_function_name}"')
|
||||
|
||||
if "</function_calls>" in full_text:
|
||||
parsing_state = "function_call_ended"
|
||||
print("[TOOL USE WAITING]")
|
||||
current_function_name = None
|
||||
|
||||
# Complete assistant messages (message_id is not null) - print final message
|
||||
elif message_id is not None:
|
||||
if content:
|
||||
parsed_content = try_parse_json(content)
|
||||
if parsed_content:
|
||||
role = parsed_content.get("role", "unknown")
|
||||
message_content = parsed_content.get("content", "")
|
||||
preview = (
|
||||
message_content[:100] + "..."
|
||||
if len(message_content) > 100
|
||||
else message_content
|
||||
)
|
||||
print() # New line
|
||||
print(f"[MESSAGE] {role}: {preview}")
|
||||
else:
|
||||
print() # New line
|
||||
print(f"[MESSAGE] Failed to parse message content")
|
||||
|
||||
# Reset state for next message
|
||||
chunks = []
|
||||
parsing_state = "text"
|
||||
current_function_name = None
|
||||
|
||||
elif event_type == "tool":
|
||||
# Handle tool results
|
||||
|
|
Loading…
Reference in New Issue