mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
8dce0d3254
commit
f9ca36efa0
|
@ -19,7 +19,6 @@ from agent.run import run_agent
|
||||||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||||
from utils.db import update_agent_run_status
|
|
||||||
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
||||||
from services.llm import make_llm_api_call
|
from services.llm import make_llm_api_call
|
||||||
|
|
||||||
|
@ -95,8 +94,7 @@ async def update_agent_run_status(
|
||||||
client,
|
client,
|
||||||
agent_run_id: str,
|
agent_run_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None
|
||||||
responses: Optional[List[Any]] = None
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Centralized function to update agent run status.
|
Centralized function to update agent run status.
|
||||||
|
@ -111,9 +109,6 @@ async def update_agent_run_status(
|
||||||
if error:
|
if error:
|
||||||
update_data["error"] = error
|
update_data["error"] = error
|
||||||
|
|
||||||
if responses:
|
|
||||||
update_data["responses"] = responses
|
|
||||||
|
|
||||||
# Retry up to 3 times
|
# Retry up to 3 times
|
||||||
for retry in range(3):
|
for retry in range(3):
|
||||||
try:
|
try:
|
||||||
|
@ -134,18 +129,16 @@ async def update_agent_run_status(
|
||||||
if retry == 2: # Last retry
|
if retry == 2: # Last retry
|
||||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
||||||
return False
|
return False
|
||||||
except Exception as db_error:
|
except Exception as e:
|
||||||
logger.error(f"Database error on retry {retry} updating status: {str(db_error)}")
|
logger.error(f"Error updating agent run status on retry {retry}: {str(e)}")
|
||||||
if retry < 2: # Not the last retry yet
|
if retry == 2: # Last retry
|
||||||
await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff
|
raise
|
||||||
else:
|
|
||||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True)
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error updating agent run status: {str(e)}", exc_info=True)
|
logger.error(f"Failed to update agent run status: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None):
|
async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None):
|
||||||
"""Update database and publish stop signal to Redis."""
|
"""Update database and publish stop signal to Redis."""
|
||||||
|
@ -348,7 +341,7 @@ async def get_or_create_project_sandbox(client, project_id: str, sandbox_cache={
|
||||||
try:
|
try:
|
||||||
logger.info(f"Creating new sandbox for project {project_id}")
|
logger.info(f"Creating new sandbox for project {project_id}")
|
||||||
sandbox_pass = str(uuid.uuid4())
|
sandbox_pass = str(uuid.uuid4())
|
||||||
sandbox = create_sandbox(sandbox_pass)
|
sandbox = create_sandbox(sandbox_pass, sandbox_id=project_id)
|
||||||
sandbox_id = sandbox.id
|
sandbox_id = sandbox.id
|
||||||
|
|
||||||
logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
|
logger.info(f"Created new sandbox {sandbox_id} with preview: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}")
|
||||||
|
@ -751,27 +744,23 @@ async def run_agent_background(
|
||||||
enable_context_manager=enable_context_manager
|
enable_context_manager=enable_context_manager
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect all responses to save to database
|
|
||||||
all_responses = []
|
|
||||||
|
|
||||||
async for response in agent_gen:
|
async for response in agent_gen:
|
||||||
# Check if stop signal received
|
# Check if stop signal received
|
||||||
if stop_signal_received:
|
if stop_signal_received:
|
||||||
logger.info(f"Agent run stopped due to stop signal: {agent_run_id} (instance: {instance_id})")
|
logger.info(f"Agent run stopped due to stop signal: {agent_run_id} (instance: {instance_id})")
|
||||||
await update_agent_run_status(client, agent_run_id, "stopped", responses=all_responses)
|
await update_agent_run_status(client, agent_run_id, "stopped")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check for billing error status
|
# Check for billing error status
|
||||||
if response.get('type') == 'status' and response.get('status') == 'error':
|
if response.get('type') == 'status' and response.get('status') == 'error':
|
||||||
error_msg = response.get('message', '')
|
error_msg = response.get('message', '')
|
||||||
logger.info(f"Agent run failed with error: {error_msg} (instance: {instance_id})")
|
logger.info(f"Agent run failed with error: {error_msg} (instance: {instance_id})")
|
||||||
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg, responses=all_responses)
|
await update_agent_run_status(client, agent_run_id, "failed", error=error_msg)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Store response in memory
|
# Store response in memory
|
||||||
if agent_run_id in active_agent_runs:
|
if agent_run_id in active_agent_runs:
|
||||||
active_agent_runs[agent_run_id].append(response)
|
active_agent_runs[agent_run_id].append(response)
|
||||||
all_responses.append(response)
|
|
||||||
total_responses += 1
|
total_responses += 1
|
||||||
|
|
||||||
# Signal all done if we weren't stopped
|
# Signal all done if we weren't stopped
|
||||||
|
@ -787,10 +776,9 @@ async def run_agent_background(
|
||||||
}
|
}
|
||||||
if agent_run_id in active_agent_runs:
|
if agent_run_id in active_agent_runs:
|
||||||
active_agent_runs[agent_run_id].append(completion_message)
|
active_agent_runs[agent_run_id].append(completion_message)
|
||||||
all_responses.append(completion_message)
|
|
||||||
|
|
||||||
# Update the agent run status
|
# Update the agent run status
|
||||||
await update_agent_run_status(client, agent_run_id, "completed", responses=all_responses)
|
await update_agent_run_status(client, agent_run_id, "completed")
|
||||||
|
|
||||||
# Notify any clients monitoring the control channels that we're done
|
# Notify any clients monitoring the control channels that we're done
|
||||||
try:
|
try:
|
||||||
|
@ -816,18 +804,13 @@ async def run_agent_background(
|
||||||
}
|
}
|
||||||
if agent_run_id in active_agent_runs:
|
if agent_run_id in active_agent_runs:
|
||||||
active_agent_runs[agent_run_id].append(error_response)
|
active_agent_runs[agent_run_id].append(error_response)
|
||||||
if 'all_responses' in locals():
|
|
||||||
all_responses.append(error_response)
|
|
||||||
else:
|
|
||||||
all_responses = [error_response]
|
|
||||||
|
|
||||||
# Update the agent run with the error
|
# Update the agent run with the error
|
||||||
await update_agent_run_status(
|
await update_agent_run_status(
|
||||||
client,
|
client,
|
||||||
agent_run_id,
|
agent_run_id,
|
||||||
"failed",
|
"failed",
|
||||||
error=f"{error_message}\n{traceback_str}",
|
error=f"{error_message}\n{traceback_str}"
|
||||||
responses=all_responses
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Notify any clients of the error
|
# Notify any clients of the error
|
||||||
|
@ -908,8 +891,6 @@ async def generate_and_update_project_name(project_id: str, prompt: str):
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
|
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
|
||||||
|
|
||||||
print(f"\n\n\nGenerated name: {generated_name}\n\n\n")
|
|
||||||
# Update database if name was generated
|
|
||||||
if generated_name:
|
if generated_name:
|
||||||
update_result = await client.table('projects') \
|
update_result = await client.table('projects') \
|
||||||
.update({"name": generated_name}) \
|
.update({"name": generated_name}) \
|
||||||
|
|
|
@ -195,281 +195,281 @@ async def run_agent(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TESTING
|
# # TESTING
|
||||||
|
|
||||||
async def test_agent():
|
# async def test_agent():
|
||||||
"""Test function to run the agent with a sample query"""
|
# """Test function to run the agent with a sample query"""
|
||||||
from agentpress.thread_manager import ThreadManager
|
# from agentpress.thread_manager import ThreadManager
|
||||||
from services.supabase import DBConnection
|
# from services.supabase import DBConnection
|
||||||
|
|
||||||
# Initialize ThreadManager
|
# # Initialize ThreadManager
|
||||||
thread_manager = ThreadManager()
|
# thread_manager = ThreadManager()
|
||||||
|
|
||||||
# Create a test thread directly with Postgres function
|
# # Create a test thread directly with Postgres function
|
||||||
client = await DBConnection().client
|
# client = await DBConnection().client
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
# Get user's personal account
|
# # Get user's personal account
|
||||||
account_result = await client.rpc('get_personal_account').execute()
|
# account_result = await client.rpc('get_personal_account').execute()
|
||||||
|
|
||||||
# if not account_result.data:
|
# # if not account_result.data:
|
||||||
# print("Error: No personal account found")
|
# # print("Error: No personal account found")
|
||||||
# return
|
# # return
|
||||||
|
|
||||||
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59"
|
# account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59"
|
||||||
|
|
||||||
if not account_id:
|
# if not account_id:
|
||||||
print("Error: Could not get account ID")
|
# print("Error: Could not get account ID")
|
||||||
return
|
# return
|
||||||
|
|
||||||
# Find or create a test project in the user's account
|
# # Find or create a test project in the user's account
|
||||||
project_result = await client.table('projects').select('*').eq('name', 'test11').eq('account_id', account_id).execute()
|
# project_result = await client.table('projects').select('*').eq('name', 'test11').eq('account_id', account_id).execute()
|
||||||
|
|
||||||
if project_result.data and len(project_result.data) > 0:
|
# if project_result.data and len(project_result.data) > 0:
|
||||||
# Use existing test project
|
# # Use existing test project
|
||||||
project_id = project_result.data[0]['project_id']
|
# project_id = project_result.data[0]['project_id']
|
||||||
print(f"\n🔄 Using existing test project: {project_id}")
|
# print(f"\n🔄 Using existing test project: {project_id}")
|
||||||
else:
|
# else:
|
||||||
# Create new test project if none exists
|
# # Create new test project if none exists
|
||||||
project_result = await client.table('projects').insert({
|
# project_result = await client.table('projects').insert({
|
||||||
"name": "test11",
|
# "name": "test11",
|
||||||
"account_id": account_id
|
# "account_id": account_id
|
||||||
}).execute()
|
# }).execute()
|
||||||
project_id = project_result.data[0]['project_id']
|
# project_id = project_result.data[0]['project_id']
|
||||||
print(f"\n✨ Created new test project: {project_id}")
|
# print(f"\n✨ Created new test project: {project_id}")
|
||||||
|
|
||||||
# Create a thread for this project
|
# # Create a thread for this project
|
||||||
thread_result = await client.table('threads').insert({
|
# thread_result = await client.table('threads').insert({
|
||||||
'project_id': project_id,
|
# 'project_id': project_id,
|
||||||
'account_id': account_id
|
# 'account_id': account_id
|
||||||
}).execute()
|
# }).execute()
|
||||||
thread_data = thread_result.data[0] if thread_result.data else None
|
# thread_data = thread_result.data[0] if thread_result.data else None
|
||||||
|
|
||||||
if not thread_data:
|
# if not thread_data:
|
||||||
print("Error: No thread data returned")
|
# print("Error: No thread data returned")
|
||||||
return
|
# return
|
||||||
|
|
||||||
thread_id = thread_data['thread_id']
|
# thread_id = thread_data['thread_id']
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"Error setting up thread: {str(e)}")
|
# print(f"Error setting up thread: {str(e)}")
|
||||||
return
|
# return
|
||||||
|
|
||||||
print(f"\n🤖 Agent Thread Created: {thread_id}\n")
|
# print(f"\n🤖 Agent Thread Created: {thread_id}\n")
|
||||||
|
|
||||||
# Interactive message input loop
|
# # Interactive message input loop
|
||||||
while True:
|
# while True:
|
||||||
# Get user input
|
# # Get user input
|
||||||
user_message = input("\n💬 Enter your message (or 'exit' to quit): ")
|
# user_message = input("\n💬 Enter your message (or 'exit' to quit): ")
|
||||||
if user_message.lower() == 'exit':
|
# if user_message.lower() == 'exit':
|
||||||
break
|
# break
|
||||||
|
|
||||||
if not user_message.strip():
|
# if not user_message.strip():
|
||||||
print("\n🔄 Running agent...\n")
|
# print("\n🔄 Running agent...\n")
|
||||||
await process_agent_response(thread_id, project_id, thread_manager)
|
# await process_agent_response(thread_id, project_id, thread_manager)
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
# Add the user message to the thread
|
# # Add the user message to the thread
|
||||||
await thread_manager.add_message(
|
# await thread_manager.add_message(
|
||||||
thread_id=thread_id,
|
# thread_id=thread_id,
|
||||||
type="user",
|
# type="user",
|
||||||
content={
|
# content={
|
||||||
"role": "user",
|
# "role": "user",
|
||||||
"content": user_message
|
# "content": user_message
|
||||||
},
|
# },
|
||||||
is_llm_message=True
|
# is_llm_message=True
|
||||||
)
|
# )
|
||||||
|
|
||||||
print("\n🔄 Running agent...\n")
|
# print("\n🔄 Running agent...\n")
|
||||||
await process_agent_response(thread_id, project_id, thread_manager)
|
# await process_agent_response(thread_id, project_id, thread_manager)
|
||||||
|
|
||||||
print("\n👋 Test completed. Goodbye!")
|
# print("\n👋 Test completed. Goodbye!")
|
||||||
|
|
||||||
async def process_agent_response(
|
# async def process_agent_response(
|
||||||
thread_id: str,
|
# thread_id: str,
|
||||||
project_id: str,
|
# project_id: str,
|
||||||
thread_manager: ThreadManager,
|
# thread_manager: ThreadManager,
|
||||||
stream: bool = True,
|
# stream: bool = True,
|
||||||
model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
# model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
||||||
enable_thinking: Optional[bool] = False,
|
# enable_thinking: Optional[bool] = False,
|
||||||
reasoning_effort: Optional[str] = 'low',
|
# reasoning_effort: Optional[str] = 'low',
|
||||||
enable_context_manager: bool = True
|
# enable_context_manager: bool = True
|
||||||
):
|
# ):
|
||||||
"""Process the streaming response from the agent."""
|
# """Process the streaming response from the agent."""
|
||||||
chunk_counter = 0
|
# chunk_counter = 0
|
||||||
current_response = ""
|
# current_response = ""
|
||||||
tool_usage_counter = 0 # Renamed from tool_call_counter as we track usage via status
|
# tool_usage_counter = 0 # Renamed from tool_call_counter as we track usage via status
|
||||||
|
|
||||||
# Create a test sandbox for processing with a unique test prefix to avoid conflicts with production sandboxes
|
# # Create a test sandbox for processing with a unique test prefix to avoid conflicts with production sandboxes
|
||||||
sandbox_pass = str(uuid4())
|
# sandbox_pass = str(uuid4())
|
||||||
sandbox = create_sandbox(sandbox_pass)
|
# sandbox = create_sandbox(sandbox_pass)
|
||||||
|
|
||||||
# Store the original ID so we can refer to it
|
# # Store the original ID so we can refer to it
|
||||||
original_sandbox_id = sandbox.id
|
# original_sandbox_id = sandbox.id
|
||||||
|
|
||||||
# Generate a clear test identifier
|
# # Generate a clear test identifier
|
||||||
test_prefix = f"test_{uuid4().hex[:8]}_"
|
# test_prefix = f"test_{uuid4().hex[:8]}_"
|
||||||
logger.info(f"Created test sandbox with ID {original_sandbox_id} and test prefix {test_prefix}")
|
# logger.info(f"Created test sandbox with ID {original_sandbox_id} and test prefix {test_prefix}")
|
||||||
|
|
||||||
# Log the sandbox URL for debugging
|
# # Log the sandbox URL for debugging
|
||||||
print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
# print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
||||||
|
|
||||||
async for chunk in run_agent(
|
# async for chunk in run_agent(
|
||||||
thread_id=thread_id,
|
# thread_id=thread_id,
|
||||||
project_id=project_id,
|
# project_id=project_id,
|
||||||
sandbox=sandbox,
|
# sandbox=sandbox,
|
||||||
stream=stream,
|
# stream=stream,
|
||||||
thread_manager=thread_manager,
|
# thread_manager=thread_manager,
|
||||||
native_max_auto_continues=25,
|
# native_max_auto_continues=25,
|
||||||
model_name=model_name,
|
# model_name=model_name,
|
||||||
enable_thinking=enable_thinking,
|
# enable_thinking=enable_thinking,
|
||||||
reasoning_effort=reasoning_effort,
|
# reasoning_effort=reasoning_effort,
|
||||||
enable_context_manager=enable_context_manager
|
# enable_context_manager=enable_context_manager
|
||||||
):
|
# ):
|
||||||
chunk_counter += 1
|
# chunk_counter += 1
|
||||||
# print(f"CHUNK: {chunk}") # Uncomment for debugging
|
# # print(f"CHUNK: {chunk}") # Uncomment for debugging
|
||||||
|
|
||||||
if chunk.get('type') == 'assistant':
|
# if chunk.get('type') == 'assistant':
|
||||||
# Try parsing the content JSON
|
# # Try parsing the content JSON
|
||||||
try:
|
# try:
|
||||||
# Handle content as string or object
|
# # Handle content as string or object
|
||||||
content = chunk.get('content', '{}')
|
# content = chunk.get('content', '{}')
|
||||||
if isinstance(content, str):
|
# if isinstance(content, str):
|
||||||
content_json = json.loads(content)
|
# content_json = json.loads(content)
|
||||||
else:
|
# else:
|
||||||
content_json = content
|
# content_json = content
|
||||||
|
|
||||||
actual_content = content_json.get('content', '')
|
# actual_content = content_json.get('content', '')
|
||||||
# Print the actual assistant text content as it comes
|
# # Print the actual assistant text content as it comes
|
||||||
if actual_content:
|
# if actual_content:
|
||||||
# Check if it contains XML tool tags, if so, print the whole tag for context
|
# # Check if it contains XML tool tags, if so, print the whole tag for context
|
||||||
if '<' in actual_content and '>' in actual_content:
|
# if '<' in actual_content and '>' in actual_content:
|
||||||
# Avoid printing potentially huge raw content if it's not just text
|
# # Avoid printing potentially huge raw content if it's not just text
|
||||||
if len(actual_content) < 500: # Heuristic limit
|
# if len(actual_content) < 500: # Heuristic limit
|
||||||
print(actual_content, end='', flush=True)
|
# print(actual_content, end='', flush=True)
|
||||||
else:
|
# else:
|
||||||
# Maybe just print a summary if it's too long or contains complex XML
|
# # Maybe just print a summary if it's too long or contains complex XML
|
||||||
if '</ask>' in actual_content: print("<ask>...</ask>", end='', flush=True)
|
# if '</ask>' in actual_content: print("<ask>...</ask>", end='', flush=True)
|
||||||
elif '</complete>' in actual_content: print("<complete>...</complete>", end='', flush=True)
|
# elif '</complete>' in actual_content: print("<complete>...</complete>", end='', flush=True)
|
||||||
else: print("<tool_call>...</tool_call>", end='', flush=True) # Generic case
|
# else: print("<tool_call>...</tool_call>", end='', flush=True) # Generic case
|
||||||
else:
|
# else:
|
||||||
# Regular text content
|
# # Regular text content
|
||||||
print(actual_content, end='', flush=True)
|
# print(actual_content, end='', flush=True)
|
||||||
current_response += actual_content # Accumulate only text part
|
# current_response += actual_content # Accumulate only text part
|
||||||
except json.JSONDecodeError:
|
# except json.JSONDecodeError:
|
||||||
# If content is not JSON (e.g., just a string chunk), print directly
|
# # If content is not JSON (e.g., just a string chunk), print directly
|
||||||
raw_content = chunk.get('content', '')
|
# raw_content = chunk.get('content', '')
|
||||||
print(raw_content, end='', flush=True)
|
# print(raw_content, end='', flush=True)
|
||||||
current_response += raw_content
|
# current_response += raw_content
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"\nError processing assistant chunk: {e}\n")
|
# print(f"\nError processing assistant chunk: {e}\n")
|
||||||
|
|
||||||
elif chunk.get('type') == 'tool': # Updated from 'tool_result'
|
# elif chunk.get('type') == 'tool': # Updated from 'tool_result'
|
||||||
# Add timestamp and format tool result nicely
|
# # Add timestamp and format tool result nicely
|
||||||
tool_name = "UnknownTool" # Try to get from metadata if available
|
# tool_name = "UnknownTool" # Try to get from metadata if available
|
||||||
result_content = "No content"
|
# result_content = "No content"
|
||||||
|
|
||||||
# Parse metadata - handle both string and dict formats
|
# # Parse metadata - handle both string and dict formats
|
||||||
metadata = chunk.get('metadata', {})
|
# metadata = chunk.get('metadata', {})
|
||||||
if isinstance(metadata, str):
|
# if isinstance(metadata, str):
|
||||||
try:
|
# try:
|
||||||
metadata = json.loads(metadata)
|
# metadata = json.loads(metadata)
|
||||||
except json.JSONDecodeError:
|
# except json.JSONDecodeError:
|
||||||
metadata = {}
|
# metadata = {}
|
||||||
|
|
||||||
linked_assistant_msg_id = metadata.get('assistant_message_id')
|
# linked_assistant_msg_id = metadata.get('assistant_message_id')
|
||||||
parsing_details = metadata.get('parsing_details')
|
# parsing_details = metadata.get('parsing_details')
|
||||||
if parsing_details:
|
# if parsing_details:
|
||||||
tool_name = parsing_details.get('xml_tag_name', 'UnknownTool') # Get name from parsing details
|
# tool_name = parsing_details.get('xml_tag_name', 'UnknownTool') # Get name from parsing details
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
# Content is a JSON string or object
|
# # Content is a JSON string or object
|
||||||
content = chunk.get('content', '{}')
|
# content = chunk.get('content', '{}')
|
||||||
if isinstance(content, str):
|
# if isinstance(content, str):
|
||||||
content_json = json.loads(content)
|
# content_json = json.loads(content)
|
||||||
else:
|
# else:
|
||||||
content_json = content
|
# content_json = content
|
||||||
|
|
||||||
# The actual tool result is nested inside content.content
|
# # The actual tool result is nested inside content.content
|
||||||
tool_result_str = content_json.get('content', '')
|
# tool_result_str = content_json.get('content', '')
|
||||||
# Extract the actual tool result string (remove outer <tool_result> tag if present)
|
# # Extract the actual tool result string (remove outer <tool_result> tag if present)
|
||||||
match = re.search(rf'<{tool_name}>(.*?)</{tool_name}>', tool_result_str, re.DOTALL)
|
# match = re.search(rf'<{tool_name}>(.*?)</{tool_name}>', tool_result_str, re.DOTALL)
|
||||||
if match:
|
# if match:
|
||||||
result_content = match.group(1).strip()
|
# result_content = match.group(1).strip()
|
||||||
# Try to parse the result string itself as JSON for pretty printing
|
# # Try to parse the result string itself as JSON for pretty printing
|
||||||
try:
|
# try:
|
||||||
result_obj = json.loads(result_content)
|
# result_obj = json.loads(result_content)
|
||||||
result_content = json.dumps(result_obj, indent=2)
|
# result_content = json.dumps(result_obj, indent=2)
|
||||||
except json.JSONDecodeError:
|
# except json.JSONDecodeError:
|
||||||
# Keep as string if not JSON
|
# # Keep as string if not JSON
|
||||||
pass
|
# pass
|
||||||
else:
|
# else:
|
||||||
# Fallback if tag extraction fails
|
# # Fallback if tag extraction fails
|
||||||
result_content = tool_result_str
|
# result_content = tool_result_str
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
# except json.JSONDecodeError:
|
||||||
result_content = chunk.get('content', 'Error parsing tool content')
|
# result_content = chunk.get('content', 'Error parsing tool content')
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
result_content = f"Error processing tool chunk: {e}"
|
# result_content = f"Error processing tool chunk: {e}"
|
||||||
|
|
||||||
print(f"\n\n🛠️ TOOL RESULT [{tool_name}] → {result_content}")
|
# print(f"\n\n🛠️ TOOL RESULT [{tool_name}] → {result_content}")
|
||||||
|
|
||||||
elif chunk.get('type') == 'status':
|
# elif chunk.get('type') == 'status':
|
||||||
# Log tool status changes
|
# # Log tool status changes
|
||||||
try:
|
# try:
|
||||||
# Handle content as string or object
|
# # Handle content as string or object
|
||||||
status_content = chunk.get('content', '{}')
|
# status_content = chunk.get('content', '{}')
|
||||||
if isinstance(status_content, str):
|
# if isinstance(status_content, str):
|
||||||
status_content = json.loads(status_content)
|
# status_content = json.loads(status_content)
|
||||||
|
|
||||||
status_type = status_content.get('status_type')
|
# status_type = status_content.get('status_type')
|
||||||
function_name = status_content.get('function_name', '')
|
# function_name = status_content.get('function_name', '')
|
||||||
xml_tag_name = status_content.get('xml_tag_name', '') # Get XML tag if available
|
# xml_tag_name = status_content.get('xml_tag_name', '') # Get XML tag if available
|
||||||
tool_name = xml_tag_name or function_name # Prefer XML tag name
|
# tool_name = xml_tag_name or function_name # Prefer XML tag name
|
||||||
|
|
||||||
if status_type == 'tool_started' and tool_name:
|
# if status_type == 'tool_started' and tool_name:
|
||||||
tool_usage_counter += 1
|
# tool_usage_counter += 1
|
||||||
print(f"\n⏳ TOOL STARTING #{tool_usage_counter} [{tool_name}]")
|
# print(f"\n⏳ TOOL STARTING #{tool_usage_counter} [{tool_name}]")
|
||||||
print(" " + "-" * 40)
|
# print(" " + "-" * 40)
|
||||||
# Return to the current content display
|
# # Return to the current content display
|
||||||
if current_response:
|
# if current_response:
|
||||||
print("\nContinuing response:", flush=True)
|
# print("\nContinuing response:", flush=True)
|
||||||
print(current_response, end='', flush=True)
|
# print(current_response, end='', flush=True)
|
||||||
elif status_type == 'tool_completed' and tool_name:
|
# elif status_type == 'tool_completed' and tool_name:
|
||||||
status_emoji = "✅"
|
# status_emoji = "✅"
|
||||||
print(f"\n{status_emoji} TOOL COMPLETED: {tool_name}")
|
# print(f"\n{status_emoji} TOOL COMPLETED: {tool_name}")
|
||||||
elif status_type == 'finish':
|
# elif status_type == 'finish':
|
||||||
finish_reason = status_content.get('finish_reason', '')
|
# finish_reason = status_content.get('finish_reason', '')
|
||||||
if finish_reason:
|
# if finish_reason:
|
||||||
print(f"\n📌 Finished: {finish_reason}")
|
# print(f"\n📌 Finished: {finish_reason}")
|
||||||
# else: # Print other status types if needed for debugging
|
# # else: # Print other status types if needed for debugging
|
||||||
# print(f"\nℹ️ STATUS: {chunk.get('content')}")
|
# # print(f"\nℹ️ STATUS: {chunk.get('content')}")
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
# except json.JSONDecodeError:
|
||||||
print(f"\nWarning: Could not parse status content JSON: {chunk.get('content')}")
|
# print(f"\nWarning: Could not parse status content JSON: {chunk.get('content')}")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
print(f"\nError processing status chunk: {e}")
|
# print(f"\nError processing status chunk: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Removed elif chunk.get('type') == 'tool_call': block
|
# # Removed elif chunk.get('type') == 'tool_call': block
|
||||||
|
|
||||||
# Update final message
|
# # Update final message
|
||||||
print(f"\n\n✅ Agent run completed with {tool_usage_counter} tool executions")
|
# print(f"\n\n✅ Agent run completed with {tool_usage_counter} tool executions")
|
||||||
|
|
||||||
# Try to clean up the test sandbox if possible
|
# # Try to clean up the test sandbox if possible
|
||||||
try:
|
# try:
|
||||||
# Attempt to delete/archive the sandbox to clean up resources
|
# # Attempt to delete/archive the sandbox to clean up resources
|
||||||
# Note: Actual deletion may depend on the Daytona SDK's capabilities
|
# # Note: Actual deletion may depend on the Daytona SDK's capabilities
|
||||||
logger.info(f"Attempting to clean up test sandbox {original_sandbox_id}")
|
# logger.info(f"Attempting to clean up test sandbox {original_sandbox_id}")
|
||||||
# If there's a method to archive/delete the sandbox, call it here
|
# # If there's a method to archive/delete the sandbox, call it here
|
||||||
# Example: daytona.archive_sandbox(sandbox.id)
|
# # Example: daytona.archive_sandbox(sandbox.id)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.warning(f"Failed to clean up test sandbox {original_sandbox_id}: {str(e)}")
|
# logger.warning(f"Failed to clean up test sandbox {original_sandbox_id}: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
import asyncio
|
# import asyncio
|
||||||
|
|
||||||
# Configure any environment variables or setup needed for testing
|
# # Configure any environment variables or setup needed for testing
|
||||||
load_dotenv() # Ensure environment variables are loaded
|
# load_dotenv() # Ensure environment variables are loaded
|
||||||
|
|
||||||
# Run the test function
|
# # Run the test function
|
||||||
asyncio.run(test_agent())
|
# asyncio.run(test_agent())
|
|
@ -208,7 +208,7 @@ class ResponseProcessor:
|
||||||
tool_index += 1
|
tool_index += 1
|
||||||
|
|
||||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||||
logger.info(f"Reached XML tool call limit ({config.max_xml_tool_calls})")
|
logger.debug(f"Reached XML tool call limit ({config.max_xml_tool_calls})")
|
||||||
finish_reason = "xml_tool_limit_reached"
|
finish_reason = "xml_tool_limit_reached"
|
||||||
break # Stop processing more XML chunks in this delta
|
break # Stop processing more XML chunks in this delta
|
||||||
|
|
||||||
|
@ -1001,7 +1001,7 @@ class ResponseProcessor:
|
||||||
"arguments": params # The extracted parameters
|
"arguments": params # The extracted parameters
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Created tool call: {tool_call}")
|
logger.debug(f"Created tool call: {tool_call}")
|
||||||
return tool_call, parsing_details # Return both dicts
|
return tool_call, parsing_details # Return both dicts
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -86,8 +86,6 @@ class ThreadManager:
|
||||||
# Add returning='representation' to get the inserted row data including the id
|
# Add returning='representation' to get the inserted row data including the id
|
||||||
result = await client.table('messages').insert(data_to_insert, returning='representation').execute()
|
result = await client.table('messages').insert(data_to_insert, returning='representation').execute()
|
||||||
logger.info(f"Successfully added message to thread {thread_id}")
|
logger.info(f"Successfully added message to thread {thread_id}")
|
||||||
|
|
||||||
print(f"MESSAGE RESULT: {result}")
|
|
||||||
|
|
||||||
if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]:
|
if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]:
|
||||||
return result.data[0]
|
return result.data[0]
|
||||||
|
|
|
@ -2665,7 +2665,7 @@ files = [
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:493fe54318bed7d124ce272fc36adbf59d46729659b2c792e87c3b95649cdee9"},
|
{file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:493fe54318bed7d124ce272fc36adbf59d46729659b2c792e87c3b95649cdee9"},
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8aa362811ccdc1f8dadcc916c6d47e554169ab79559319ae9fae7d7752d0d60c"},
|
{file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8aa362811ccdc1f8dadcc916c6d47e554169ab79559319ae9fae7d7752d0d60c"},
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d8f9a6e7fd5434817526815f09ea27f2746c4a51ee11bb3439065f5fc754db58"},
|
{file = "rpds_py-0.24.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d8f9a6e7fd5434817526815f09ea27f2746c4a51ee11bb3439065f5fc754db58"},
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8205ee14463248d3349131bb8080efe15cd3ce83b8ef3ace63c7e976998e7124"},
|
{file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8205ee14463248d3349131bb8099efe15cd3ce83b8ef3ace63c7e976998e7124"},
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:921ae54f9ecba3b6325df425cf72c074cd469dea843fb5743a26ca7fb2ccb149"},
|
{file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:921ae54f9ecba3b6325df425cf72c074cd469dea843fb5743a26ca7fb2ccb149"},
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32bab0a56eac685828e00cc2f5d1200c548f8bc11f2e44abf311d6b548ce2e45"},
|
{file = "rpds_py-0.24.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:32bab0a56eac685828e00cc2f5d1200c548f8bc11f2e44abf311d6b548ce2e45"},
|
||||||
{file = "rpds_py-0.24.0-cp39-cp39-win32.whl", hash = "sha256:f5c0ed12926dec1dfe7d645333ea59cf93f4d07750986a586f511c0bc61fe103"},
|
{file = "rpds_py-0.24.0-cp39-cp39-win32.whl", hash = "sha256:f5c0ed12926dec1dfe7d645333ea59cf93f4d07750986a586f511c0bc61fe103"},
|
||||||
|
@ -3622,4 +3622,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "8d1e50482e981a8647474dd33a068bce4da888729d70e17f6d92bc43f972689d"
|
content-hash = "622a06feff14fc27c612f15e50be3375531175462c46fa57c3bcf33851e2a9c3"
|
||||||
|
|
|
@ -57,6 +57,10 @@ agentpress = "agentpress.cli:main"
|
||||||
[[tool.poetry.packages]]
|
[[tool.poetry.packages]]
|
||||||
include = "agentpress"
|
include = "agentpress"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
daytona-sdk = "^0.14.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
|
@ -11,7 +11,6 @@ from sandbox.sandbox import get_or_start_sandbox
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
from agent.api import get_or_create_project_sandbox
|
from agent.api import get_or_create_project_sandbox
|
||||||
|
|
||||||
# TODO: ADD AUTHORIZATION TO ONLY HAVE ACCESS TO SANDBOXES OF PROJECTS U HAVE ACCESS TO
|
|
||||||
|
|
||||||
# Initialize shared resources
|
# Initialize shared resources
|
||||||
router = APIRouter(tags=["sandbox"])
|
router = APIRouter(tags=["sandbox"])
|
||||||
|
|
|
@ -83,15 +83,21 @@ def start_supervisord_session(sandbox: Sandbox):
|
||||||
logger.error(f"Error starting supervisord session: {str(e)}")
|
logger.error(f"Error starting supervisord session: {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def create_sandbox(password: str):
|
def create_sandbox(password: str, sandbox_id: str = None):
|
||||||
"""Create a new sandbox with all required services configured and running."""
|
"""Create a new sandbox with all required services configured and running."""
|
||||||
|
|
||||||
logger.info("Creating new Daytona sandbox environment")
|
logger.debug("Creating new Daytona sandbox environment")
|
||||||
logger.debug("Configuring sandbox with browser-use image and environment variables")
|
logger.debug("Configuring sandbox with browser-use image and environment variables")
|
||||||
|
|
||||||
|
labels = None
|
||||||
|
if sandbox_id:
|
||||||
|
logger.debug(f"Using sandbox_id as label: {sandbox_id}")
|
||||||
|
labels = {'id': sandbox_id}
|
||||||
|
|
||||||
sandbox = daytona.create(CreateSandboxParams(
|
params = CreateSandboxParams(
|
||||||
image="adamcohenhillel/kortix-suna:0.0.20",
|
image="adamcohenhillel/kortix-suna:0.0.20",
|
||||||
public=True,
|
public=True,
|
||||||
|
labels=labels,
|
||||||
env_vars={
|
env_vars={
|
||||||
"CHROME_PERSISTENT_SESSION": "true",
|
"CHROME_PERSISTENT_SESSION": "true",
|
||||||
"RESOLUTION": "1024x768x24",
|
"RESOLUTION": "1024x768x24",
|
||||||
|
@ -118,13 +124,16 @@ def create_sandbox(password: str):
|
||||||
"memory": 4,
|
"memory": 4,
|
||||||
"disk": 5,
|
"disk": 5,
|
||||||
}
|
}
|
||||||
))
|
)
|
||||||
logger.info(f"Sandbox created with ID: {sandbox.id}")
|
|
||||||
|
# Create the sandbox
|
||||||
|
sandbox = daytona.create(params)
|
||||||
|
logger.debug(f"Sandbox created with ID: {sandbox.id}")
|
||||||
|
|
||||||
# Start supervisord in a session for new sandbox
|
# Start supervisord in a session for new sandbox
|
||||||
start_supervisord_session(sandbox)
|
start_supervisord_session(sandbox)
|
||||||
|
|
||||||
logger.info(f"Sandbox environment successfully initialized")
|
logger.debug(f"Sandbox environment successfully initialized")
|
||||||
return sandbox
|
return sandbox
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,15 +169,15 @@ class SandboxToolsBase(Tool):
|
||||||
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link)
|
website_url = website_link.url if hasattr(website_link, 'url') else str(website_link)
|
||||||
|
|
||||||
# Log the actual URLs
|
# Log the actual URLs
|
||||||
logger.info(f"Sandbox VNC URL: {vnc_url}")
|
# logger.info(f"Sandbox VNC URL: {vnc_url}")
|
||||||
logger.info(f"Sandbox Website URL: {website_url}")
|
# logger.info(f"Sandbox Website URL: {website_url}")
|
||||||
|
|
||||||
if not SandboxToolsBase._urls_printed:
|
# if not SandboxToolsBase._urls_printed:
|
||||||
print("\033[95m***")
|
# print("\033[95m***")
|
||||||
print(vnc_url)
|
# print(vnc_url)
|
||||||
print(website_url)
|
# print(website_url)
|
||||||
print("***\033[0m")
|
# print("***\033[0m")
|
||||||
SandboxToolsBase._urls_printed = True
|
# SandboxToolsBase._urls_printed = True
|
||||||
|
|
||||||
def clean_path(self, path: str) -> str:
|
def clean_path(self, path: str) -> str:
|
||||||
cleaned_path = clean_path(path, self.workspace_path)
|
cleaned_path = clean_path(path, self.workspace_path)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
# Define subscription tiers and their monthly limits (in minutes)
|
# Define subscription tiers and their monthly limits (in minutes)
|
||||||
SUBSCRIPTION_TIERS = {
|
SUBSCRIPTION_TIERS = {
|
||||||
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 10},
|
'price_1RGJ9GG6l1KZGqIroxSqgphC': {'name': 'free', 'minutes': 1000000},
|
||||||
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, # 100 hours = 6000 minutes
|
'price_1RGJ9LG6l1KZGqIrd9pwzeNW': {'name': 'base', 'minutes': 300}, # 100 hours = 6000 minutes
|
||||||
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} # 100 hours = 6000 minutes
|
'price_1RGJ9JG6l1KZGqIrVUU4ZRv6': {'name': 'extra', 'minutes': 2400} # 100 hours = 6000 minutes
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
from typing import Optional, List, Any
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
async def update_agent_run_status(
|
|
||||||
client,
|
|
||||||
agent_run_id: str,
|
|
||||||
status: str,
|
|
||||||
error: Optional[str] = None,
|
|
||||||
responses: Optional[List[Any]] = None
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Centralized function to update agent run status.
|
|
||||||
Returns True if update was successful.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
update_data = {
|
|
||||||
"status": status,
|
|
||||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
if error:
|
|
||||||
update_data["error"] = error
|
|
||||||
|
|
||||||
if responses:
|
|
||||||
update_data["responses"] = responses
|
|
||||||
|
|
||||||
# Retry up to 3 times
|
|
||||||
for retry in range(3):
|
|
||||||
try:
|
|
||||||
update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute()
|
|
||||||
|
|
||||||
if hasattr(update_result, 'data') and update_result.data:
|
|
||||||
logger.info(f"Successfully updated agent run status to '{status}' (retry {retry}): {agent_run_id}")
|
|
||||||
|
|
||||||
# Verify the update
|
|
||||||
verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute()
|
|
||||||
if verify_result.data:
|
|
||||||
actual_status = verify_result.data[0].get('status')
|
|
||||||
completed_at = verify_result.data[0].get('completed_at')
|
|
||||||
logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"Database update returned no data on retry {retry}: {update_result}")
|
|
||||||
if retry == 2: # Last retry
|
|
||||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating agent run status on retry {retry}: {str(e)}")
|
|
||||||
if retry == 2: # Last retry
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to update agent run status: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return False
|
|
Loading…
Reference in New Issue