diff --git a/backend/agent/api.py b/backend/agent/api.py
index 6ab8891f..b09a5e28 100644
--- a/backend/agent/api.py
+++ b/backend/agent/api.py
@@ -157,6 +157,11 @@ async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None)
except Exception as e:
logger.error(f"Failed to find or signal active instances: {str(e)}")
+ # Make sure to remove from active_agent_runs
+ if agent_run_id in active_agent_runs:
+ del active_agent_runs[agent_run_id]
+ logger.debug(f"Removed agent run {agent_run_id} from active_agent_runs during stop")
+
logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}")
async def restore_running_agent_runs():
@@ -390,43 +395,79 @@ async def stream_agent_run(
async def stream_generator():
logger.debug(f"Streaming responses for agent run: {agent_run_id}")
- # Check if this is an active run with stored responses
- if agent_run_id in active_agent_runs:
- # First, send all existing responses
- stored_responses = active_agent_runs[agent_run_id]
- logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}")
-
- for response in stored_responses:
- yield f"data: {json.dumps(response)}\n\n"
-
- # If the run is still active (status is running), set up to stream new responses
- if agent_run_data['status'] == 'running':
- # Get the current length to know where to start watching for new responses
- current_length = len(stored_responses)
-
- # Keep checking for new responses
- while agent_run_id in active_agent_runs:
- # Check if there are new responses
- if len(active_agent_runs[agent_run_id]) > current_length:
- # Send all new responses
- for i in range(current_length, len(active_agent_runs[agent_run_id])):
- response = active_agent_runs[agent_run_id][i]
- yield f"data: {json.dumps(response)}\n\n"
-
- # Update current length
- current_length = len(active_agent_runs[agent_run_id])
-
- # Brief pause before checking again
- await asyncio.sleep(0.1)
- else:
- # If the run is not active or we don't have stored responses,
- # send a message indicating the run is not available for streaming
- logger.warning(f"Agent run {agent_run_id} not found in active runs")
- yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n"
+ # Track if we've sent a completion message
+ sent_completion = False
- # Always send a completion status at the end
- yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n"
- logger.debug(f"Streaming complete for agent run: {agent_run_id}")
+ try:
+ # Check if this is an active run with stored responses
+ if agent_run_id in active_agent_runs:
+ # First, send all existing responses
+ stored_responses = active_agent_runs[agent_run_id]
+ logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}")
+
+ for response in stored_responses:
+ yield f"data: {json.dumps(response)}\n\n"
+
+ # Check if this is a completion message
+ if response.get('type') == 'status':
+ if response.get('status') == 'completed' or response.get('status_type') == 'thread_run_end':
+ sent_completion = True
+
+ # If the run is still active (status is running), set up to stream new responses
+ if agent_run_data['status'] == 'running':
+ # Get the current length to know where to start watching for new responses
+ current_length = len(stored_responses)
+
+ # Setup a timeout mechanism
+ start_time = datetime.now(timezone.utc)
+ timeout_seconds = 300 # 5 minutes max wait time
+
+ # Keep checking for new responses
+ while agent_run_id in active_agent_runs:
+ # Check if there are new responses
+ if len(active_agent_runs[agent_run_id]) > current_length:
+ # Send all new responses
+ for i in range(current_length, len(active_agent_runs[agent_run_id])):
+ response = active_agent_runs[agent_run_id][i]
+ yield f"data: {json.dumps(response)}\n\n"
+
+ # Check if this is a completion message
+ if response.get('type') == 'status':
+ if response.get('status') == 'completed' or response.get('status_type') == 'thread_run_end':
+ sent_completion = True
+
+ # Update current length
+ current_length = len(active_agent_runs[agent_run_id])
+
+ # Check for timeout
+ elapsed = (datetime.now(timezone.utc) - start_time).total_seconds()
+ if elapsed > timeout_seconds:
+ logger.warning(f"Stream timeout after {timeout_seconds}s for agent run: {agent_run_id}")
+ break
+
+ # Brief pause before checking again
+ await asyncio.sleep(0.1)
+ else:
+ # If the run is not active or we don't have stored responses,
+ # send a message indicating the run is not available for streaming
+ logger.warning(f"Agent run {agent_run_id} not found in active runs")
+ yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n"
+
+ # Always send a completion status at the end if we haven't already
+ if not sent_completion:
+ completion_status = 'completed'
+ # Use the actual status from database if available
+ if agent_run_data['status'] in ['failed', 'stopped']:
+ completion_status = agent_run_data['status']
+
+ yield f"data: {json.dumps({'type': 'status', 'status': completion_status, 'message': f'Stream ended with status: {completion_status}'})}\n\n"
+
+ logger.debug(f"Streaming complete for agent run: {agent_run_id}")
+ except Exception as e:
+ logger.error(f"Error in stream generator: {str(e)}", exc_info=True)
+ # Send error message if we encounter an exception
+ if not sent_completion:
+ yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Stream error: {str(e)}'})}\n\n"
# Return a streaming response
return StreamingResponse(
@@ -449,6 +490,7 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
# Tracking variables
total_responses = 0
start_time = datetime.now(timezone.utc)
+ thread_run_ended = False # Track if we received a thread_run_end signal
# Create a pubsub to listen for control messages
pubsub = None
@@ -582,6 +624,11 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses)
break
+ # Check for thread_run_end signal from ResponseProcessor
+ if response.get('type') == 'status' and response.get('status_type') == 'thread_run_end':
+ logger.info(f"Received thread_run_end signal from ResponseProcessor for agent run: {agent_run_id}")
+ thread_run_ended = True
+
# Store response in memory
if agent_run_id in active_agent_runs:
active_agent_runs[agent_run_id].append(response)
@@ -675,5 +722,10 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
logger.debug(f"Deleted active run key for agent run: {agent_run_id} (instance: {instance_id})")
except Exception as e:
logger.warning(f"Error deleting active run key: {str(e)}")
+
+ # Remove from active_agent_runs to ensure stream stops
+ if agent_run_id in active_agent_runs:
+ del active_agent_runs[agent_run_id]
+ logger.debug(f"Removed agent run {agent_run_id} from active_agent_runs")
- logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
\ No newline at end of file
+ logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id}, thread_run_ended: {thread_run_ended})")
\ No newline at end of file
diff --git a/backend/agent/run.py b/backend/agent/run.py
index 4565ac36..9d7fb7fb 100644
--- a/backend/agent/run.py
+++ b/backend/agent/run.py
@@ -72,16 +72,14 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
}
break
- # Check if last message is from assistant using direct Supabase query
- latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
- if latest_message.data and len(latest_message.data) > 0:
- message_type = latest_message.data[0].get('type')
- if message_type == 'assistant':
- print(f"Last message was from assistant, stopping execution")
- continue_execution = False
- break
+ # Check for termination signals in the messages
+ should_terminate, termination_reason = await check_for_termination_signals(client, thread_id)
+ if should_terminate:
+ print(f"Terminating execution: {termination_reason}")
+ continue_execution = False
+ break
- # Get the latest message from messages table that its tpye is browser_state
+ # Get the latest browser state message if available
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
temporary_message = None
if latest_browser_state.data and len(latest_browser_state.data) > 0:
@@ -112,58 +110,178 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
print(f"Error parsing browser state: {e}")
# print(latest_browser_state.data[0])
- response = await thread_manager.run_thread(
- thread_id=thread_id,
- system_prompt=system_message,
- stream=stream,
- llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
- llm_temperature=0,
- llm_max_tokens=64000,
- tool_choice="auto",
- max_xml_tool_calls=1,
- temporary_message=temporary_message,
- processor_config=ProcessorConfig(
- xml_tool_calling=True,
- native_tool_calling=False,
- execute_tools=True,
- execute_on_stream=True,
- tool_execution_strategy="parallel",
- xml_adding_strategy="user_message"
- ),
- native_max_auto_continues=native_max_auto_continues,
- include_xml_examples=True,
- )
-
- if isinstance(response, dict) and "status" in response and response["status"] == "error":
- yield response
- break
-
- # Track if we see ask or complete tool calls
- last_tool_call = None
-
- async for chunk in response:
- # Check if this is a tool call chunk for ask or complete
- if chunk.get('type') == 'tool_call':
- tool_call = chunk.get('tool_call', {})
- function_name = tool_call.get('function', {}).get('name', '')
- if function_name in ['ask', 'complete']:
- last_tool_call = function_name
- # Check for XML versions like or in content chunks
- elif chunk.get('type') == 'content' and 'content' in chunk:
- content = chunk.get('content', '')
- if '' in content or '' in content:
- xml_tool = 'ask' if '' in content else 'complete'
- last_tool_call = xml_tool
- print(f"Agent used XML tool: {xml_tool}")
+ try:
+ # Track if we see ask or complete tool calls
+ last_tool_call = None
+ response = await thread_manager.run_thread(
+ thread_id=thread_id,
+ system_prompt=system_message,
+ stream=stream,
+ llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
+ llm_temperature=0,
+ llm_max_tokens=64000,
+ tool_choice="auto",
+ max_xml_tool_calls=1,
+ temporary_message=temporary_message,
+ processor_config=ProcessorConfig(
+ xml_tool_calling=True,
+ native_tool_calling=False,
+ execute_tools=True,
+ execute_on_stream=True,
+ tool_execution_strategy="parallel",
+ xml_adding_strategy="user_message"
+ ),
+ native_max_auto_continues=native_max_auto_continues,
+ include_xml_examples=True,
+ )
+
+ if isinstance(response, dict) and "status" in response and response["status"] == "error":
+ yield response
+ break
+
+ try:
+ # Store XML content across chunks for better detection
+ accumulated_xml_content = ""
+
+ async for chunk in response:
+ # Check if this is a tool call chunk for ask or complete
+ if chunk.get('type') == 'tool_call':
+ tool_call = chunk.get('tool_call', {})
+ function_name = tool_call.get('function', {}).get('name', '')
+ if function_name in ['ask', 'complete']:
+ last_tool_call = function_name
+ print(f"Detected native tool call: {function_name}")
+
+ # Check for XML versions like or in content chunks
+ elif chunk.get('type') == 'content' and 'content' in chunk:
+ content = chunk.get('content', '')
+ # Accumulate content for more reliable XML detection
+ accumulated_xml_content += content
+
+ # Check for complete XML tags
+ if '' in accumulated_xml_content and '' in accumulated_xml_content:
+ last_tool_call = 'ask'
+ print(f"Detected XML ask tool")
+
+ if '' in accumulated_xml_content and '' in accumulated_xml_content:
+ last_tool_call = 'complete'
+ print(f"Detected XML complete tool")
+
+ # Check if content has a tool call completion status
+ elif chunk.get('type') == 'tool_status':
+ status = chunk.get('status')
+ function_name = chunk.get('function_name', '')
+
+ if status == 'completed' and function_name in ['ask', 'complete']:
+ last_tool_call = function_name
+ print(f"Detected completed tool call status for: {function_name}")
+
+ # Check tool result messages for ask/complete tools
+ elif chunk.get('type') == 'tool_result':
+ function_name = chunk.get('name', '')
+ if function_name in ['ask', 'complete']:
+ last_tool_call = function_name
+ print(f"Detected tool result for: {function_name}")
+
+ # Always yield the chunk to the client
+ yield chunk
- yield chunk
-
- # Check if we should stop based on the last tool call
- if last_tool_call in ['ask', 'complete']:
- print(f"Agent decided to stop with tool: {last_tool_call}")
- continue_execution = False
-
+ # Check if we should stop immediately after processing this chunk
+ if last_tool_call in ['ask', 'complete']:
+ print(f"Agent decided to stop with tool: {last_tool_call}")
+ continue_execution = False
+
+ # Add a clear status message to the database to signal termination
+ await client.table('messages').insert({
+ 'thread_id': thread_id,
+ 'type': 'status',
+ 'content': json.dumps({
+ "status_type": "agent_termination",
+ "reason": f"Tool '{last_tool_call}' executed"
+ }),
+ 'is_llm_message': False,
+ 'metadata': json.dumps({"termination_signal": True})
+ }).execute()
+
+ # We don't break here to ensure all chunks are yielded,
+ # but the next iteration won't start due to continue_execution = False
+
+ except Exception as stream_error:
+ print(f"Error during stream processing: {str(stream_error)}")
+ yield {
+ "type": "status",
+ "status": "error",
+ "message": f"Stream processing error: {str(stream_error)}"
+ }
+ break
+
+ # Double-check termination condition after all chunks processed
+ if last_tool_call in ['ask', 'complete']:
+ print(f"Confirming termination after stream with tool: {last_tool_call}")
+ continue_execution = False
+ except Exception as e:
+ print(f"Error running thread manager: {str(e)}")
+ yield {
+ "type": "status",
+ "status": "error",
+ "message": f"Thread manager error: {str(e)}"
+ }
+ break
+async def check_for_termination_signals(client, thread_id):
+ """Check database for signals that should terminate the agent execution."""
+ try:
+ # Check the last message type first
+ latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
+ if latest_message.data and len(latest_message.data) > 0:
+ message_type = latest_message.data[0].get('type')
+
+ # If last message is from assistant, stop execution
+ if message_type == 'assistant':
+ return True, "Last message was from assistant"
+
+ # Check for tool-related termination signals
+ if message_type == 'tool':
+ try:
+ content = json.loads(latest_message.data[0].get('content', '{}'))
+ if content.get('name') in ['ask', 'complete']:
+ return True, f"Tool '{content.get('name')}' was executed"
+ except:
+ pass
+
+ # Check for special status messages with termination signals
+ if message_type == 'status':
+ try:
+ content = json.loads(latest_message.data[0].get('content', '{}'))
+ metadata = json.loads(latest_message.data[0].get('metadata', '{}'))
+
+ # Check for explicit termination signal in metadata
+ if metadata.get('termination_signal') == True:
+ return True, "Explicit termination signal found"
+
+ # Check for agent_termination status type
+ if content.get('status_type') == 'agent_termination':
+ return True, content.get('reason', 'Agent termination status found')
+ except:
+ pass
+
+ # Also look for specific ask/complete tool execution in recent messages
+ recent_tool_messages = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'tool').order('created_at', desc=True).limit(5).execute()
+ if recent_tool_messages.data:
+ for msg in recent_tool_messages.data:
+ try:
+ content = json.loads(msg.get('content', '{}'))
+ if isinstance(content, dict) and content.get('role') == 'tool':
+ tool_name = content.get('name', '')
+ if tool_name in ['ask', 'complete']:
+ return True, f"Recent '{tool_name}' tool execution found"
+ except:
+ continue
+
+ return False, None
+ except Exception as e:
+ print(f"Error checking for termination signals: {e}")
+ return False, None
# TESTING