Merge pull request #1060 from KrishavRajSingh/krishav/fix/half_finished_response

This commit is contained in:
Marko Kraemer 2025-07-24 18:02:50 +02:00 committed by GitHub
commit 4b20eb983b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 74 deletions

View File

@ -146,6 +146,9 @@ class ResponseProcessor:
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(),
can_auto_continue: bool = False,
auto_continue_count: int = 0,
continuous_state: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a streaming LLM response, handling tool calls and execution.
@ -155,19 +158,25 @@ class ResponseProcessor:
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
can_auto_continue: Whether auto-continue is enabled
auto_continue_count: Number of auto-continue cycles
continuous_state: Previous state of the conversation
Yields:
Complete message objects matching the DB schema, except for content chunks.
"""
accumulated_content = ""
# Initialize from continuous state if provided (for auto-continue)
continuous_state = continuous_state or {}
accumulated_content = continuous_state.get('accumulated_content', "")
tool_calls_buffer = {}
current_xml_content = ""
current_xml_content = accumulated_content # equal to accumulated_content if auto-continuing, else blank
xml_chunks_buffer = []
pending_tool_executions = []
yielded_tool_indices = set() # Stores indices of tools whose *status* has been yielded
tool_index = 0
xml_tool_call_count = 0
finish_reason = None
should_auto_continue = False
last_assistant_message_object = None # Store the final saved assistant message object
tool_result_message_objects = {} # tool_index -> full saved message object
has_printed_thinking_prefix = False # Flag for printing thinking prefix only once
@ -191,26 +200,29 @@ class ResponseProcessor:
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
thread_run_id = str(uuid.uuid4())
# Reuse thread_run_id for auto-continue or create new one
thread_run_id = continuous_state.get('thread_run_id') or str(uuid.uuid4())
continuous_state['thread_run_id'] = thread_run_id
try:
# --- Save and Yield Start Events ---
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
start_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=start_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
)
if start_msg_obj: yield format_for_yield(start_msg_obj)
# --- Save and Yield Start Events (only if not auto-continuing) ---
if auto_continue_count == 0:
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
start_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=start_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
)
if start_msg_obj: yield format_for_yield(start_msg_obj)
assist_start_content = {"status_type": "assistant_response_start"}
assist_start_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=assist_start_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
)
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
assist_start_content = {"status_type": "assistant_response_start"}
assist_start_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=assist_start_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id}
)
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
# --- End Start Events ---
__sequence = 0
__sequence = continuous_state.get('sequence', 0) # get the sequence from the previous auto-continue cycle
async for chunk in llm_response:
# Extract streaming metadata from chunks
@ -492,8 +504,12 @@ class ResponseProcessor:
logger.info(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls")
self.trace.event(name="stream_finished_with_reason_xml_tool_limit_reached_after_xml_tool_calls", level="DEFAULT", status_message=(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls"))
# Calculate if auto-continue is needed if the finish reason is length
should_auto_continue = (can_auto_continue and finish_reason == 'length')
# --- SAVE and YIELD Final Assistant Message ---
if accumulated_content:
# Only save assistant message if NOT auto-continuing due to length to avoid duplicate messages
if accumulated_content and not should_auto_continue:
# ... (Truncate accumulated_content logic) ...
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls and xml_chunks_buffer:
last_xml_chunk = xml_chunks_buffer[-1]
@ -746,53 +762,55 @@ class ResponseProcessor:
return
# --- Save and Yield assistant_response_end ---
if last_assistant_message_object: # Only save if assistant message was saved
try:
# Calculate response time if we have timing data
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
# Only save assistant_response_end if not auto-continuing (response is actually complete)
if not should_auto_continue:
if last_assistant_message_object: # Only save if assistant message was saved
try:
# Calculate response time if we have timing data
if streaming_metadata["first_chunk_time"] and streaming_metadata["last_chunk_time"]:
streaming_metadata["response_ms"] = (streaming_metadata["last_chunk_time"] - streaming_metadata["first_chunk_time"]) * 1000
# Create a LiteLLM-like response object for streaming
# Check if we have any actual usage data
has_usage_data = (
streaming_metadata["usage"]["prompt_tokens"] > 0 or
streaming_metadata["usage"]["completion_tokens"] > 0 or
streaming_metadata["usage"]["total_tokens"] > 0
)
assistant_end_content = {
"choices": [
{
"finish_reason": finish_reason or "stop",
"index": 0,
"message": {
"role": "assistant",
"content": accumulated_content,
"tool_calls": complete_native_tool_calls or None
# Create a LiteLLM-like response object for streaming
# Check if we have any actual usage data
has_usage_data = (
streaming_metadata["usage"]["prompt_tokens"] > 0 or
streaming_metadata["usage"]["completion_tokens"] > 0 or
streaming_metadata["usage"]["total_tokens"] > 0
)
assistant_end_content = {
"choices": [
{
"finish_reason": finish_reason or "stop",
"index": 0,
"message": {
"role": "assistant",
"content": accumulated_content,
"tool_calls": complete_native_tool_calls or None
}
}
}
],
"created": streaming_metadata.get("created"),
"model": streaming_metadata.get("model", llm_model),
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
"streaming": True, # Add flag to indicate this was reconstructed from streaming
}
# Only include response_ms if we have timing data
if streaming_metadata.get("response_ms"):
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
logger.info("Assistant response end saved for stream")
except Exception as e:
logger.error(f"Error saving assistant response end for stream: {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
],
"created": streaming_metadata.get("created"),
"model": streaming_metadata.get("model", llm_model),
"usage": streaming_metadata["usage"], # Always include usage like LiteLLM does
"streaming": True, # Add flag to indicate this was reconstructed from streaming
}
# Only include response_ms if we have timing data
if streaming_metadata.get("response_ms"):
assistant_end_content["response_ms"] = streaming_metadata["response_ms"]
await self.add_message(
thread_id=thread_id,
type="assistant_response_end",
content=assistant_end_content,
is_llm_message=False,
metadata={"thread_run_id": thread_run_id}
)
logger.info("Assistant response end saved for stream")
except Exception as e:
logger.error(f"Error saving assistant response end for stream: {str(e)}")
self.trace.event(name="error_saving_assistant_response_end_for_stream", level="ERROR", status_message=(f"Error saving assistant response end for stream: {str(e)}"))
except Exception as e:
logger.error(f"Error processing stream: {str(e)}", exc_info=True)
@ -815,17 +833,24 @@ class ResponseProcessor:
raise # Use bare 'raise' to preserve the original exception with its traceback
finally:
# Save and Yield the final thread_run_end status
try:
end_content = {"status_type": "thread_run_end"}
end_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=end_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
)
if end_msg_obj: yield format_for_yield(end_msg_obj)
except Exception as final_e:
logger.error(f"Error in finally block: {str(final_e)}", exc_info=True)
self.trace.event(name="error_in_finally_block", level="ERROR", status_message=(f"Error in finally block: {str(final_e)}"))
# Update continuous state for potential auto-continue
if should_auto_continue:
continuous_state['accumulated_content'] = accumulated_content
continuous_state['sequence'] = __sequence
logger.info(f"Updated continuous state for auto-continue with {len(accumulated_content)} chars")
else:
# Save and Yield the final thread_run_end status (only if not auto-continuing and finish_reason is not 'length')
try:
end_content = {"status_type": "thread_run_end"}
end_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=end_content,
is_llm_message=False, metadata={"thread_run_id": thread_run_id if 'thread_run_id' in locals() else None}
)
if end_msg_obj: yield format_for_yield(end_msg_obj)
except Exception as final_e:
logger.error(f"Error in finally block: {str(final_e)}", exc_info=True)
self.trace.event(name="error_in_finally_block", level="ERROR", status_message=(f"Error in finally block: {str(final_e)}"))
async def process_non_streaming_response(
self,

View File

@ -297,6 +297,12 @@ Here are the XML tools available with examples:
# Control whether we need to auto-continue due to tool_calls finish reason
auto_continue = True
auto_continue_count = 0
# Shared state for continuous streaming across auto-continues
continuous_state = {
'accumulated_content': '',
'thread_run_id': None
}
# Define inner function to handle a single run
async def _run_once(temp_msg=None):
@ -342,6 +348,18 @@ Here are the XML tools available with examples:
prepared_messages.append(temp_msg)
logger.debug("Added temporary message to the end of prepared messages")
# Add partial assistant content for auto-continue context (without saving to DB)
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
partial_content = continuous_state.get('accumulated_content', '')
# Create temporary assistant message with just the text content
temporary_assistant_message = {
"role": "assistant",
"content": partial_content
}
prepared_messages.append(temporary_assistant_message)
logger.info(f"Added temporary assistant message with {len(partial_content)} chars for auto-continue context")
# 4. Prepare tools for LLM call
openapi_tool_schemas = None
if config.native_tool_calling:
@ -395,6 +413,9 @@ Here are the XML tools available with examples:
config=config,
prompt_messages=prepared_messages,
llm_model=llm_model,
can_auto_continue=(native_max_auto_continues > 0),
auto_continue_count=auto_continue_count,
continuous_state=continuous_state
)
else:
# Fallback to non-streaming if response is not iterable
@ -467,6 +488,14 @@ Here are the XML tools available with examples:
auto_continue = False
# Still yield the chunk to inform the client
elif chunk.get('type') == 'status':
# if the finish reason is length, auto-continue
content = json.loads(chunk.get('content'))
if content.get('finish_reason') == 'length':
logger.info(f"Detected finish_reason='length', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
auto_continue = True
auto_continue_count += 1
continue
# Otherwise just yield the chunk normally
yield chunk
else: