fix: auto continue response if finish_reason is length

This commit is contained in:
Krishav Raj Singh 2025-07-24 20:22:44 +05:30
parent 6a201da1b4
commit e4a6f5a1ef
2 changed files with 127 additions and 74 deletions

View File

@ -146,6 +146,9 @@ class ResponseProcessor:
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(),
can_auto_continue: bool = False,
auto_continue_count: int = 0,
continuous_state: Optional[Dict[str, Any]] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a streaming LLM response, handling tool calls and execution.
@ -155,13 +158,18 @@ class ResponseProcessor:
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
can_auto_continue: Whether auto-continue is enabled
auto_continue_count: Number of auto-continue cycles
continuous_state: Previous state of the conversation
Yields:
Complete message objects matching the DB schema, except for content chunks.
"""
accumulated_content = ""
# Initialize from continuous state if provided (for auto-continue)
continuous_state = continuous_state or {}
accumulated_content = continuous_state.get('accumulated_content', "")
tool_calls_buffer = {}
current_xml_content = ""
current_xml_content = accumulated_content # equal to accumulated_content if auto-continuing, else blank
xml_chunks_buffer = []
pending_tool_executions = []
yielded_tool_indices = set() # Stores indices of tools whose *status* has been yielded
@ -191,10 +199,13 @@ class ResponseProcessor:
logger.info(f"Streaming Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Strategy={config.tool_execution_strategy}")
thread_run_id = str(uuid.uuid4())
# Reuse thread_run_id for auto-continue or create new one
thread_run_id = continuous_state.get('thread_run_id') or str(uuid.uuid4())
continuous_state['thread_run_id'] = thread_run_id
try:
# --- Save and Yield Start Events ---
# --- Save and Yield Start Events (only if not auto-continuing) ---
if auto_continue_count == 0:
start_content = {"status_type": "thread_run_start", "thread_run_id": thread_run_id}
start_msg_obj = await self.add_message(
thread_id=thread_id, type="status", content=start_content,
@ -210,7 +221,7 @@ class ResponseProcessor:
if assist_start_msg_obj: yield format_for_yield(assist_start_msg_obj)
# --- End Start Events ---
__sequence = 0
__sequence = continuous_state.get('sequence', 0) # get the sequence from the previous auto-continue cycle
async for chunk in llm_response:
# Extract streaming metadata from chunks
@ -493,7 +504,9 @@ class ResponseProcessor:
self.trace.event(name="stream_finished_with_reason_xml_tool_limit_reached_after_xml_tool_calls", level="DEFAULT", status_message=(f"Stream finished with reason: xml_tool_limit_reached after {xml_tool_call_count} XML tool calls"))
# --- SAVE and YIELD Final Assistant Message ---
if accumulated_content:
# Only save assistant message if NOT auto-continuing due to length to avoid duplicate messages
should_auto_continue = (can_auto_continue and finish_reason == 'length')
if accumulated_content and not should_auto_continue:
# ... (Truncate accumulated_content logic) ...
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls and xml_chunks_buffer:
last_xml_chunk = xml_chunks_buffer[-1]
@ -746,6 +759,9 @@ class ResponseProcessor:
return
# --- Save and Yield assistant_response_end ---
# Only save assistant_response_end if not auto-continuing (response is actually complete)
should_auto_continue = (can_auto_continue and finish_reason == 'length')
if not should_auto_continue:
if last_assistant_message_object: # Only save if assistant message was saved
try:
# Calculate response time if we have timing data
@ -815,7 +831,15 @@ class ResponseProcessor:
raise # Use bare 'raise' to preserve the original exception with its traceback
finally:
# Save and Yield the final thread_run_end status
# Update continuous state for potential auto-continue
should_auto_continue = (can_auto_continue and finish_reason == 'length')
if should_auto_continue:
continuous_state['accumulated_content'] = accumulated_content
continuous_state['sequence'] = __sequence
logger.info(f"Updated continuous state for auto-continue with {len(accumulated_content)} chars")
else:
# Save and Yield the final thread_run_end status (only if not auto-continuing and finish_reason is not 'length')
try:
end_content = {"status_type": "thread_run_end"}
end_msg_obj = await self.add_message(

View File

@ -298,6 +298,12 @@ Here are the XML tools available with examples:
auto_continue = True
auto_continue_count = 0
# Shared state for continuous streaming across auto-continues
continuous_state = {
'accumulated_content': '',
'thread_run_id': None
}
# Define inner function to handle a single run
async def _run_once(temp_msg=None):
try:
@ -342,6 +348,18 @@ Here are the XML tools available with examples:
prepared_messages.append(temp_msg)
logger.debug("Added temporary message to the end of prepared messages")
# Add partial assistant content for auto-continue context (without saving to DB)
if auto_continue_count > 0 and continuous_state.get('accumulated_content'):
partial_content = continuous_state.get('accumulated_content', '')
# Create temporary assistant message with just the text content
temporary_assistant_message = {
"role": "assistant",
"content": partial_content
}
prepared_messages.append(temporary_assistant_message)
logger.info(f"Added temporary assistant message with {len(partial_content)} chars for auto-continue context")
# 4. Prepare tools for LLM call
openapi_tool_schemas = None
if config.native_tool_calling:
@ -395,6 +413,9 @@ Here are the XML tools available with examples:
config=config,
prompt_messages=prepared_messages,
llm_model=llm_model,
can_auto_continue=(native_max_auto_continues > 0),
auto_continue_count=auto_continue_count,
continuous_state=continuous_state
)
else:
# Fallback to non-streaming if response is not iterable
@ -467,6 +488,14 @@ Here are the XML tools available with examples:
auto_continue = False
# Still yield the chunk to inform the client
elif chunk.get('type') == 'status':
# if the finish reason is length, auto-continue
content = json.loads(chunk.get('content'))
if content.get('finish_reason') == 'length':
logger.info(f"Detected finish_reason='length', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
auto_continue = True
auto_continue_count += 1
continue
# Otherwise just yield the chunk normally
yield chunk
else: