diff --git a/backend/agent/api.py b/backend/agent/api.py index 6ab8891f..b09a5e28 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -157,6 +157,11 @@ async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None) except Exception as e: logger.error(f"Failed to find or signal active instances: {str(e)}") + # Make sure to remove from active_agent_runs + if agent_run_id in active_agent_runs: + del active_agent_runs[agent_run_id] + logger.debug(f"Removed agent run {agent_run_id} from active_agent_runs during stop") + logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}") async def restore_running_agent_runs(): @@ -390,43 +395,79 @@ async def stream_agent_run( async def stream_generator(): logger.debug(f"Streaming responses for agent run: {agent_run_id}") - # Check if this is an active run with stored responses - if agent_run_id in active_agent_runs: - # First, send all existing responses - stored_responses = active_agent_runs[agent_run_id] - logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}") - - for response in stored_responses: - yield f"data: {json.dumps(response)}\n\n" - - # If the run is still active (status is running), set up to stream new responses - if agent_run_data['status'] == 'running': - # Get the current length to know where to start watching for new responses - current_length = len(stored_responses) - - # Keep checking for new responses - while agent_run_id in active_agent_runs: - # Check if there are new responses - if len(active_agent_runs[agent_run_id]) > current_length: - # Send all new responses - for i in range(current_length, len(active_agent_runs[agent_run_id])): - response = active_agent_runs[agent_run_id][i] - yield f"data: {json.dumps(response)}\n\n" - - # Update current length - current_length = len(active_agent_runs[agent_run_id]) - - # Brief pause before checking again - await asyncio.sleep(0.1) - else: - # If the run is not active or we don't have stored responses, - # send a message indicating the run is not available for streaming - logger.warning(f"Agent run {agent_run_id} not found in active runs") - yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n" + # Track if we've sent a completion message + sent_completion = False - # Always send a completion status at the end - yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n" - logger.debug(f"Streaming complete for agent run: {agent_run_id}") + try: + # Check if this is an active run with stored responses + if agent_run_id in active_agent_runs: + # First, send all existing responses + stored_responses = active_agent_runs[agent_run_id] + logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}") + + for response in stored_responses: + yield f"data: {json.dumps(response)}\n\n" + + # Check if this is a completion message + if response.get('type') == 'status': + if response.get('status') == 'completed' or response.get('status_type') == 'thread_run_end': + sent_completion = True + + # If the run is still active (status is running), set up to stream new responses + if agent_run_data['status'] == 'running': + # Get the current length to know where to start watching for new responses + current_length = len(stored_responses) + + # Setup a timeout mechanism + start_time = datetime.now(timezone.utc) + timeout_seconds = 300 # 5 minutes max wait time + + # Keep checking for new responses + while agent_run_id in active_agent_runs: + # Check if there are new responses + if len(active_agent_runs[agent_run_id]) > current_length: + # Send all new responses + for i in range(current_length, len(active_agent_runs[agent_run_id])): + response = active_agent_runs[agent_run_id][i] + yield f"data: {json.dumps(response)}\n\n" + + # Check if this is a completion message + if response.get('type') == 'status': + if response.get('status') == 'completed' or response.get('status_type') == 'thread_run_end': + sent_completion = True + + # Update current length + current_length = len(active_agent_runs[agent_run_id]) + + # Check for timeout + elapsed = (datetime.now(timezone.utc) - start_time).total_seconds() + if elapsed > timeout_seconds: + logger.warning(f"Stream timeout after {timeout_seconds}s for agent run: {agent_run_id}") + break + + # Brief pause before checking again + await asyncio.sleep(0.1) + else: + # If the run is not active or we don't have stored responses, + # send a message indicating the run is not available for streaming + logger.warning(f"Agent run {agent_run_id} not found in active runs") + yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n" + + # Always send a completion status at the end if we haven't already + if not sent_completion: + completion_status = 'completed' + # Use the actual status from database if available + if agent_run_data['status'] in ['failed', 'stopped']: + completion_status = agent_run_data['status'] + + yield f"data: {json.dumps({'type': 'status', 'status': completion_status, 'message': f'Stream ended with status: {completion_status}'})}\n\n" + + logger.debug(f"Streaming complete for agent run: {agent_run_id}") + except Exception as e: + logger.error(f"Error in stream generator: {str(e)}", exc_info=True) + # Send error message if we encounter an exception + if not sent_completion: + yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Stream error: {str(e)}'})}\n\n" # Return a streaming response return StreamingResponse( @@ -449,6 +490,7 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s # Tracking variables total_responses = 0 start_time = datetime.now(timezone.utc) + thread_run_ended = False # Track if we received a thread_run_end signal # Create a pubsub to listen for control messages pubsub = None @@ -582,6 +624,11 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses) break + # Check for thread_run_end signal from ResponseProcessor + if response.get('type') == 'status' and response.get('status_type') == 'thread_run_end': + logger.info(f"Received thread_run_end signal from ResponseProcessor for agent run: {agent_run_id}") + thread_run_ended = True + # Store response in memory if agent_run_id in active_agent_runs: active_agent_runs[agent_run_id].append(response) @@ -675,5 +722,10 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s logger.debug(f"Deleted active run key for agent run: {agent_run_id} (instance: {instance_id})") except Exception as e: logger.warning(f"Error deleting active run key: {str(e)}") + + # Remove from active_agent_runs to ensure stream stops + if agent_run_id in active_agent_runs: + del active_agent_runs[agent_run_id] + logger.debug(f"Removed agent run {agent_run_id} from active_agent_runs") - logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})") \ No newline at end of file + logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id}, thread_run_ended: {thread_run_ended})") \ No newline at end of file diff --git a/backend/agent/run.py b/backend/agent/run.py index 4565ac36..9d7fb7fb 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -72,16 +72,14 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru } break - # Check if last message is from assistant using direct Supabase query - latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute() - if latest_message.data and len(latest_message.data) > 0: - message_type = latest_message.data[0].get('type') - if message_type == 'assistant': - print(f"Last message was from assistant, stopping execution") - continue_execution = False - break + # Check for termination signals in the messages + should_terminate, termination_reason = await check_for_termination_signals(client, thread_id) + if should_terminate: + print(f"Terminating execution: {termination_reason}") + continue_execution = False + break - # Get the latest message from messages table that its tpye is browser_state + # Get the latest browser state message if available latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() temporary_message = None if latest_browser_state.data and len(latest_browser_state.data) > 0: @@ -112,58 +110,178 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru print(f"Error parsing browser state: {e}") # print(latest_browser_state.data[0]) - response = await thread_manager.run_thread( - thread_id=thread_id, - system_prompt=system_message, - stream=stream, - llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), - llm_temperature=0, - llm_max_tokens=64000, - tool_choice="auto", - max_xml_tool_calls=1, - temporary_message=temporary_message, - processor_config=ProcessorConfig( - xml_tool_calling=True, - native_tool_calling=False, - execute_tools=True, - execute_on_stream=True, - tool_execution_strategy="parallel", - xml_adding_strategy="user_message" - ), - native_max_auto_continues=native_max_auto_continues, - include_xml_examples=True, - ) - - if isinstance(response, dict) and "status" in response and response["status"] == "error": - yield response - break - - # Track if we see ask or complete tool calls - last_tool_call = None - - async for chunk in response: - # Check if this is a tool call chunk for ask or complete - if chunk.get('type') == 'tool_call': - tool_call = chunk.get('tool_call', {}) - function_name = tool_call.get('function', {}).get('name', '') - if function_name in ['ask', 'complete']: - last_tool_call = function_name - # Check for XML versions like or in content chunks - elif chunk.get('type') == 'content' and 'content' in chunk: - content = chunk.get('content', '') - if '' in content or '' in content: - xml_tool = 'ask' if '' in content else 'complete' - last_tool_call = xml_tool - print(f"Agent used XML tool: {xml_tool}") + try: + # Track if we see ask or complete tool calls + last_tool_call = None + response = await thread_manager.run_thread( + thread_id=thread_id, + system_prompt=system_message, + stream=stream, + llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), + llm_temperature=0, + llm_max_tokens=64000, + tool_choice="auto", + max_xml_tool_calls=1, + temporary_message=temporary_message, + processor_config=ProcessorConfig( + xml_tool_calling=True, + native_tool_calling=False, + execute_tools=True, + execute_on_stream=True, + tool_execution_strategy="parallel", + xml_adding_strategy="user_message" + ), + native_max_auto_continues=native_max_auto_continues, + include_xml_examples=True, + ) + + if isinstance(response, dict) and "status" in response and response["status"] == "error": + yield response + break + + try: + # Store XML content across chunks for better detection + accumulated_xml_content = "" + + async for chunk in response: + # Check if this is a tool call chunk for ask or complete + if chunk.get('type') == 'tool_call': + tool_call = chunk.get('tool_call', {}) + function_name = tool_call.get('function', {}).get('name', '') + if function_name in ['ask', 'complete']: + last_tool_call = function_name + print(f"Detected native tool call: {function_name}") + + # Check for XML versions like or in content chunks + elif chunk.get('type') == 'content' and 'content' in chunk: + content = chunk.get('content', '') + # Accumulate content for more reliable XML detection + accumulated_xml_content += content + + # Check for complete XML tags + if '' in accumulated_xml_content and '' in accumulated_xml_content: + last_tool_call = 'ask' + print(f"Detected XML ask tool") + + if '' in accumulated_xml_content and '' in accumulated_xml_content: + last_tool_call = 'complete' + print(f"Detected XML complete tool") + + # Check if content has a tool call completion status + elif chunk.get('type') == 'tool_status': + status = chunk.get('status') + function_name = chunk.get('function_name', '') + + if status == 'completed' and function_name in ['ask', 'complete']: + last_tool_call = function_name + print(f"Detected completed tool call status for: {function_name}") + + # Check tool result messages for ask/complete tools + elif chunk.get('type') == 'tool_result': + function_name = chunk.get('name', '') + if function_name in ['ask', 'complete']: + last_tool_call = function_name + print(f"Detected tool result for: {function_name}") + + # Always yield the chunk to the client + yield chunk - yield chunk - - # Check if we should stop based on the last tool call - if last_tool_call in ['ask', 'complete']: - print(f"Agent decided to stop with tool: {last_tool_call}") - continue_execution = False - + # Check if we should stop immediately after processing this chunk + if last_tool_call in ['ask', 'complete']: + print(f"Agent decided to stop with tool: {last_tool_call}") + continue_execution = False + + # Add a clear status message to the database to signal termination + await client.table('messages').insert({ + 'thread_id': thread_id, + 'type': 'status', + 'content': json.dumps({ + "status_type": "agent_termination", + "reason": f"Tool '{last_tool_call}' executed" + }), + 'is_llm_message': False, + 'metadata': json.dumps({"termination_signal": True}) + }).execute() + + # We don't break here to ensure all chunks are yielded, + # but the next iteration won't start due to continue_execution = False + + except Exception as stream_error: + print(f"Error during stream processing: {str(stream_error)}") + yield { + "type": "status", + "status": "error", + "message": f"Stream processing error: {str(stream_error)}" + } + break + + # Double-check termination condition after all chunks processed + if last_tool_call in ['ask', 'complete']: + print(f"Confirming termination after stream with tool: {last_tool_call}") + continue_execution = False + except Exception as e: + print(f"Error running thread manager: {str(e)}") + yield { + "type": "status", + "status": "error", + "message": f"Thread manager error: {str(e)}" + } + break +async def check_for_termination_signals(client, thread_id): + """Check database for signals that should terminate the agent execution.""" + try: + # Check the last message type first + latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute() + if latest_message.data and len(latest_message.data) > 0: + message_type = latest_message.data[0].get('type') + + # If last message is from assistant, stop execution + if message_type == 'assistant': + return True, "Last message was from assistant" + + # Check for tool-related termination signals + if message_type == 'tool': + try: + content = json.loads(latest_message.data[0].get('content', '{}')) + if content.get('name') in ['ask', 'complete']: + return True, f"Tool '{content.get('name')}' was executed" + except: + pass + + # Check for special status messages with termination signals + if message_type == 'status': + try: + content = json.loads(latest_message.data[0].get('content', '{}')) + metadata = json.loads(latest_message.data[0].get('metadata', '{}')) + + # Check for explicit termination signal in metadata + if metadata.get('termination_signal') == True: + return True, "Explicit termination signal found" + + # Check for agent_termination status type + if content.get('status_type') == 'agent_termination': + return True, content.get('reason', 'Agent termination status found') + except: + pass + + # Also look for specific ask/complete tool execution in recent messages + recent_tool_messages = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'tool').order('created_at', desc=True).limit(5).execute() + if recent_tool_messages.data: + for msg in recent_tool_messages.data: + try: + content = json.loads(msg.get('content', '{}')) + if isinstance(content, dict) and content.get('role') == 'tool': + tool_name = content.get('name', '') + if tool_name in ['ask', 'complete']: + return True, f"Recent '{tool_name}' tool execution found" + except: + continue + + return False, None + except Exception as e: + print(f"Error checking for termination signals: {e}") + return False, None # TESTING