mirror of https://github.com/kortix-ai/suna.git
Merge pull request #1060 from KrishavRajSingh/krishav/fix/half_finished_response
This commit is contained in:
commit
4b20eb983b
|
@ -146,6 +146,9 @@ class ResponseProcessor:
|
|||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
can_auto_continue: bool = False,
|
||||
auto_continue_count: int = 0,
|
||||
continuous_state: Optional[Dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process a streaming LLM response, handling tool calls and execution.
|
||||
|
||||
|
@ -155,19 +158,25 @@ class ResponseProcessor:
|
|||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
can_auto_continue: Whether auto-continue is enabled
|
||||
auto_continue_count: Number of auto-continue cycles
|
||||
continuous_state: Previous state of the conversation
|
||||
|
||||
Yields:
|
||||
Complete message objects matching the DB schema, except for content chunks.
|
||||
"""
|
||||
accumulated_content = ""
|
||||
# Initialize from continuous state if provided (for auto-continue)
|
||||
continuous_state = continuous_state or {}
|
||||
accumulated_content = continuous_state.get('accumulated_content', "")
|
||||
tool_calls_buffer = {}
|
||||
current_xml_content = ""
|
||||
current_xml_content = accumulated_content # equal to accumulated_content if auto-continuing, else blank
|
||||
xml_chunks_buffer = []
|
||||
pending_tool_executions = []
|
||||
yielded_tool_indices = set() # Stores indices of tools whose *status* has been yielded
|
||||
tool_index = 0
|
||||
xml_tool_call_count = 0
|
||||
finish_reason = None
|
||||
should_auto_continue = False
|
||||
last_assistant_message_object = None # Store the final saved assistant message object
|
||||
tool_result_message_objects = {} # tool_index -> full saved message object
|
||||
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
|
||||
|
@ -191,26 +200,29 @@ class ResponseProcessor:
|
|||
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
||||
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
|
||||
|
||||
thread_run_id = str(uuid.uuid4())
|
||||
# Reuse thread_run_id for auto-continue or create new one
|
||||
thread_run_id = continuous_state.get('thread_run_id') or str(uuid.uuid4())
|
||||
continuous_state['thread_run_id'] = thread_run_id
|
||||
|
||||
try:
|
||||
# --- Save and Yield Start Events ---
|
||||
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
|
||||
start_msg_obj = await self.add_message(
|
||||
thread_id=thread_id, type="status", content=start_content,
|
||||
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
if start_msg_obj: yield format_for_yield(start_msg_obj)
|
||||
# --- Save and Yield Start Events (only if not auto-continuing) ---
|
||||
if auto_continue_count == 0:
|
||||
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
|
||||
start_msg_obj = await self.add_message(
|
||||
thread_id=thread_id, type="status", content=start_content,
|
||||
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
if start_msg_obj: yield format_for_yield(start_msg_obj)
|
||||
|
||||
assist_start_content = {"status_type": "assistant_response_start"}
|
||||
assist_start_msg_obj = await self.add_message(
|
||||
thread_id=thread_id, type="status", content=assist_start_content,
|
||||
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
|
||||
assist_start_content = {"status_type": "assistant_response_start"}
|
||||
assist_start_msg_obj = await self.add_message(
|
||||
thread_id=thread_id, type="status", content=assist_start_content,
|
||||
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
|
||||
# --- End Start Events ---
|
||||
|
||||
__sequence = 0
|
||||
__sequence = continuous_state.get('sequence', 0) # get the sequence from the previous auto-continue cycle
|
||||
|
||||
async for chunk in llm_response:
|
||||
# Extract streaming metadata from chunks
|
||||
|
@ -492,8 +504,12 @@ class ResponseProcessor:
|
|||
logger.info(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls")
|
||||
self.trace.event(name="stream_finished_with_reason_xml_tool_limit_reached_after_xml_tool_calls", level="DEFAULT", status_message=(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls"))
|
||||
|
||||
# Calculate if auto-continue is needed if the finish reason is length
|
||||
should_auto_continue = (can_auto_continue and finish_reason == 'length')
|
||||
|
||||
# --- SAVE and YIELD Final Assistant Message ---
|
||||
if accumulated_content:
|
||||
# Only save assistant message if NOT auto-continuing due to length to avoid duplicate messages
|
||||
if accumulated_content and not should_auto_continue:
|
||||
# ... (Truncate accumulated_content logic) ...
|
||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls and xml_chunks_buffer:
|
||||
last_xml_chunk = xml_chunks_buffer[-1]
|
||||
|
@ -746,53 +762,55 @@ class ResponseProcessor:
|
|||
return
|
||||
|
||||
# --- Save and Yield assistant_response_end ---
|
||||
if last_assistant_message_object: # Only save if assistant message was saved
|
||||
try:
|
||||
# Calculate response time if we have timing data
|
||||
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
|
||||
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
|
||||
# Only save assistant_response_end if not auto-continuing (response is actually complete)
|
||||
if not should_auto_continue:
|
||||
if last_assistant_message_object: # Only save if assistant message was saved
|
||||
try:
|
||||
# Calculate response time if we have timing data
|
||||
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
|
||||
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
|
||||
|
||||
# Create a LiteLLM-like response object for streaming
|
||||
# Check if we have any actual usage data
|
||||
has_usage_data = (
|
||||
streaming_metadata["usage"]["prompt_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["completion_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["total_tokens"] > 0
|
||||
)
|
||||
|
||||
assistant_end_content = {
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": finish_reason or "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": accumulated_content,
|
||||
"tool_calls": complete_native_tool_calls or None
|
||||
# Create a LiteLLM-like response object for streaming
|
||||
# Check if we have any actual usage data
|
||||
has_usage_data = (
|
||||
streaming_metadata["usage"]["prompt_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["completion_tokens"] > 0 or
|
||||
streaming_metadata["usage"]["total_tokens"] > 0
|
||||
)
|
||||
|
||||
assistant_end_content = {
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": finish_reason or "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": accumulated_content,
|
||||
"tool_calls": complete_native_tool_calls or None
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": streaming_metadata.get("created"),
|
||||
"model": streaming_metadata.get("model", llm_model),
|
||||
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
|
||||
"streaming": True, # Add flag to indicate this was reconstructed from streaming
|
||||
}
|
||||
|
||||
# Only include response_ms if we have timing data
|
||||
if streaming_metadata.get("response_ms"):
|
||||
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
|
||||
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant_response_end",
|
||||
content=assistant_end_content,
|
||||
is_llm_message=False,
|
||||
metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
logger.info("Assistant response end saved for stream")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving assistant response end for stream: {str(e)}")
|
||||
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
|
||||
],
|
||||
"created": streaming_metadata.get("created"),
|
||||
"model": streaming_metadata.get("model", llm_model),
|
||||
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
|
||||
"streaming": True, # Add flag to indicate this was reconstructed from streaming
|
||||
}
|
||||
|
||||
# Only include response_ms if we have timing data
|
||||
if streaming_metadata.get("response_ms"):
|
||||
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
|
||||
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="assistant_response_end",
|
||||
content=assistant_end_content,
|
||||
is_llm_message=False,
|
||||
metadata={"thread_run_id": thread_run_id}
|
||||
)
|
||||
logger.info("Assistant response end saved for stream")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving assistant response end for stream: {str(e)}")
|
||||
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
|
||||
|
@ -815,17 +833,24 @@ class ResponseProcessor:
|
|||
raise # Use bare 'raise' to preserve the original exception with its traceback
|
||||
|
||||
finally:
|
||||
# Save and Yield the final thread_run_end status
|
||||
try:
|
||||
end_content = {"status_type": "thread_run_end"}
|
||||
end_msg_obj = await self.add_message(
|
||||
thread_id=thread_id, type="status", content=end_content,
|
||||
is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
||||
)
|
||||
if end_msg_obj: yield format_for_yield(end_msg_obj)
|
||||
except Exception as final_e:
|
||||
logger.error(f"Error in finally block: {str(final_e)}", exc_info=True)
|
||||
self.trace.event(name="error_in_finally_block", level="ERROR", status_message=(f"Error in finally block: {str(final_e)}"))
|
||||
# Update continuous state for potential auto-continue
|
||||
if should_auto_continue:
|
||||
continuous_state['accumulated_content'] = accumulated_content
|
||||
continuous_state['sequence'] = __sequence
|
||||
|
||||
logger.info(f"Updated continuous state for auto-continue with {len(accumulated_content)} chars")
|
||||
else:
|
||||
# Save and Yield the final thread_run_end status (only if not auto-continuing and finish_reason is not 'length')
|
||||
try:
|
||||
end_content = {"status_type": "thread_run_end"}
|
||||
end_msg_obj = await self.add_message(
|
||||
thread_id=thread_id, type="status", content=end_content,
|
||||
is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
|
||||
)
|
||||
if end_msg_obj: yield format_for_yield(end_msg_obj)
|
||||
except Exception as final_e:
|
||||
logger.error(f"Error in finally block: {str(final_e)}", exc_info=True)
|
||||
self.trace.event(name="error_in_finally_block", level="ERROR", status_message=(f"Error in finally block: {str(final_e)}"))
|
||||
|
||||
async def process_non_streaming_response(
|
||||
self,
|
||||
|
|
|
@ -297,6 +297,12 @@ Here are the XML tools available with examples:
|
|||
# Control whether we need to auto-continue due to tool_calls finish reason
|
||||
auto_continue = True
|
||||
auto_continue_count = 0
|
||||
|
||||
# Shared state for continuous streaming across auto-continues
|
||||
continuous_state = {
|
||||
'accumulated_content': '',
|
||||
'thread_run_id': None
|
||||
}
|
||||
|
||||
# Define inner function to handle a single run
|
||||
async def _run_once(temp_msg=None):
|
||||
|
@ -342,6 +348,18 @@ Here are the XML tools available with examples:
|
|||
prepared_messages.append(temp_msg)
|
||||
logger.debug("Added temporary message to the end of prepared messages")
|
||||
|
||||
# Add partial assistant content for auto-continue context (without saving to DB)
|
||||
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
|
||||
partial_content = continuous_state.get('accumulated_content', '')
|
||||
|
||||
# Create temporary assistant message with just the text content
|
||||
temporary_assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": partial_content
|
||||
}
|
||||
prepared_messages.append(temporary_assistant_message)
|
||||
logger.info(f"Added temporary assistant message with {len(partial_content)} chars for auto-continue context")
|
||||
|
||||
# 4. Prepare tools for LLM call
|
||||
openapi_tool_schemas = None
|
||||
if config.native_tool_calling:
|
||||
|
@ -395,6 +413,9 @@ Here are the XML tools available with examples:
|
|||
config=config,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model,
|
||||
can_auto_continue=(native_max_auto_continues > 0),
|
||||
auto_continue_count=auto_continue_count,
|
||||
continuous_state=continuous_state
|
||||
)
|
||||
else:
|
||||
# Fallback to non-streaming if response is not iterable
|
||||
|
@ -467,6 +488,14 @@ Here are the XML tools available with examples:
|
|||
auto_continue = False
|
||||
# Still yield the chunk to inform the client
|
||||
|
||||
elif chunk.get('type') == 'status':
|
||||
# if the finish reason is length, auto-continue
|
||||
content = json.loads(chunk.get('content'))
|
||||
if content.get('finish_reason') == 'length':
|
||||
logger.info(f"Detected finish_reason='length', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
|
||||
auto_continue = True
|
||||
auto_continue_count += 1
|
||||
continue
|
||||
# Otherwise just yield the chunk normally
|
||||
yield chunk
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue