mirror of https://github.com/kortix-ai/suna.git
Merge pull request #1060 from KrishavRajSingh/krishav/fix/half_finished_response
This commit is contained in:
commit
4b20eb983b
|
@ -146,6 +146,9 @@ class ResponseProcessor:
|
||||||
prompt_messages: List[Dict[str, Any]],
|
prompt_messages: List[Dict[str, Any]],
|
||||||
llm_model: str,
|
llm_model: str,
|
||||||
config: ProcessorConfig = ProcessorConfig(),
|
config: ProcessorConfig = ProcessorConfig(),
|
||||||
|
can_auto_continue: bool = False,
|
||||||
|
auto_continue_count: int = 0,
|
||||||
|
continuous_state: Optional[Dict[str, Any]] = None,
|
||||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
"""Process a streaming LLM response, handling tool calls and execution.
|
"""Process a streaming LLM response, handling tool calls and execution.
|
||||||
|
|
||||||
|
@ -155,19 +158,25 @@ class ResponseProcessor:
|
||||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||||
llm_model: The name of the LLM model used
|
llm_model: The name of the LLM model used
|
||||||
config: Configuration for parsing and execution
|
config: Configuration for parsing and execution
|
||||||
|
can_auto_continue: Whether auto-continue is enabled
|
||||||
|
auto_continue_count: Number of auto-continue cycles
|
||||||
|
continuous_state: Previous state of the conversation
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Complete message objects matching the DB schema, except for content chunks.
|
Complete message objects matching the DB schema, except for content chunks.
|
||||||
"""
|
"""
|
||||||
accumulated_content = ""
|
# Initialize from continuous state if provided (for auto-continue)
|
||||||
|
continuous_state = continuous_state or {}
|
||||||
|
accumulated_content = continuous_state.get('accumulated_content', "")
|
||||||
tool_calls_buffer = {}
|
tool_calls_buffer = {}
|
||||||
current_xml_content = ""
|
current_xml_content = accumulated_content # equal to accumulated_content if auto-continuing, else blank
|
||||||
xml_chunks_buffer = []
|
xml_chunks_buffer = []
|
||||||
pending_tool_executions = []
|
pending_tool_executions = []
|
||||||
yielded_tool_indices = set() # Stores indices of tools whose *status* has been yielded
|
yielded_tool_indices = set() # Stores indices of tools whose *status* has been yielded
|
||||||
tool_index = 0
|
tool_index = 0
|
||||||
xml_tool_call_count = 0
|
xml_tool_call_count = 0
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
should_auto_continue = False
|
||||||
last_assistant_message_object = None # Store the final saved assistant message object
|
last_assistant_message_object = None # Store the final saved assistant message object
|
||||||
tool_result_message_objects = {} # tool_index -> full saved message object
|
tool_result_message_objects = {} # tool_index -> full saved message object
|
||||||
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
|
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
|
||||||
|
@ -191,10 +200,13 @@ class ResponseProcessor:
|
||||||
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
||||||
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
|
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
|
||||||
|
|
||||||
thread_run_id = str(uuid.uuid4())
|
# Reuse thread_run_id for auto-continue or create new one
|
||||||
|
thread_run_id = continuous_state.get('thread_run_id') or str(uuid.uuid4())
|
||||||
|
continuous_state['thread_run_id'] = thread_run_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# --- Save and Yield Start Events ---
|
# --- Save and Yield Start Events (only if not auto-continuing) ---
|
||||||
|
if auto_continue_count == 0:
|
||||||
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
|
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
|
||||||
start_msg_obj = await self.add_message(
|
start_msg_obj = await self.add_message(
|
||||||
thread_id=thread_id, type="status", content=start_content,
|
thread_id=thread_id, type="status", content=start_content,
|
||||||
|
@ -210,7 +222,7 @@ class ResponseProcessor:
|
||||||
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
|
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
|
||||||
# --- End Start Events ---
|
# --- End Start Events ---
|
||||||
|
|
||||||
__sequence = 0
|
__sequence = continuous_state.get('sequence', 0) # get the sequence from the previous auto-continue cycle
|
||||||
|
|
||||||
async for chunk in llm_response:
|
async for chunk in llm_response:
|
||||||
# Extract streaming metadata from chunks
|
# Extract streaming metadata from chunks
|
||||||
|
@ -492,8 +504,12 @@ class ResponseProcessor:
|
||||||
logger.info(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls")
|
logger.info(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls")
|
||||||
self.trace.event(name="stream_finished_with_reason_xml_tool_limit_reached_after_xml_tool_calls", level="DEFAULT", status_message=(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls"))
|
self.trace.event(name="stream_finished_with_reason_xml_tool_limit_reached_after_xml_tool_calls", level="DEFAULT", status_message=(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls"))
|
||||||
|
|
||||||
|
# Calculate if auto-continue is needed if the finish reason is length
|
||||||
|
should_auto_continue = (can_auto_continue and finish_reason == 'length')
|
||||||
|
|
||||||
# --- SAVE and YIELD Final Assistant Message ---
|
# --- SAVE and YIELD Final Assistant Message ---
|
||||||
if accumulated_content:
|
# Only save assistant message if NOT auto-continuing due to length to avoid duplicate messages
|
||||||
|
if accumulated_content and not should_auto_continue:
|
||||||
# ... (Truncate accumulated_content logic) ...
|
# ... (Truncate accumulated_content logic) ...
|
||||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls and xml_chunks_buffer:
|
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls and xml_chunks_buffer:
|
||||||
last_xml_chunk = xml_chunks_buffer[-1]
|
last_xml_chunk = xml_chunks_buffer[-1]
|
||||||
|
@ -746,6 +762,8 @@ class ResponseProcessor:
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- Save and Yield assistant_response_end ---
|
# --- Save and Yield assistant_response_end ---
|
||||||
|
# Only save assistant_response_end if not auto-continuing (response is actually complete)
|
||||||
|
if not should_auto_continue:
|
||||||
if last_assistant_message_object: # Only save if assistant message was saved
|
if last_assistant_message_object: # Only save if assistant message was saved
|
||||||
try:
|
try:
|
||||||
# Calculate response time if we have timing data
|
# Calculate response time if we have timing data
|
||||||
|
@ -815,7 +833,14 @@ class ResponseProcessor:
|
||||||
raise # Use bare 'raise' to preserve the original exception with its traceback
|
raise # Use bare 'raise' to preserve the original exception with its traceback
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Save and Yield the final thread_run_end status
|
# Update continuous state for potential auto-continue
|
||||||
|
if should_auto_continue:
|
||||||
|
continuous_state['accumulated_content'] = accumulated_content
|
||||||
|
continuous_state['sequence'] = __sequence
|
||||||
|
|
||||||
|
logger.info(f"Updated continuous state for auto-continue with {len(accumulated_content)} chars")
|
||||||
|
else:
|
||||||
|
# Save and Yield the final thread_run_end status (only if not auto-continuing and finish_reason is not 'length')
|
||||||
try:
|
try:
|
||||||
end_content = {"status_type": "thread_run_end"}
|
end_content = {"status_type": "thread_run_end"}
|
||||||
end_msg_obj = await self.add_message(
|
end_msg_obj = await self.add_message(
|
||||||
|
|
|
@ -298,6 +298,12 @@ Here are the XML tools available with examples:
|
||||||
auto_continue = True
|
auto_continue = True
|
||||||
auto_continue_count = 0
|
auto_continue_count = 0
|
||||||
|
|
||||||
|
# Shared state for continuous streaming across auto-continues
|
||||||
|
continuous_state = {
|
||||||
|
'accumulated_content': '',
|
||||||
|
'thread_run_id': None
|
||||||
|
}
|
||||||
|
|
||||||
# Define inner function to handle a single run
|
# Define inner function to handle a single run
|
||||||
async def _run_once(temp_msg=None):
|
async def _run_once(temp_msg=None):
|
||||||
try:
|
try:
|
||||||
|
@ -342,6 +348,18 @@ Here are the XML tools available with examples:
|
||||||
prepared_messages.append(temp_msg)
|
prepared_messages.append(temp_msg)
|
||||||
logger.debug("Added temporary message to the end of prepared messages")
|
logger.debug("Added temporary message to the end of prepared messages")
|
||||||
|
|
||||||
|
# Add partial assistant content for auto-continue context (without saving to DB)
|
||||||
|
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
|
||||||
|
partial_content = continuous_state.get('accumulated_content', '')
|
||||||
|
|
||||||
|
# Create temporary assistant message with just the text content
|
||||||
|
temporary_assistant_message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": partial_content
|
||||||
|
}
|
||||||
|
prepared_messages.append(temporary_assistant_message)
|
||||||
|
logger.info(f"Added temporary assistant message with {len(partial_content)} chars for auto-continue context")
|
||||||
|
|
||||||
# 4. Prepare tools for LLM call
|
# 4. Prepare tools for LLM call
|
||||||
openapi_tool_schemas = None
|
openapi_tool_schemas = None
|
||||||
if config.native_tool_calling:
|
if config.native_tool_calling:
|
||||||
|
@ -395,6 +413,9 @@ Here are the XML tools available with examples:
|
||||||
config=config,
|
config=config,
|
||||||
prompt_messages=prepared_messages,
|
prompt_messages=prepared_messages,
|
||||||
llm_model=llm_model,
|
llm_model=llm_model,
|
||||||
|
can_auto_continue=(native_max_auto_continues > 0),
|
||||||
|
auto_continue_count=auto_continue_count,
|
||||||
|
continuous_state=continuous_state
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback to non-streaming if response is not iterable
|
# Fallback to non-streaming if response is not iterable
|
||||||
|
@ -467,6 +488,14 @@ Here are the XML tools available with examples:
|
||||||
auto_continue = False
|
auto_continue = False
|
||||||
# Still yield the chunk to inform the client
|
# Still yield the chunk to inform the client
|
||||||
|
|
||||||
|
elif chunk.get('type') == 'status':
|
||||||
|
# if the finish reason is length, auto-continue
|
||||||
|
content = json.loads(chunk.get('content'))
|
||||||
|
if content.get('finish_reason') == 'length':
|
||||||
|
logger.info(f"Detected finish_reason='length', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
|
||||||
|
auto_continue = True
|
||||||
|
auto_continue_count += 1
|
||||||
|
continue
|
||||||
# Otherwise just yield the chunk normally
|
# Otherwise just yield the chunk normally
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue