mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
582f0d228b
commit
7c104baf73
|
@ -157,6 +157,11 @@ async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to find or signal active instances: {str(e)}")
|
logger.error(f"Failed to find or signal active instances: {str(e)}")
|
||||||
|
|
||||||
|
# Make sure to remove from active_agent_runs
|
||||||
|
if agent_run_id in active_agent_runs:
|
||||||
|
del active_agent_runs[agent_run_id]
|
||||||
|
logger.debug(f"Removed agent run {agent_run_id} from active_agent_runs during stop")
|
||||||
|
|
||||||
logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}")
|
logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}")
|
||||||
|
|
||||||
async def restore_running_agent_runs():
|
async def restore_running_agent_runs():
|
||||||
|
@ -390,43 +395,79 @@ async def stream_agent_run(
|
||||||
async def stream_generator():
|
async def stream_generator():
|
||||||
logger.debug(f"Streaming responses for agent run: {agent_run_id}")
|
logger.debug(f"Streaming responses for agent run: {agent_run_id}")
|
||||||
|
|
||||||
# Check if this is an active run with stored responses
|
# Track if we've sent a completion message
|
||||||
if agent_run_id in active_agent_runs:
|
sent_completion = False
|
||||||
# First, send all existing responses
|
|
||||||
stored_responses = active_agent_runs[agent_run_id]
|
|
||||||
logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}")
|
|
||||||
|
|
||||||
for response in stored_responses:
|
|
||||||
yield f"data: {json.dumps(response)}\n\n"
|
|
||||||
|
|
||||||
# If the run is still active (status is running), set up to stream new responses
|
|
||||||
if agent_run_data['status'] == 'running':
|
|
||||||
# Get the current length to know where to start watching for new responses
|
|
||||||
current_length = len(stored_responses)
|
|
||||||
|
|
||||||
# Keep checking for new responses
|
|
||||||
while agent_run_id in active_agent_runs:
|
|
||||||
# Check if there are new responses
|
|
||||||
if len(active_agent_runs[agent_run_id]) > current_length:
|
|
||||||
# Send all new responses
|
|
||||||
for i in range(current_length, len(active_agent_runs[agent_run_id])):
|
|
||||||
response = active_agent_runs[agent_run_id][i]
|
|
||||||
yield f"data: {json.dumps(response)}\n\n"
|
|
||||||
|
|
||||||
# Update current length
|
|
||||||
current_length = len(active_agent_runs[agent_run_id])
|
|
||||||
|
|
||||||
# Brief pause before checking again
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
else:
|
|
||||||
# If the run is not active or we don't have stored responses,
|
|
||||||
# send a message indicating the run is not available for streaming
|
|
||||||
logger.warning(f"Agent run {agent_run_id} not found in active runs")
|
|
||||||
yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n"
|
|
||||||
|
|
||||||
# Always send a completion status at the end
|
try:
|
||||||
yield f"data: {json.dumps({'type': 'status', 'status': 'completed'})}\n\n"
|
# Check if this is an active run with stored responses
|
||||||
logger.debug(f"Streaming complete for agent run: {agent_run_id}")
|
if agent_run_id in active_agent_runs:
|
||||||
|
# First, send all existing responses
|
||||||
|
stored_responses = active_agent_runs[agent_run_id]
|
||||||
|
logger.debug(f"Sending {len(stored_responses)} existing responses for agent run: {agent_run_id}")
|
||||||
|
|
||||||
|
for response in stored_responses:
|
||||||
|
yield f"data: {json.dumps(response)}\n\n"
|
||||||
|
|
||||||
|
# Check if this is a completion message
|
||||||
|
if response.get('type') == 'status':
|
||||||
|
if response.get('status') == 'completed' or response.get('status_type') == 'thread_run_end':
|
||||||
|
sent_completion = True
|
||||||
|
|
||||||
|
# If the run is still active (status is running), set up to stream new responses
|
||||||
|
if agent_run_data['status'] == 'running':
|
||||||
|
# Get the current length to know where to start watching for new responses
|
||||||
|
current_length = len(stored_responses)
|
||||||
|
|
||||||
|
# Setup a timeout mechanism
|
||||||
|
start_time = datetime.now(timezone.utc)
|
||||||
|
timeout_seconds = 300 # 5 minutes max wait time
|
||||||
|
|
||||||
|
# Keep checking for new responses
|
||||||
|
while agent_run_id in active_agent_runs:
|
||||||
|
# Check if there are new responses
|
||||||
|
if len(active_agent_runs[agent_run_id]) > current_length:
|
||||||
|
# Send all new responses
|
||||||
|
for i in range(current_length, len(active_agent_runs[agent_run_id])):
|
||||||
|
response = active_agent_runs[agent_run_id][i]
|
||||||
|
yield f"data: {json.dumps(response)}\n\n"
|
||||||
|
|
||||||
|
# Check if this is a completion message
|
||||||
|
if response.get('type') == 'status':
|
||||||
|
if response.get('status') == 'completed' or response.get('status_type') == 'thread_run_end':
|
||||||
|
sent_completion = True
|
||||||
|
|
||||||
|
# Update current length
|
||||||
|
current_length = len(active_agent_runs[agent_run_id])
|
||||||
|
|
||||||
|
# Check for timeout
|
||||||
|
elapsed = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||||
|
if elapsed > timeout_seconds:
|
||||||
|
logger.warning(f"Stream timeout after {timeout_seconds}s for agent run: {agent_run_id}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Brief pause before checking again
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
# If the run is not active or we don't have stored responses,
|
||||||
|
# send a message indicating the run is not available for streaming
|
||||||
|
logger.warning(f"Agent run {agent_run_id} not found in active runs")
|
||||||
|
yield f"data: {json.dumps({'type': 'status', 'status': agent_run_data['status'], 'message': 'Run data not available for streaming'})}\n\n"
|
||||||
|
|
||||||
|
# Always send a completion status at the end if we haven't already
|
||||||
|
if not sent_completion:
|
||||||
|
completion_status = 'completed'
|
||||||
|
# Use the actual status from database if available
|
||||||
|
if agent_run_data['status'] in ['failed', 'stopped']:
|
||||||
|
completion_status = agent_run_data['status']
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'type': 'status', 'status': completion_status, 'message': f'Stream ended with status: {completion_status}'})}\n\n"
|
||||||
|
|
||||||
|
logger.debug(f"Streaming complete for agent run: {agent_run_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in stream generator: {str(e)}", exc_info=True)
|
||||||
|
# Send error message if we encounter an exception
|
||||||
|
if not sent_completion:
|
||||||
|
yield f"data: {json.dumps({'type': 'status', 'status': 'error', 'message': f'Stream error: {str(e)}'})}\n\n"
|
||||||
|
|
||||||
# Return a streaming response
|
# Return a streaming response
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
@ -449,6 +490,7 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
# Tracking variables
|
# Tracking variables
|
||||||
total_responses = 0
|
total_responses = 0
|
||||||
start_time = datetime.now(timezone.utc)
|
start_time = datetime.now(timezone.utc)
|
||||||
|
thread_run_ended = False # Track if we received a thread_run_end signal
|
||||||
|
|
||||||
# Create a pubsub to listen for control messages
|
# Create a pubsub to listen for control messages
|
||||||
pubsub = None
|
pubsub = None
|
||||||
|
@ -582,6 +624,11 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses)
|
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Check for thread_run_end signal from ResponseProcessor
|
||||||
|
if response.get('type') == 'status' and response.get('status_type') == 'thread_run_end':
|
||||||
|
logger.info(f"Received thread_run_end signal from ResponseProcessor for agent run: {agent_run_id}")
|
||||||
|
thread_run_ended = True
|
||||||
|
|
||||||
# Store response in memory
|
# Store response in memory
|
||||||
if agent_run_id in active_agent_runs:
|
if agent_run_id in active_agent_runs:
|
||||||
active_agent_runs[agent_run_id].append(response)
|
active_agent_runs[agent_run_id].append(response)
|
||||||
|
@ -675,5 +722,10 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
logger.debug(f"Deleted active run key for agent run: {agent_run_id} (instance: {instance_id})")
|
logger.debug(f"Deleted active run key for agent run: {agent_run_id} (instance: {instance_id})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error deleting active run key: {str(e)}")
|
logger.warning(f"Error deleting active run key: {str(e)}")
|
||||||
|
|
||||||
|
# Remove from active_agent_runs to ensure stream stops
|
||||||
|
if agent_run_id in active_agent_runs:
|
||||||
|
del active_agent_runs[agent_run_id]
|
||||||
|
logger.debug(f"Removed agent run {agent_run_id} from active_agent_runs")
|
||||||
|
|
||||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id}, thread_run_ended: {thread_run_ended})")
|
|
@ -72,16 +72,14 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check if last message is from assistant using direct Supabase query
|
# Check for termination signals in the messages
|
||||||
latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
|
should_terminate, termination_reason = await check_for_termination_signals(client, thread_id)
|
||||||
if latest_message.data and len(latest_message.data) > 0:
|
if should_terminate:
|
||||||
message_type = latest_message.data[0].get('type')
|
print(f"Terminating execution: {termination_reason}")
|
||||||
if message_type == 'assistant':
|
continue_execution = False
|
||||||
print(f"Last message was from assistant, stopping execution")
|
break
|
||||||
continue_execution = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Get the latest message from messages table that its tpye is browser_state
|
# Get the latest browser state message if available
|
||||||
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
latest_browser_state = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
|
||||||
temporary_message = None
|
temporary_message = None
|
||||||
if latest_browser_state.data and len(latest_browser_state.data) > 0:
|
if latest_browser_state.data and len(latest_browser_state.data) > 0:
|
||||||
|
@ -112,58 +110,178 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
||||||
print(f"Error parsing browser state: {e}")
|
print(f"Error parsing browser state: {e}")
|
||||||
# print(latest_browser_state.data[0])
|
# print(latest_browser_state.data[0])
|
||||||
|
|
||||||
response = await thread_manager.run_thread(
|
try:
|
||||||
thread_id=thread_id,
|
# Track if we see ask or complete tool calls
|
||||||
system_prompt=system_message,
|
last_tool_call = None
|
||||||
stream=stream,
|
response = await thread_manager.run_thread(
|
||||||
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
thread_id=thread_id,
|
||||||
llm_temperature=0,
|
system_prompt=system_message,
|
||||||
llm_max_tokens=64000,
|
stream=stream,
|
||||||
tool_choice="auto",
|
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
||||||
max_xml_tool_calls=1,
|
llm_temperature=0,
|
||||||
temporary_message=temporary_message,
|
llm_max_tokens=64000,
|
||||||
processor_config=ProcessorConfig(
|
tool_choice="auto",
|
||||||
xml_tool_calling=True,
|
max_xml_tool_calls=1,
|
||||||
native_tool_calling=False,
|
temporary_message=temporary_message,
|
||||||
execute_tools=True,
|
processor_config=ProcessorConfig(
|
||||||
execute_on_stream=True,
|
xml_tool_calling=True,
|
||||||
tool_execution_strategy="parallel",
|
native_tool_calling=False,
|
||||||
xml_adding_strategy="user_message"
|
execute_tools=True,
|
||||||
),
|
execute_on_stream=True,
|
||||||
native_max_auto_continues=native_max_auto_continues,
|
tool_execution_strategy="parallel",
|
||||||
include_xml_examples=True,
|
xml_adding_strategy="user_message"
|
||||||
)
|
),
|
||||||
|
native_max_auto_continues=native_max_auto_continues,
|
||||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
include_xml_examples=True,
|
||||||
yield response
|
)
|
||||||
break
|
|
||||||
|
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||||
# Track if we see ask or complete tool calls
|
yield response
|
||||||
last_tool_call = None
|
break
|
||||||
|
|
||||||
async for chunk in response:
|
try:
|
||||||
# Check if this is a tool call chunk for ask or complete
|
# Store XML content across chunks for better detection
|
||||||
if chunk.get('type') == 'tool_call':
|
accumulated_xml_content = ""
|
||||||
tool_call = chunk.get('tool_call', {})
|
|
||||||
function_name = tool_call.get('function', {}).get('name', '')
|
async for chunk in response:
|
||||||
if function_name in ['ask', 'complete']:
|
# Check if this is a tool call chunk for ask or complete
|
||||||
last_tool_call = function_name
|
if chunk.get('type') == 'tool_call':
|
||||||
# Check for XML versions like <ask> or <complete> in content chunks
|
tool_call = chunk.get('tool_call', {})
|
||||||
elif chunk.get('type') == 'content' and 'content' in chunk:
|
function_name = tool_call.get('function', {}).get('name', '')
|
||||||
content = chunk.get('content', '')
|
if function_name in ['ask', 'complete']:
|
||||||
if '</ask>' in content or '</complete>' in content:
|
last_tool_call = function_name
|
||||||
xml_tool = 'ask' if '</ask>' in content else 'complete'
|
print(f"Detected native tool call: {function_name}")
|
||||||
last_tool_call = xml_tool
|
|
||||||
print(f"Agent used XML tool: {xml_tool}")
|
# Check for XML versions like <ask> or <complete> in content chunks
|
||||||
|
elif chunk.get('type') == 'content' and 'content' in chunk:
|
||||||
|
content = chunk.get('content', '')
|
||||||
|
# Accumulate content for more reliable XML detection
|
||||||
|
accumulated_xml_content += content
|
||||||
|
|
||||||
|
# Check for complete XML tags
|
||||||
|
if '<ask>' in accumulated_xml_content and '</ask>' in accumulated_xml_content:
|
||||||
|
last_tool_call = 'ask'
|
||||||
|
print(f"Detected XML ask tool")
|
||||||
|
|
||||||
|
if '<complete>' in accumulated_xml_content and '</complete>' in accumulated_xml_content:
|
||||||
|
last_tool_call = 'complete'
|
||||||
|
print(f"Detected XML complete tool")
|
||||||
|
|
||||||
|
# Check if content has a tool call completion status
|
||||||
|
elif chunk.get('type') == 'tool_status':
|
||||||
|
status = chunk.get('status')
|
||||||
|
function_name = chunk.get('function_name', '')
|
||||||
|
|
||||||
|
if status == 'completed' and function_name in ['ask', 'complete']:
|
||||||
|
last_tool_call = function_name
|
||||||
|
print(f"Detected completed tool call status for: {function_name}")
|
||||||
|
|
||||||
|
# Check tool result messages for ask/complete tools
|
||||||
|
elif chunk.get('type') == 'tool_result':
|
||||||
|
function_name = chunk.get('name', '')
|
||||||
|
if function_name in ['ask', 'complete']:
|
||||||
|
last_tool_call = function_name
|
||||||
|
print(f"Detected tool result for: {function_name}")
|
||||||
|
|
||||||
|
# Always yield the chunk to the client
|
||||||
|
yield chunk
|
||||||
|
|
||||||
yield chunk
|
# Check if we should stop immediately after processing this chunk
|
||||||
|
if last_tool_call in ['ask', 'complete']:
|
||||||
# Check if we should stop based on the last tool call
|
print(f"Agent decided to stop with tool: {last_tool_call}")
|
||||||
if last_tool_call in ['ask', 'complete']:
|
continue_execution = False
|
||||||
print(f"Agent decided to stop with tool: {last_tool_call}")
|
|
||||||
continue_execution = False
|
# Add a clear status message to the database to signal termination
|
||||||
|
await client.table('messages').insert({
|
||||||
|
'thread_id': thread_id,
|
||||||
|
'type': 'status',
|
||||||
|
'content': json.dumps({
|
||||||
|
"status_type": "agent_termination",
|
||||||
|
"reason": f"Tool '{last_tool_call}' executed"
|
||||||
|
}),
|
||||||
|
'is_llm_message': False,
|
||||||
|
'metadata': json.dumps({"termination_signal": True})
|
||||||
|
}).execute()
|
||||||
|
|
||||||
|
# We don't break here to ensure all chunks are yielded,
|
||||||
|
# but the next iteration won't start due to continue_execution = False
|
||||||
|
|
||||||
|
except Exception as stream_error:
|
||||||
|
print(f"Error during stream processing: {str(stream_error)}")
|
||||||
|
yield {
|
||||||
|
"type": "status",
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Stream processing error: {str(stream_error)}"
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
# Double-check termination condition after all chunks processed
|
||||||
|
if last_tool_call in ['ask', 'complete']:
|
||||||
|
print(f"Confirming termination after stream with tool: {last_tool_call}")
|
||||||
|
continue_execution = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error running thread manager: {str(e)}")
|
||||||
|
yield {
|
||||||
|
"type": "status",
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Thread manager error: {str(e)}"
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
async def check_for_termination_signals(client, thread_id):
|
||||||
|
"""Check database for signals that should terminate the agent execution."""
|
||||||
|
try:
|
||||||
|
# Check the last message type first
|
||||||
|
latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).order('created_at', desc=True).limit(1).execute()
|
||||||
|
if latest_message.data and len(latest_message.data) > 0:
|
||||||
|
message_type = latest_message.data[0].get('type')
|
||||||
|
|
||||||
|
# If last message is from assistant, stop execution
|
||||||
|
if message_type == 'assistant':
|
||||||
|
return True, "Last message was from assistant"
|
||||||
|
|
||||||
|
# Check for tool-related termination signals
|
||||||
|
if message_type == 'tool':
|
||||||
|
try:
|
||||||
|
content = json.loads(latest_message.data[0].get('content', '{}'))
|
||||||
|
if content.get('name') in ['ask', 'complete']:
|
||||||
|
return True, f"Tool '{content.get('name')}' was executed"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check for special status messages with termination signals
|
||||||
|
if message_type == 'status':
|
||||||
|
try:
|
||||||
|
content = json.loads(latest_message.data[0].get('content', '{}'))
|
||||||
|
metadata = json.loads(latest_message.data[0].get('metadata', '{}'))
|
||||||
|
|
||||||
|
# Check for explicit termination signal in metadata
|
||||||
|
if metadata.get('termination_signal') == True:
|
||||||
|
return True, "Explicit termination signal found"
|
||||||
|
|
||||||
|
# Check for agent_termination status type
|
||||||
|
if content.get('status_type') == 'agent_termination':
|
||||||
|
return True, content.get('reason', 'Agent termination status found')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Also look for specific ask/complete tool execution in recent messages
|
||||||
|
recent_tool_messages = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'tool').order('created_at', desc=True).limit(5).execute()
|
||||||
|
if recent_tool_messages.data:
|
||||||
|
for msg in recent_tool_messages.data:
|
||||||
|
try:
|
||||||
|
content = json.loads(msg.get('content', '{}'))
|
||||||
|
if isinstance(content, dict) and content.get('role') == 'tool':
|
||||||
|
tool_name = content.get('name', '')
|
||||||
|
if tool_name in ['ask', 'complete']:
|
||||||
|
return True, f"Recent '{tool_name}' tool execution found"
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error checking for termination signals: {e}")
|
||||||
|
return False, None
|
||||||
|
|
||||||
# TESTING
|
# TESTING
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue