mirror of https://github.com/kortix-ai/suna.git
Merge branch 'main' of https://github.com/escapade-mckv/suna into feat/ui
This commit is contained in:
commit
d2bbd1bd27
|
@ -18,6 +18,7 @@ from agent.run import run_agent
|
|||
from utils.auth_utils import get_current_user_id_from_jwt, get_user_id_from_stream_auth, verify_thread_access
|
||||
from utils.logger import logger
|
||||
from services.billing import check_billing_status
|
||||
from utils.config import config
|
||||
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
||||
from services.llm import make_llm_api_call
|
||||
|
||||
|
@ -31,17 +32,32 @@ instance_id = None # Global instance ID for this backend instance
|
|||
REDIS_RESPONSE_LIST_TTL = 3600 * 24
|
||||
|
||||
MODEL_NAME_ALIASES = {
|
||||
# Short names to full names
|
||||
"sonnet-3.7": "anthropic/claude-3-7-sonnet-latest",
|
||||
"gpt-4.1": "openai/gpt-4.1-2025-04-14",
|
||||
"gpt-4o": "openai/gpt-4o",
|
||||
"gpt-4-turbo": "openai/gpt-4-turbo",
|
||||
"gpt-4": "openai/gpt-4",
|
||||
"gemini-flash-2.5": "openrouter/google/gemini-2.5-flash-preview",
|
||||
"grok-3": "xai/grok-3-fast-latest",
|
||||
"deepseek": "deepseek/deepseek-chat",
|
||||
"grok-3-mini": "xai/grok-3-mini-fast-beta",
|
||||
"qwen3-4b": "openrouter/qwen/qwen3-4b:free",
|
||||
|
||||
# Also include full names as keys to ensure they map to themselves
|
||||
"anthropic/claude-3-7-sonnet-latest": "anthropic/claude-3-7-sonnet-latest",
|
||||
"openai/gpt-4.1-2025-04-14": "openai/gpt-4.1-2025-04-14",
|
||||
"openai/gpt-4o": "openai/gpt-4o",
|
||||
"openai/gpt-4-turbo": "openai/gpt-4-turbo",
|
||||
"openai/gpt-4": "openai/gpt-4",
|
||||
"openrouter/google/gemini-2.5-flash-preview": "openrouter/google/gemini-2.5-flash-preview",
|
||||
"xai/grok-3-fast-latest": "xai/grok-3-fast-latest",
|
||||
"deepseek/deepseek-chat": "deepseek/deepseek-chat",
|
||||
"xai/grok-3-mini-fast-beta": "xai/grok-3-mini-fast-beta",
|
||||
}
|
||||
|
||||
class AgentStartRequest(BaseModel):
|
||||
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||
model_name: Optional[str] = None # Will be set from config.MODEL_TO_USE in the endpoint
|
||||
enable_thinking: Optional[bool] = False
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
stream: Optional[bool] = True
|
||||
|
@ -236,27 +252,27 @@ async def restore_running_agent_runs():
|
|||
for run in running_agent_runs.data:
|
||||
agent_run_id = run['id']
|
||||
logger.warning(f"Found running agent run {agent_run_id} from before server restart")
|
||||
|
||||
|
||||
# Clean up Redis resources for this run
|
||||
try:
|
||||
# Clean up active run key
|
||||
active_run_key = f"active_run:{instance_id}:{agent_run_id}"
|
||||
await redis.delete(active_run_key)
|
||||
|
||||
|
||||
# Clean up response list
|
||||
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||
await redis.delete(response_list_key)
|
||||
|
||||
|
||||
# Clean up control channels
|
||||
control_channel = f"agent_run:{agent_run_id}:control"
|
||||
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
|
||||
await redis.delete(control_channel)
|
||||
await redis.delete(instance_control_channel)
|
||||
|
||||
|
||||
logger.info(f"Cleaned up Redis resources for agent run {agent_run_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up Redis resources for agent run {agent_run_id}: {e}")
|
||||
|
||||
|
||||
# Call stop_agent_run to handle status update and cleanup
|
||||
await stop_agent_run(agent_run_id, error_message="Server restarted while agent was running")
|
||||
|
||||
|
@ -356,7 +372,22 @@ async def start_agent(
|
|||
if not instance_id:
|
||||
raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID")
|
||||
|
||||
logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager} (Instance: {instance_id})")
|
||||
# Use model from config if not specified in the request
|
||||
model_name = body.model_name
|
||||
logger.info(f"Original model_name from request: {model_name}")
|
||||
|
||||
if model_name is None:
|
||||
model_name = config.MODEL_TO_USE
|
||||
logger.info(f"Using model from config: {model_name}")
|
||||
|
||||
# Log the model name after alias resolution
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
logger.info(f"Resolved model name: {resolved_model}")
|
||||
|
||||
# Update model_name to use the resolved version
|
||||
model_name = resolved_model
|
||||
|
||||
logger.info(f"Starting new agent for thread: {thread_id} with config: model={model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager} (Instance: {instance_id})")
|
||||
client = await db.client
|
||||
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
|
@ -401,7 +432,7 @@ async def start_agent(
|
|||
run_agent_background(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id, sandbox=sandbox,
|
||||
model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name),
|
||||
model_name=model_name, # Already resolved above
|
||||
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
|
||||
stream=body.stream, enable_context_manager=body.enable_context_manager
|
||||
)
|
||||
|
@ -643,7 +674,9 @@ async def run_agent_background(
|
|||
enable_context_manager: bool
|
||||
):
|
||||
"""Run the agent in the background using Redis for state."""
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
||||
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
||||
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
||||
|
||||
client = await db.client
|
||||
start_time = datetime.now(timezone.utc)
|
||||
total_responses = 0
|
||||
|
@ -854,7 +887,7 @@ async def generate_and_update_project_name(project_id: str, prompt: str):
|
|||
@router.post("/agent/initiate", response_model=InitiateAgentResponse)
|
||||
async def initiate_agent_with_files(
|
||||
prompt: str = Form(...),
|
||||
model_name: Optional[str] = Form("anthropic/claude-3-7-sonnet-latest"),
|
||||
model_name: Optional[str] = Form(None), # Default to None to use config.MODEL_TO_USE
|
||||
enable_thinking: Optional[bool] = Form(False),
|
||||
reasoning_effort: Optional[str] = Form("low"),
|
||||
stream: Optional[bool] = Form(True),
|
||||
|
@ -867,6 +900,20 @@ async def initiate_agent_with_files(
|
|||
if not instance_id:
|
||||
raise HTTPException(status_code=500, detail="Agent API not initialized with instance ID")
|
||||
|
||||
# Use model from config if not specified in the request
|
||||
logger.info(f"Original model_name from request: {model_name}")
|
||||
|
||||
if model_name is None:
|
||||
model_name = config.MODEL_TO_USE
|
||||
logger.info(f"Using model from config: {model_name}")
|
||||
|
||||
# Log the model name after alias resolution
|
||||
resolved_model = MODEL_NAME_ALIASES.get(model_name, model_name)
|
||||
logger.info(f"Resolved model name: {resolved_model}")
|
||||
|
||||
# Update model_name to use the resolved version
|
||||
model_name = resolved_model
|
||||
|
||||
logger.info(f"[\033[91mDEBUG\033[0m] Initiating new agent with prompt and {len(files)} files (Instance: {instance_id}), model: {model_name}, enable_thinking: {enable_thinking}")
|
||||
client = await db.client
|
||||
account_id = user_id # In Basejump, personal account_id is the same as user_id
|
||||
|
@ -987,7 +1034,7 @@ async def initiate_agent_with_files(
|
|||
run_agent_background(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id, sandbox=sandbox,
|
||||
model_name=MODEL_NAME_ALIASES.get(model_name, model_name),
|
||||
model_name=model_name, # Already resolved above
|
||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||
stream=stream, enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
|
|
@ -39,7 +39,8 @@ async def run_agent(
|
|||
enable_context_manager: bool = True
|
||||
):
|
||||
"""Run the development agent with specified configuration."""
|
||||
|
||||
print(f"🚀 Starting agent with model: {model_name}")
|
||||
|
||||
thread_manager = ThreadManager()
|
||||
|
||||
client = await thread_manager.db.client
|
||||
|
@ -53,12 +54,12 @@ async def run_agent(
|
|||
project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if not project.data or len(project.data) == 0:
|
||||
raise ValueError(f"Project {project_id} not found")
|
||||
|
||||
|
||||
project_data = project.data[0]
|
||||
sandbox_info = project_data.get('sandbox', {})
|
||||
if not sandbox_info.get('id'):
|
||||
raise ValueError(f"No sandbox found for project {project_id}")
|
||||
|
||||
|
||||
# Initialize tools with project_id instead of sandbox object
|
||||
# This ensures each tool independently verifies it's operating on the correct project
|
||||
thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
|
||||
|
@ -69,7 +70,6 @@ async def run_agent(
|
|||
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
||||
thread_manager.add_tool(WebSearchTool)
|
||||
thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
|
||||
|
||||
# Add data providers tool if RapidAPI key is available
|
||||
if config.RAPID_API_KEY:
|
||||
thread_manager.add_tool(DataProvidersTool)
|
||||
|
@ -78,7 +78,7 @@ async def run_agent(
|
|||
|
||||
iteration_count = 0
|
||||
continue_execution = True
|
||||
|
||||
|
||||
while continue_execution and iteration_count < max_iterations:
|
||||
iteration_count += 1
|
||||
# logger.debug(f"Running iteration {iteration_count}...")
|
||||
|
@ -95,14 +95,14 @@ async def run_agent(
|
|||
}
|
||||
break
|
||||
# Check if last message is from assistant using direct Supabase query
|
||||
latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute()
|
||||
latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute()
|
||||
if latest_message.data and len(latest_message.data) > 0:
|
||||
message_type = latest_message.data[0].get('type')
|
||||
if message_type == 'assistant':
|
||||
print(f"Last message was from assistant, stopping execution")
|
||||
continue_execution = False
|
||||
break
|
||||
|
||||
|
||||
# ---- Temporary Message Handling (Browser State & Image Context) ----
|
||||
temporary_message = None
|
||||
temp_message_content_list = [] # List to hold text/image blocks
|
||||
|
@ -133,7 +133,7 @@ async def run_agent(
|
|||
})
|
||||
else:
|
||||
logger.warning("Browser state found but no screenshot base64 data.")
|
||||
|
||||
|
||||
await client.table('messages').delete().eq('message_id', latest_browser_state_msg.data[0]["message_id"]).execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing browser state: {e}")
|
||||
|
@ -160,7 +160,7 @@ async def run_agent(
|
|||
})
|
||||
else:
|
||||
logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.")
|
||||
|
||||
|
||||
await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute()
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing image context: {e}")
|
||||
|
@ -171,7 +171,18 @@ async def run_agent(
|
|||
# logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.")
|
||||
# ---- End Temporary Message Handling ----
|
||||
|
||||
max_tokens = 64000 if "sonnet" in model_name.lower() else None
|
||||
# Set max_tokens based on model
|
||||
max_tokens = None
|
||||
if "sonnet" in model_name.lower():
|
||||
max_tokens = 64000
|
||||
elif "gpt-4" in model_name.lower():
|
||||
max_tokens = 4096
|
||||
|
||||
# # Configure tool calling based on model type
|
||||
# use_xml_tool_calling = "anthropic" in model_name.lower() or "claude" in model_name.lower()
|
||||
# use_native_tool_calling = "openai" in model_name.lower() or "gpt" in model_name.lower()
|
||||
|
||||
# # model_name = "openrouter/qwen/qwen3-235b-a22b"
|
||||
|
||||
response = await thread_manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
|
@ -197,14 +208,14 @@ async def run_agent(
|
|||
reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
||||
|
||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||
yield response
|
||||
break
|
||||
|
||||
yield response
|
||||
return
|
||||
|
||||
# Track if we see ask, complete, or web-browser-takeover tool calls
|
||||
last_tool_call = None
|
||||
|
||||
|
||||
async for chunk in response:
|
||||
# print(f"CHUNK: {chunk}") # Uncomment for detailed chunk logging
|
||||
|
||||
|
@ -217,7 +228,7 @@ async def run_agent(
|
|||
assistant_content_json = json.loads(content)
|
||||
else:
|
||||
assistant_content_json = content
|
||||
|
||||
|
||||
# The actual text content is nested within
|
||||
assistant_text = assistant_content_json.get('content', '')
|
||||
if isinstance(assistant_text, str): # Ensure it's a string
|
||||
|
@ -229,7 +240,7 @@ async def run_agent(
|
|||
xml_tool = 'complete'
|
||||
elif '</web-browser-takeover>' in assistant_text:
|
||||
xml_tool = 'web-browser-takeover'
|
||||
|
||||
|
||||
last_tool_call = xml_tool
|
||||
print(f"Agent used XML tool: {xml_tool}")
|
||||
except json.JSONDecodeError:
|
||||
|
@ -237,9 +248,31 @@ async def run_agent(
|
|||
print(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}")
|
||||
except Exception as e:
|
||||
print(f"Error processing assistant chunk: {e}")
|
||||
|
||||
|
||||
# # Check for native function calls (OpenAI format)
|
||||
# elif chunk.get('type') == 'status' and 'content' in chunk:
|
||||
# try:
|
||||
# # Parse the status content
|
||||
# status_content = chunk.get('content', '{}')
|
||||
# if isinstance(status_content, str):
|
||||
# status_content = json.loads(status_content)
|
||||
|
||||
# # Check if this is a tool call status
|
||||
# status_type = status_content.get('status_type')
|
||||
# function_name = status_content.get('function_name', '')
|
||||
|
||||
# # Check for special function names that should stop execution
|
||||
# if status_type == 'tool_started' and function_name in ['ask', 'complete', 'web-browser-takeover']:
|
||||
# last_tool_call = function_name
|
||||
# print(f"Agent used native function call: {function_name}")
|
||||
# except json.JSONDecodeError:
|
||||
# # Handle cases where content might not be valid JSON
|
||||
# print(f"Warning: Could not parse status content JSON: {chunk.get('content')}")
|
||||
# except Exception as e:
|
||||
# print(f"Error processing status chunk: {e}")
|
||||
|
||||
yield chunk
|
||||
|
||||
|
||||
# Check if we should stop based on the last tool call
|
||||
if last_tool_call in ['ask', 'complete', 'web-browser-takeover']:
|
||||
print(f"Agent decided to stop with tool: {last_tool_call}")
|
||||
|
@ -252,30 +285,30 @@ async def run_agent(
|
|||
# """Test function to run the agent with a sample query"""
|
||||
# from agentpress.thread_manager import ThreadManager
|
||||
# from services.supabase import DBConnection
|
||||
|
||||
|
||||
# # Initialize ThreadManager
|
||||
# thread_manager = ThreadManager()
|
||||
|
||||
|
||||
# # Create a test thread directly with Postgres function
|
||||
# client = await DBConnection().client
|
||||
|
||||
|
||||
# try:
|
||||
# # Get user's personal account
|
||||
# account_result = await client.rpc('get_personal_account').execute()
|
||||
|
||||
|
||||
# # if not account_result.data:
|
||||
# # print("Error: No personal account found")
|
||||
# # return
|
||||
|
||||
|
||||
# account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59"
|
||||
|
||||
|
||||
# if not account_id:
|
||||
# print("Error: Could not get account ID")
|
||||
# return
|
||||
|
||||
|
||||
# # Find or create a test project in the user's account
|
||||
# project_result = await client.table('projects').select('*').eq('name', 'test11').eq('account_id', account_id).execute()
|
||||
|
||||
|
||||
# if project_result.data and len(project_result.data) > 0:
|
||||
# # Use existing test project
|
||||
# project_id = project_result.data[0]['project_id']
|
||||
|
@ -283,42 +316,42 @@ async def run_agent(
|
|||
# else:
|
||||
# # Create new test project if none exists
|
||||
# project_result = await client.table('projects').insert({
|
||||
# "name": "test11",
|
||||
# "name": "test11",
|
||||
# "account_id": account_id
|
||||
# }).execute()
|
||||
# project_id = project_result.data[0]['project_id']
|
||||
# print(f"\n✨ Created new test project: {project_id}")
|
||||
|
||||
|
||||
# # Create a thread for this project
|
||||
# thread_result = await client.table('threads').insert({
|
||||
# 'project_id': project_id,
|
||||
# 'account_id': account_id
|
||||
# }).execute()
|
||||
# thread_data = thread_result.data[0] if thread_result.data else None
|
||||
|
||||
|
||||
# if not thread_data:
|
||||
# print("Error: No thread data returned")
|
||||
# return
|
||||
|
||||
|
||||
# thread_id = thread_data['thread_id']
|
||||
# except Exception as e:
|
||||
# print(f"Error setting up thread: {str(e)}")
|
||||
# return
|
||||
|
||||
|
||||
# print(f"\n🤖 Agent Thread Created: {thread_id}\n")
|
||||
|
||||
|
||||
# # Interactive message input loop
|
||||
# while True:
|
||||
# # Get user input
|
||||
# user_message = input("\n💬 Enter your message (or 'exit' to quit): ")
|
||||
# if user_message.lower() == 'exit':
|
||||
# break
|
||||
|
||||
|
||||
# if not user_message.strip():
|
||||
# print("\n🔄 Running agent...\n")
|
||||
# await process_agent_response(thread_id, project_id, thread_manager)
|
||||
# continue
|
||||
|
||||
|
||||
# # Add the user message to the thread
|
||||
# await thread_manager.add_message(
|
||||
# thread_id=thread_id,
|
||||
|
@ -329,10 +362,10 @@ async def run_agent(
|
|||
# },
|
||||
# is_llm_message=True
|
||||
# )
|
||||
|
||||
|
||||
# print("\n🔄 Running agent...\n")
|
||||
# await process_agent_response(thread_id, project_id, thread_manager)
|
||||
|
||||
|
||||
# print("\n👋 Test completed. Goodbye!")
|
||||
|
||||
# async def process_agent_response(
|
||||
|
@ -349,21 +382,21 @@ async def run_agent(
|
|||
# chunk_counter = 0
|
||||
# current_response = ""
|
||||
# tool_usage_counter = 0 # Renamed from tool_call_counter as we track usage via status
|
||||
|
||||
|
||||
# # Create a test sandbox for processing with a unique test prefix to avoid conflicts with production sandboxes
|
||||
# sandbox_pass = str(uuid4())
|
||||
# sandbox = create_sandbox(sandbox_pass)
|
||||
|
||||
|
||||
# # Store the original ID so we can refer to it
|
||||
# original_sandbox_id = sandbox.id
|
||||
|
||||
|
||||
# # Generate a clear test identifier
|
||||
# test_prefix = f"test_{uuid4().hex[:8]}_"
|
||||
# logger.info(f"Created test sandbox with ID {original_sandbox_id} and test prefix {test_prefix}")
|
||||
|
||||
|
||||
# # Log the sandbox URL for debugging
|
||||
# print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
||||
|
||||
|
||||
# async for chunk in run_agent(
|
||||
# thread_id=thread_id,
|
||||
# project_id=project_id,
|
||||
|
@ -388,7 +421,7 @@ async def run_agent(
|
|||
# content_json = json.loads(content)
|
||||
# else:
|
||||
# content_json = content
|
||||
|
||||
|
||||
# actual_content = content_json.get('content', '')
|
||||
# # Print the actual assistant text content as it comes
|
||||
# if actual_content:
|
||||
|
@ -418,7 +451,7 @@ async def run_agent(
|
|||
# # Add timestamp and format tool result nicely
|
||||
# tool_name = "UnknownTool" # Try to get from metadata if available
|
||||
# result_content = "No content"
|
||||
|
||||
|
||||
# # Parse metadata - handle both string and dict formats
|
||||
# metadata = chunk.get('metadata', {})
|
||||
# if isinstance(metadata, str):
|
||||
|
@ -426,7 +459,7 @@ async def run_agent(
|
|||
# metadata = json.loads(metadata)
|
||||
# except json.JSONDecodeError:
|
||||
# metadata = {}
|
||||
|
||||
|
||||
# linked_assistant_msg_id = metadata.get('assistant_message_id')
|
||||
# parsing_details = metadata.get('parsing_details')
|
||||
# if parsing_details:
|
||||
|
@ -434,12 +467,12 @@ async def run_agent(
|
|||
|
||||
# try:
|
||||
# # Content is a JSON string or object
|
||||
# content = chunk.get('content', '{}')
|
||||
# content = chunk.get('content', '{}')
|
||||
# if isinstance(content, str):
|
||||
# content_json = json.loads(content)
|
||||
# else:
|
||||
# content_json = content
|
||||
|
||||
|
||||
# # The actual tool result is nested inside content.content
|
||||
# tool_result_str = content_json.get('content', '')
|
||||
# # Extract the actual tool result string (remove outer <tool_result> tag if present)
|
||||
|
@ -471,7 +504,7 @@ async def run_agent(
|
|||
# status_content = chunk.get('content', '{}')
|
||||
# if isinstance(status_content, str):
|
||||
# status_content = json.loads(status_content)
|
||||
|
||||
|
||||
# status_type = status_content.get('status_type')
|
||||
# function_name = status_content.get('function_name', '')
|
||||
# xml_tag_name = status_content.get('xml_tag_name', '') # Get XML tag if available
|
||||
|
@ -502,10 +535,10 @@ async def run_agent(
|
|||
|
||||
|
||||
# # Removed elif chunk.get('type') == 'tool_call': block
|
||||
|
||||
|
||||
# # Update final message
|
||||
# print(f"\n\n✅ Agent run completed with {tool_usage_counter} tool executions")
|
||||
|
||||
|
||||
# # Try to clean up the test sandbox if possible
|
||||
# try:
|
||||
# # Attempt to delete/archive the sandbox to clean up resources
|
||||
|
@ -518,9 +551,9 @@ async def run_agent(
|
|||
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
|
||||
|
||||
# # Configure any environment variables or setup needed for testing
|
||||
# load_dotenv() # Ensure environment variables are loaded
|
||||
|
||||
|
||||
# # Run the test function
|
||||
# asyncio.run(test_agent())
|
|
@ -4,14 +4,14 @@ from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema
|
|||
|
||||
class MessageTool(Tool):
|
||||
"""Tool for user communication and interaction.
|
||||
|
||||
|
||||
This tool provides methods for asking questions, with support for
|
||||
attachments and user takeover suggestions.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Commented out as we are just doing this via prompt as there is no need to call it as a tool
|
||||
|
||||
@openapi_schema({
|
||||
|
@ -55,25 +55,25 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
<!-- 4. Validating critical assumptions -->
|
||||
<!-- 5. Getting missing information -->
|
||||
<!-- IMPORTANT: Always if applicable include representable files as attachments - this includes HTML files, presentations, writeups, visualizations, reports, and any other viewable content -->
|
||||
|
||||
|
||||
<ask attachments="recipes/chocolate_cake.txt,photos/cake_examples.jpg">
|
||||
I'm planning to bake the chocolate cake for your birthday party. The recipe mentions "rich frosting" but doesn't specify what type. Could you clarify your preferences? For example:
|
||||
1. Would you prefer buttercream or cream cheese frosting?
|
||||
2. Do you want any specific flavor added to the frosting (vanilla, coffee, etc.)?
|
||||
3. Should I add any decorative toppings like sprinkles or fruit?
|
||||
4. Do you have any dietary restrictions I should be aware of?
|
||||
|
||||
|
||||
This information will help me make sure the cake meets your expectations for the celebration.
|
||||
</ask>
|
||||
'''
|
||||
)
|
||||
async def ask(self, text: str, attachments: Optional[Union[str, List[str]]] = None) -> ToolResult:
|
||||
"""Ask the user a question and wait for a response.
|
||||
|
||||
|
||||
Args:
|
||||
text: The question to present to the user
|
||||
attachments: Optional file paths or URLs to attach to the question
|
||||
|
||||
|
||||
Returns:
|
||||
ToolResult indicating the question was successfully sent
|
||||
"""
|
||||
|
@ -81,7 +81,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
# Convert single attachment to list for consistent handling
|
||||
if attachments and isinstance(attachments, str):
|
||||
attachments = [attachments]
|
||||
|
||||
|
||||
return self.success_response({"status": "Awaiting user response..."})
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error asking user: {str(e)}")
|
||||
|
@ -122,24 +122,24 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
<!-- 1. CAPTCHA or human verification required -->
|
||||
<!-- 2. Anti-bot measures preventing access -->
|
||||
<!-- 3. Authentication requiring human input -->
|
||||
|
||||
|
||||
<web-browser-takeover>
|
||||
I've encountered a CAPTCHA verification on the page. Please:
|
||||
1. Solve the CAPTCHA puzzle
|
||||
2. Let me know once you've completed it
|
||||
3. I'll then continue with the automated process
|
||||
|
||||
|
||||
If you encounter any issues or need to take additional steps, please let me know.
|
||||
</web-browser-takeover>
|
||||
'''
|
||||
)
|
||||
async def web_browser_takeover(self, text: str, attachments: Optional[Union[str, List[str]]] = None) -> ToolResult:
|
||||
"""Request user takeover of browser interaction.
|
||||
|
||||
|
||||
Args:
|
||||
text: Instructions for the user about what actions to take
|
||||
attachments: Optional file paths or URLs to attach to the request
|
||||
|
||||
|
||||
Returns:
|
||||
ToolResult indicating the takeover request was successfully sent
|
||||
"""
|
||||
|
@ -147,7 +147,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
# Convert single attachment to list for consistent handling
|
||||
if attachments and isinstance(attachments, str):
|
||||
attachments = [attachments]
|
||||
|
||||
|
||||
return self.success_response({"status": "Awaiting user browser takeover..."})
|
||||
except Exception as e:
|
||||
return self.fail_response(f"Error requesting browser takeover: {str(e)}")
|
||||
|
@ -184,7 +184,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
# ],
|
||||
# example='''
|
||||
|
||||
# Inform the user about progress, completion of a major step, or important context. Use this tool: 1) To provide updates between major sections of work, 2) After accomplishing significant milestones, 3) When transitioning to a new phase of work, 4) To confirm actions were completed successfully, 5) To provide context about upcoming steps. IMPORTANT: Use FREQUENTLY throughout execution to provide UI context to the user. The user CANNOT respond to this tool - they can only respond to the 'ask' tool. Use this tool to keep the user informed without requiring their input."
|
||||
# Inform the user about progress, completion of a major step, or important context. Use this tool: 1) To provide updates between major sections of work, 2) After accomplishing significant milestones, 3) When transitioning to a new phase of work, 4) To confirm actions were completed successfully, 5) To provide context about upcoming steps. IMPORTANT: Use FREQUENTLY throughout execution to provide UI context to the user. The user CANNOT respond to this tool - they can only respond to the 'ask' tool. Use this tool to keep the user informed without requiring their input."
|
||||
|
||||
# <!-- Use inform FREQUENTLY to provide UI context and progress updates - THE USER CANNOT RESPOND to this tool -->
|
||||
# <!-- The user can ONLY respond to the ask tool, not to inform -->
|
||||
|
@ -195,24 +195,24 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
# <!-- 4. Providing context about upcoming steps -->
|
||||
# <!-- 5. Sharing significant intermediate results -->
|
||||
# <!-- 6. Providing regular UI updates throughout execution -->
|
||||
|
||||
|
||||
# <inform attachments="analysis_results.csv,summary_chart.png">
|
||||
# I've completed the data analysis of the sales figures. Key findings include:
|
||||
# - Q4 sales were 28% higher than Q3
|
||||
# - Product line A showed the strongest performance
|
||||
# - Three regions missed their targets
|
||||
|
||||
|
||||
# I'll now proceed with creating the executive summary report based on these findings.
|
||||
# </inform>
|
||||
# '''
|
||||
# )
|
||||
# async def inform(self, text: str, attachments: Optional[Union[str, List[str]]] = None) -> ToolResult:
|
||||
# """Inform the user about progress or important updates without requiring a response.
|
||||
|
||||
|
||||
# Args:
|
||||
# text: The information to present to the user
|
||||
# attachments: Optional file paths or URLs to attach
|
||||
|
||||
|
||||
# Returns:
|
||||
# ToolResult indicating the information was successfully sent
|
||||
# """
|
||||
|
@ -220,7 +220,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
# # Convert single attachment to list for consistent handling
|
||||
# if attachments and isinstance(attachments, str):
|
||||
# attachments = [attachments]
|
||||
|
||||
|
||||
# return self.success_response({"status": "Information sent"})
|
||||
# except Exception as e:
|
||||
# return self.fail_response(f"Error informing user: {str(e)}")
|
||||
|
@ -231,7 +231,9 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
"name": "complete",
|
||||
"description": "A special tool to indicate you have completed all tasks and are about to enter complete state. Use ONLY when: 1) All tasks in todo.md are marked complete [x], 2) The user's original request has been fully addressed, 3) There are no pending actions or follow-ups required, 4) You've delivered all final outputs and results to the user. IMPORTANT: This is the ONLY way to properly terminate execution. Never use this tool unless ALL tasks are complete and verified. Always ensure you've provided all necessary outputs and references before using this tool.",
|
||||
"parameters": {
|
||||
"type": "object"
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -246,7 +248,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
<!-- 3. All outputs and results delivered -->
|
||||
<!-- 4. No pending actions or follow-ups -->
|
||||
<!-- 5. All tasks verified and validated -->
|
||||
|
||||
|
||||
<complete>
|
||||
<!-- This tool indicates successful completion of all tasks -->
|
||||
<!-- The system will stop execution after this tool is used -->
|
||||
|
@ -255,7 +257,7 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
)
|
||||
async def complete(self) -> ToolResult:
|
||||
"""Indicate that the agent has completed all tasks and is entering complete state.
|
||||
|
||||
|
||||
Returns:
|
||||
ToolResult indicating successful transition to complete state
|
||||
"""
|
||||
|
@ -267,22 +269,22 @@ Ask user a question and wait for response. Use for: 1) Requesting clarification
|
|||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
||||
async def test_message_tool():
|
||||
message_tool = MessageTool()
|
||||
|
||||
|
||||
# Test question
|
||||
ask_result = await message_tool.ask(
|
||||
text="Would you like to proceed with the next phase?",
|
||||
attachments="summary.pdf"
|
||||
)
|
||||
print("Question result:", ask_result)
|
||||
|
||||
|
||||
# Test inform
|
||||
inform_result = await message_tool.inform(
|
||||
text="Completed analysis of data. Processing results now.",
|
||||
attachments="analysis.pdf"
|
||||
)
|
||||
print("Inform result:", inform_result)
|
||||
|
||||
|
||||
asyncio.run(test_message_tool())
|
||||
|
|
|
@ -17,8 +17,8 @@ from agentpress.tool import Tool
|
|||
from agentpress.tool_registry import ToolRegistry
|
||||
from agentpress.context_manager import ContextManager
|
||||
from agentpress.response_processor import (
|
||||
ResponseProcessor,
|
||||
ProcessorConfig
|
||||
ResponseProcessor,
|
||||
ProcessorConfig
|
||||
)
|
||||
from services.supabase import DBConnection
|
||||
from utils.logger import logger
|
||||
|
@ -28,7 +28,7 @@ ToolChoice = Literal["auto", "required", "none"]
|
|||
|
||||
class ThreadManager:
|
||||
"""Manages conversation threads with LLM models and tool execution.
|
||||
|
||||
|
||||
Provides comprehensive conversation management, handling message threading,
|
||||
tool registration, and LLM interactions with support for both standard and
|
||||
XML-based tool execution patterns.
|
||||
|
@ -36,7 +36,7 @@ class ThreadManager:
|
|||
|
||||
def __init__(self):
|
||||
"""Initialize ThreadManager.
|
||||
|
||||
|
||||
"""
|
||||
self.db = DBConnection()
|
||||
self.tool_registry = ToolRegistry()
|
||||
|
@ -51,10 +51,10 @@ class ThreadManager:
|
|||
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
type: str,
|
||||
content: Union[Dict[str, Any], List[Any], str],
|
||||
self,
|
||||
thread_id: str,
|
||||
type: str,
|
||||
content: Union[Dict[str, Any], List[Any], str],
|
||||
is_llm_message: bool = False,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
|
@ -72,7 +72,7 @@ class ThreadManager:
|
|||
"""
|
||||
logger.debug(f"Adding message of type '{type}' to thread {thread_id}")
|
||||
client = await self.db.client
|
||||
|
||||
|
||||
# Prepare data for insertion
|
||||
data_to_insert = {
|
||||
'thread_id': thread_id,
|
||||
|
@ -81,12 +81,12 @@ class ThreadManager:
|
|||
'is_llm_message': is_llm_message,
|
||||
'metadata': json.dumps(metadata or {}), # Ensure metadata is always a JSON object
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Add returning='representation' to get the inserted row data including the id
|
||||
result = await client.table('messages').insert(data_to_insert, returning='representation').execute()
|
||||
logger.info(f"Successfully added message to thread {thread_id}")
|
||||
|
||||
|
||||
if result.data and len(result.data) > 0 and isinstance(result.data[0], dict) and 'message_id' in result.data[0]:
|
||||
return result.data[0]
|
||||
else:
|
||||
|
@ -98,26 +98,26 @@ class ThreadManager:
|
|||
|
||||
async def get_llm_messages(self, thread_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all messages for a thread.
|
||||
|
||||
|
||||
This method uses the SQL function which handles context truncation
|
||||
by considering summary messages.
|
||||
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread to get messages for.
|
||||
|
||||
|
||||
Returns:
|
||||
List of message objects.
|
||||
"""
|
||||
logger.debug(f"Getting messages for thread {thread_id}")
|
||||
client = await self.db.client
|
||||
|
||||
|
||||
try:
|
||||
result = await client.rpc('get_llm_formatted_messages', {'p_thread_id': thread_id}).execute()
|
||||
|
||||
|
||||
# Parse the returned data which might be stringified JSON
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
|
||||
# Return properly parsed JSON objects
|
||||
messages = []
|
||||
for item in result.data:
|
||||
|
@ -140,7 +140,7 @@ class ThreadManager:
|
|||
tool_call['function']['arguments'] = json.dumps(tool_call['function']['arguments'])
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get messages for thread {thread_id}: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
@ -164,7 +164,7 @@ class ThreadManager:
|
|||
enable_context_manager: bool = True
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""Run a conversation thread with LLM integration and tool execution.
|
||||
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread to run
|
||||
system_prompt: System message to set the assistant's behavior
|
||||
|
@ -175,30 +175,31 @@ class ThreadManager:
|
|||
llm_max_tokens: Maximum tokens in the LLM response
|
||||
processor_config: Configuration for the response processor
|
||||
tool_choice: Tool choice preference ("auto", "required", "none")
|
||||
native_max_auto_continues: Maximum number of automatic continuations when
|
||||
native_max_auto_continues: Maximum number of automatic continuations when
|
||||
finish_reason="tool_calls" (0 disables auto-continue)
|
||||
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
|
||||
include_xml_examples: Whether to include XML tool examples in the system prompt
|
||||
enable_thinking: Whether to enable thinking before making a decision
|
||||
reasoning_effort: The effort level for reasoning
|
||||
enable_context_manager: Whether to enable automatic context summarization.
|
||||
|
||||
|
||||
Returns:
|
||||
An async generator yielding response chunks or error dict
|
||||
"""
|
||||
|
||||
|
||||
logger.info(f"Starting thread execution for thread {thread_id}")
|
||||
logger.debug(f"Parameters: model={llm_model}, temperature={llm_temperature}, max_tokens={llm_max_tokens}")
|
||||
logger.debug(f"Auto-continue: max={native_max_auto_continues}, XML tool limit={max_xml_tool_calls}")
|
||||
|
||||
# Use a default config if none was provided (needed for XML examples check)
|
||||
if processor_config is None:
|
||||
processor_config = ProcessorConfig()
|
||||
logger.info(f"Using model: {llm_model}")
|
||||
# Log parameters
|
||||
logger.info(f"Parameters: model={llm_model}, temperature={llm_temperature}, max_tokens={llm_max_tokens}")
|
||||
logger.info(f"Auto-continue: max={native_max_auto_continues}, XML tool limit={max_xml_tool_calls}")
|
||||
|
||||
# Log model info
|
||||
logger.info(f"🤖 Thread {thread_id}: Using model {llm_model}")
|
||||
|
||||
# Apply max_xml_tool_calls if specified and not already set in config
|
||||
if max_xml_tool_calls > 0 and not processor_config.max_xml_tool_calls:
|
||||
processor_config.max_xml_tool_calls = max_xml_tool_calls
|
||||
|
||||
|
||||
# Create a working copy of the system prompt to potentially modify
|
||||
working_system_prompt = system_prompt.copy()
|
||||
|
||||
|
@ -236,30 +237,30 @@ Here are the XML tools available with examples:
|
|||
logger.warning("System prompt content is a list but no text block found to append XML examples.")
|
||||
else:
|
||||
logger.warning(f"System prompt content is of unexpected type ({type(system_content)}), cannot add XML examples.")
|
||||
|
||||
|
||||
# Control whether we need to auto-continue due to tool_calls finish reason
|
||||
auto_continue = True
|
||||
auto_continue_count = 0
|
||||
|
||||
|
||||
# Define inner function to handle a single run
|
||||
async def _run_once(temp_msg=None):
|
||||
try:
|
||||
# Ensure processor_config is available in this scope
|
||||
nonlocal processor_config
|
||||
nonlocal processor_config
|
||||
# Note: processor_config is now guaranteed to exist due to check above
|
||||
|
||||
|
||||
# 1. Get messages from thread for LLM call
|
||||
messages = await self.get_llm_messages(thread_id)
|
||||
|
||||
|
||||
# 2. Check token count before proceeding
|
||||
token_count = 0
|
||||
try:
|
||||
from litellm import token_counter
|
||||
# Use the potentially modified working_system_prompt for token counting
|
||||
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
token_threshold = self.context_manager.token_threshold
|
||||
logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)")
|
||||
|
||||
|
||||
# if token_count >= token_threshold and enable_context_manager:
|
||||
# logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...")
|
||||
# summarized = await self.context_manager.check_and_summarize_if_needed(
|
||||
|
@ -272,26 +273,26 @@ Here are the XML tools available with examples:
|
|||
# logger.info("Summarization complete, fetching updated messages with summary")
|
||||
# messages = await self.get_llm_messages(thread_id)
|
||||
# # Recount tokens after summarization, using the modified prompt
|
||||
# new_token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
# new_token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
# logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}")
|
||||
# else:
|
||||
# logger.warning("Summarization failed or wasn't needed - proceeding with original messages")
|
||||
# elif not enable_context_manager:
|
||||
# elif not enable_context_manager:
|
||||
# logger.info("Automatic summarization disabled. Skipping token count check and summarization.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting tokens or summarizing: {str(e)}")
|
||||
|
||||
|
||||
# 3. Prepare messages for LLM call + add temporary message if it exists
|
||||
# Use the working_system_prompt which may contain the XML examples
|
||||
prepared_messages = [working_system_prompt]
|
||||
|
||||
prepared_messages = [working_system_prompt]
|
||||
|
||||
# Find the last user message index
|
||||
last_user_index = -1
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.get('role') == 'user':
|
||||
last_user_index = i
|
||||
|
||||
|
||||
# Insert temporary message before the last user message if it exists
|
||||
if temp_msg and last_user_index >= 0:
|
||||
prepared_messages.extend(messages[:last_user_index])
|
||||
|
@ -341,7 +342,7 @@ Here are the XML tools available with examples:
|
|||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
|
||||
return response_generator
|
||||
else:
|
||||
logger.debug("Processing non-streaming response")
|
||||
|
@ -358,31 +359,31 @@ Here are the XML tools available with examples:
|
|||
except Exception as e:
|
||||
logger.error(f"Error setting up non-streaming response: {str(e)}", exc_info=True)
|
||||
raise # Re-raise the exception to be caught by the outer handler
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
# Define a wrapper generator that handles auto-continue logic
|
||||
async def auto_continue_wrapper():
|
||||
nonlocal auto_continue, auto_continue_count
|
||||
|
||||
|
||||
while auto_continue and (native_max_auto_continues == 0 or auto_continue_count < native_max_auto_continues):
|
||||
# Reset auto_continue for this iteration
|
||||
auto_continue = False
|
||||
|
||||
|
||||
# Run the thread once, passing the potentially modified system prompt
|
||||
# Pass temp_msg only on the first iteration
|
||||
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
|
||||
|
||||
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
|
||||
|
||||
# Handle error responses
|
||||
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error":
|
||||
yield response_gen
|
||||
return
|
||||
|
||||
|
||||
# Process each chunk
|
||||
async for chunk in response_gen:
|
||||
# Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached
|
||||
|
@ -400,27 +401,27 @@ Here are the XML tools available with examples:
|
|||
logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
|
||||
auto_continue = False
|
||||
# Still yield the chunk to inform the client
|
||||
|
||||
|
||||
# Otherwise just yield the chunk normally
|
||||
yield chunk
|
||||
|
||||
|
||||
# If not auto-continuing, we're done
|
||||
if not auto_continue:
|
||||
break
|
||||
|
||||
|
||||
# If we've reached the max auto-continues, log a warning
|
||||
if auto_continue and auto_continue_count >= native_max_auto_continues:
|
||||
logger.warning(f"Reached maximum auto-continue limit ({native_max_auto_continues}), stopping.")
|
||||
yield {
|
||||
"type": "content",
|
||||
"type": "content",
|
||||
"content": f"\n[Agent reached maximum auto-continue limit of {native_max_auto_continues}]"
|
||||
}
|
||||
|
||||
|
||||
# If auto-continue is disabled (max=0), just run once
|
||||
if native_max_auto_continues == 0:
|
||||
logger.info("Auto-continue is disabled (native_max_auto_continues=0)")
|
||||
# Pass the potentially modified system prompt and temp message
|
||||
return await _run_once(temporary_message)
|
||||
|
||||
return await _run_once(temporary_message)
|
||||
|
||||
# Otherwise return the auto-continue wrapper generator
|
||||
return auto_continue_wrapper()
|
||||
|
|
|
@ -46,17 +46,17 @@ def setup_api_keys() -> None:
|
|||
logger.debug(f"API key set for provider: {provider}")
|
||||
else:
|
||||
logger.warning(f"No API key found for provider: {provider}")
|
||||
|
||||
|
||||
# Set up OpenRouter API base if not already set
|
||||
if config.OPENROUTER_API_KEY and config.OPENROUTER_API_BASE:
|
||||
os.environ['OPENROUTER_API_BASE'] = config.OPENROUTER_API_BASE
|
||||
logger.debug(f"Set OPENROUTER_API_BASE to {config.OPENROUTER_API_BASE}")
|
||||
|
||||
|
||||
# Set up AWS Bedrock credentials
|
||||
aws_access_key = config.AWS_ACCESS_KEY_ID
|
||||
aws_secret_key = config.AWS_SECRET_ACCESS_KEY
|
||||
aws_region = config.AWS_REGION_NAME
|
||||
|
||||
|
||||
if aws_access_key and aws_secret_key and aws_region:
|
||||
logger.debug(f"AWS credentials set for Bedrock in region: {aws_region}")
|
||||
# Configure LiteLLM to use AWS credentials
|
||||
|
@ -132,11 +132,11 @@ def prepare_params(
|
|||
"anthropic-beta": "output-128k-2025-02-19"
|
||||
}
|
||||
logger.debug("Added Claude-specific headers")
|
||||
|
||||
|
||||
# Add OpenRouter-specific parameters
|
||||
if model_name.startswith("openrouter/"):
|
||||
logger.debug(f"Preparing OpenRouter parameters for model: {model_name}")
|
||||
|
||||
|
||||
# Add optional site URL and app name from config
|
||||
site_url = config.OR_SITE_URL
|
||||
app_name = config.OR_APP_NAME
|
||||
|
@ -148,11 +148,11 @@ def prepare_params(
|
|||
extra_headers["X-Title"] = app_name
|
||||
params["extra_headers"] = extra_headers
|
||||
logger.debug(f"Added OpenRouter site URL and app name to headers")
|
||||
|
||||
|
||||
# Add Bedrock-specific parameters
|
||||
if model_name.startswith("bedrock/"):
|
||||
logger.debug(f"Preparing AWS Bedrock parameters for model: {model_name}")
|
||||
|
||||
|
||||
if not model_id and "anthropic.claude-3-7-sonnet" in model_name:
|
||||
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
|
||||
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
||||
|
@ -256,7 +256,7 @@ async def make_llm_api_call(
|
|||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""
|
||||
Make an API call to a language model using LiteLLM.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries for the conversation
|
||||
model_name: Name of the model to use (e.g., "gpt-4", "claude-3", "openrouter/openai/gpt-4", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
|
@ -272,16 +272,17 @@ async def make_llm_api_call(
|
|||
model_id: Optional ARN for Bedrock inference profiles
|
||||
enable_thinking: Whether to enable thinking
|
||||
reasoning_effort: Level of reasoning effort
|
||||
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
||||
|
||||
|
||||
Raises:
|
||||
LLMRetryError: If API call fails after retries
|
||||
LLMError: For other API-related errors
|
||||
"""
|
||||
# debug <timestamp>.json messages
|
||||
logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
|
||||
# debug <timestamp>.json messages
|
||||
logger.info(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
|
||||
logger.info(f"📡 API Call: Using model {model_name}")
|
||||
params = prepare_params(
|
||||
messages=messages,
|
||||
model_name=model_name,
|
||||
|
@ -303,20 +304,20 @@ async def make_llm_api_call(
|
|||
try:
|
||||
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
|
||||
# logger.debug(f"API request parameters: {json.dumps(params, indent=2)}")
|
||||
|
||||
|
||||
response = await litellm.acompletion(**params)
|
||||
logger.debug(f"Successfully received API response from {model_name}")
|
||||
logger.debug(f"Response: {response}")
|
||||
return response
|
||||
|
||||
|
||||
except (litellm.exceptions.RateLimitError, OpenAIError, json.JSONDecodeError) as e:
|
||||
last_error = e
|
||||
await handle_error(e, attempt, MAX_RETRIES)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during API call: {str(e)}", exc_info=True)
|
||||
raise LLMError(f"API call failed: {str(e)}")
|
||||
|
||||
|
||||
error_msg = f"Failed to make API call after {MAX_RETRIES} attempts"
|
||||
if last_error:
|
||||
error_msg += f". Last error: {str(last_error)}"
|
||||
|
@ -332,7 +333,7 @@ async def test_openrouter():
|
|||
test_messages = [
|
||||
{"role": "user", "content": "Hello, can you give me a quick test response?"}
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
# Test with standard OpenRouter model
|
||||
print("\n--- Testing standard OpenRouter model ---")
|
||||
|
@ -343,7 +344,7 @@ async def test_openrouter():
|
|||
max_tokens=100
|
||||
)
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
|
||||
|
||||
# Test with deepseek model
|
||||
print("\n--- Testing deepseek model ---")
|
||||
response = await make_llm_api_call(
|
||||
|
@ -354,7 +355,7 @@ async def test_openrouter():
|
|||
)
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
print(f"Model used: {response.model}")
|
||||
|
||||
|
||||
# Test with Mistral model
|
||||
print("\n--- Testing Mistral model ---")
|
||||
response = await make_llm_api_call(
|
||||
|
@ -365,7 +366,7 @@ async def test_openrouter():
|
|||
)
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
print(f"Model used: {response.model}")
|
||||
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error testing OpenRouter: {str(e)}")
|
||||
|
@ -376,8 +377,8 @@ async def test_bedrock():
|
|||
test_messages = [
|
||||
{"role": "user", "content": "Hello, can you give me a quick test response?"}
|
||||
]
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
response = await make_llm_api_call(
|
||||
model_name="bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
model_id="arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
|
@ -388,7 +389,7 @@ async def test_bedrock():
|
|||
)
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
print(f"Model used: {response.model}")
|
||||
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error testing Bedrock: {str(e)}")
|
||||
|
@ -396,9 +397,9 @@ async def test_bedrock():
|
|||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
|
||||
test_success = asyncio.run(test_bedrock())
|
||||
|
||||
|
||||
if test_success:
|
||||
print("\n✅ integration test completed successfully!")
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,342 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Script to archive sandboxes for projects that are older than 1 day.
|
||||
|
||||
Usage:
|
||||
python archive_old_sandboxes.py [--days N] [--dry-run]
|
||||
|
||||
This script:
|
||||
1. Gets all projects from the projects table
|
||||
2. Filters projects created more than N days ago (default: 1 day)
|
||||
3. Archives the sandboxes for those projects
|
||||
|
||||
Make sure your environment variables are properly set:
|
||||
- SUPABASE_URL
|
||||
- SUPABASE_SERVICE_ROLE_KEY
|
||||
- DAYTONA_SERVER_URL
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load script-specific environment variables
|
||||
load_dotenv(".env")
|
||||
|
||||
from services.supabase import DBConnection
|
||||
from sandbox.sandbox import daytona
|
||||
from utils.logger import logger
|
||||
|
||||
# Global DB connection to reuse
|
||||
db_connection = None
|
||||
|
||||
|
||||
async def get_old_projects(days_threshold: int = 1) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query all projects created more than N days ago.
|
||||
|
||||
Args:
|
||||
days_threshold: Number of days threshold (default: 1)
|
||||
|
||||
Returns:
|
||||
List of projects with their sandbox information
|
||||
"""
|
||||
global db_connection
|
||||
if db_connection is None:
|
||||
db_connection = DBConnection()
|
||||
|
||||
client = await db_connection.client
|
||||
|
||||
# Print the Supabase URL being used
|
||||
print(f"Using Supabase URL: {os.getenv('SUPABASE_URL')}")
|
||||
|
||||
# Calculate the date threshold
|
||||
threshold_date = (datetime.now() - timedelta(days=days_threshold)).isoformat()
|
||||
|
||||
# Initialize variables for pagination
|
||||
all_projects = []
|
||||
page_size = 1000
|
||||
current_page = 0
|
||||
has_more = True
|
||||
|
||||
logger.info(f"Starting to fetch projects older than {days_threshold} day(s)")
|
||||
print(f"Looking for projects created before: {threshold_date}")
|
||||
|
||||
# Paginate through all projects
|
||||
while has_more:
|
||||
# Query projects with pagination
|
||||
start_range = current_page * page_size
|
||||
end_range = start_range + page_size - 1
|
||||
|
||||
logger.info(f"Fetching projects page {current_page+1} (range: {start_range}-{end_range})")
|
||||
|
||||
try:
|
||||
result = await client.table('projects').select(
|
||||
'project_id',
|
||||
'name',
|
||||
'created_at',
|
||||
'account_id',
|
||||
'sandbox'
|
||||
).range(start_range, end_range).execute()
|
||||
|
||||
# Debug info - print raw response
|
||||
print(f"Response data length: {len(result.data)}")
|
||||
|
||||
if not result.data:
|
||||
print("No more data returned from query, ending pagination")
|
||||
has_more = False
|
||||
else:
|
||||
# Print a sample project to see the actual data structure
|
||||
if current_page == 0 and result.data:
|
||||
print(f"Sample project data: {result.data[0]}")
|
||||
|
||||
all_projects.extend(result.data)
|
||||
current_page += 1
|
||||
|
||||
# Progress update
|
||||
logger.info(f"Loaded {len(all_projects)} projects so far")
|
||||
print(f"Loaded {len(all_projects)} projects so far...")
|
||||
|
||||
# Check if we've reached the end - if we got fewer results than the page size
|
||||
if len(result.data) < page_size:
|
||||
print(f"Got {len(result.data)} records which is less than page size {page_size}, ending pagination")
|
||||
has_more = False
|
||||
else:
|
||||
print(f"Full page returned ({len(result.data)} records), continuing to next page")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during pagination: {str(e)}")
|
||||
print(f"Error during pagination: {str(e)}")
|
||||
has_more = False # Stop on error
|
||||
|
||||
# Print the query result summary
|
||||
total_projects = len(all_projects)
|
||||
print(f"Found {total_projects} total projects in database")
|
||||
logger.info(f"Total projects found in database: {total_projects}")
|
||||
|
||||
if not all_projects:
|
||||
logger.info("No projects found in database")
|
||||
return []
|
||||
|
||||
# Filter projects that are older than the threshold and have sandbox information
|
||||
old_projects_with_sandboxes = [
|
||||
project for project in all_projects
|
||||
if project.get('created_at') and project.get('created_at') < threshold_date
|
||||
and project.get('sandbox') and project['sandbox'].get('id')
|
||||
]
|
||||
|
||||
logger.info(f"Found {len(old_projects_with_sandboxes)} old projects with sandboxes")
|
||||
|
||||
# Print a few sample old projects for debugging
|
||||
if old_projects_with_sandboxes:
|
||||
print("\nSample of old projects with sandboxes:")
|
||||
for i, project in enumerate(old_projects_with_sandboxes[:3]):
|
||||
print(f" {i+1}. {project.get('name')} (Created: {project.get('created_at')})")
|
||||
print(f" Sandbox ID: {project['sandbox'].get('id')}")
|
||||
if i >= 2:
|
||||
break
|
||||
|
||||
return old_projects_with_sandboxes
|
||||
|
||||
|
||||
async def archive_sandbox(project: Dict[str, Any], dry_run: bool) -> bool:
|
||||
"""
|
||||
Archive a single sandbox.
|
||||
|
||||
Args:
|
||||
project: Project information containing sandbox to archive
|
||||
dry_run: If True, only simulate archiving
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
sandbox_id = project['sandbox'].get('id')
|
||||
project_name = project.get('name', 'Unknown')
|
||||
project_id = project.get('project_id', 'Unknown')
|
||||
created_at = project.get('created_at', 'Unknown')
|
||||
|
||||
try:
|
||||
logger.info(f"Checking sandbox {sandbox_id} for project '{project_name}' (ID: {project_id}, Created: {created_at})")
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"DRY RUN: Would archive sandbox {sandbox_id}")
|
||||
print(f"Would archive sandbox {sandbox_id} for project '{project_name}' (Created: {created_at})")
|
||||
return True
|
||||
|
||||
# Get the sandbox
|
||||
sandbox = daytona.get_current_sandbox(sandbox_id)
|
||||
|
||||
# Check sandbox state - it must be stopped before archiving
|
||||
sandbox_info = sandbox.info()
|
||||
|
||||
# Log the current state
|
||||
logger.info(f"Sandbox {sandbox_id} is in '{sandbox_info.state}' state")
|
||||
|
||||
# Only archive if the sandbox is in the stopped state
|
||||
if sandbox_info.state == "stopped":
|
||||
logger.info(f"Archiving sandbox {sandbox_id} as it is in stopped state")
|
||||
sandbox.archive()
|
||||
logger.info(f"Successfully archived sandbox {sandbox_id}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"Skipping sandbox {sandbox_id} as it is not in stopped state (current: {sandbox_info.state})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_type = type(e).__name__
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
# Log detailed error information
|
||||
logger.error(f"Error processing sandbox {sandbox_id}: {str(e)}")
|
||||
logger.error(f"Error type: {error_type}")
|
||||
logger.error(f"Stack trace:\n{stack_trace}")
|
||||
|
||||
# If the exception has a response attribute (like in HTTP errors), log it
|
||||
if hasattr(e, 'response'):
|
||||
try:
|
||||
response_data = e.response.json() if hasattr(e.response, 'json') else str(e.response)
|
||||
logger.error(f"Response data: {response_data}")
|
||||
except Exception:
|
||||
logger.error(f"Could not parse response data from error")
|
||||
|
||||
print(f"Failed to process sandbox {sandbox_id}: {error_type} - {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def process_sandboxes(old_projects: List[Dict[str, Any]], dry_run: bool) -> tuple[int, int]:
|
||||
"""
|
||||
Process all sandboxes sequentially.
|
||||
|
||||
Args:
|
||||
old_projects: List of projects older than the threshold
|
||||
dry_run: Whether to actually archive sandboxes or just simulate
|
||||
|
||||
Returns:
|
||||
Tuple of (processed_count, failed_count)
|
||||
"""
|
||||
processed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"DRY RUN: Would archive {len(old_projects)} sandboxes")
|
||||
else:
|
||||
logger.info(f"Archiving {len(old_projects)} sandboxes")
|
||||
|
||||
print(f"Processing {len(old_projects)} sandboxes...")
|
||||
|
||||
# Process each sandbox sequentially
|
||||
for i, project in enumerate(old_projects):
|
||||
success = await archive_sandbox(project, dry_run)
|
||||
|
||||
if success:
|
||||
processed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# Print progress periodically
|
||||
if (i + 1) % 20 == 0 or (i + 1) == len(old_projects):
|
||||
progress = (i + 1) / len(old_projects) * 100
|
||||
print(f"Progress: {i + 1}/{len(old_projects)} sandboxes processed ({progress:.1f}%)")
|
||||
print(f" - Processed: {processed_count}, Failed: {failed_count}")
|
||||
|
||||
return processed_count, failed_count
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run the script."""
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description='Archive sandboxes for projects older than N days')
|
||||
parser.add_argument('--days', type=int, default=1, help='Age threshold in days (default: 1)')
|
||||
parser.add_argument('--dry-run', action='store_true', help='Show what would be archived without actually archiving')
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"Starting sandbox cleanup for projects older than {args.days} day(s)")
|
||||
if args.dry_run:
|
||||
logger.info("DRY RUN MODE - No sandboxes will be archived")
|
||||
|
||||
# Print environment info
|
||||
print(f"Environment Mode: {os.getenv('ENV_MODE', 'Not set')}")
|
||||
print(f"Daytona Server: {os.getenv('DAYTONA_SERVER_URL', 'Not set')}")
|
||||
|
||||
try:
|
||||
# Initialize global DB connection
|
||||
global db_connection
|
||||
db_connection = DBConnection()
|
||||
|
||||
# Get all projects older than the threshold
|
||||
old_projects = await get_old_projects(args.days)
|
||||
|
||||
if not old_projects:
|
||||
logger.info(f"No projects older than {args.days} day(s) with sandboxes to process")
|
||||
print(f"No projects older than {args.days} day(s) with sandboxes to archive.")
|
||||
return
|
||||
|
||||
# Print summary of what will be processed
|
||||
print("\n===== SANDBOX CLEANUP SUMMARY =====")
|
||||
print(f"Projects older than {args.days} day(s): {len(old_projects)}")
|
||||
print(f"Sandboxes that will be archived: {len(old_projects)}")
|
||||
print("===================================")
|
||||
|
||||
logger.info(f"Found {len(old_projects)} projects older than {args.days} day(s)")
|
||||
|
||||
# Ask for confirmation before proceeding
|
||||
if not args.dry_run:
|
||||
print("\n⚠️ WARNING: You are about to archive sandboxes for old projects ⚠️")
|
||||
print("This action cannot be undone!")
|
||||
confirmation = input("\nAre you sure you want to proceed with archiving? (TRUE/FALSE): ").strip().upper()
|
||||
|
||||
if confirmation != "TRUE":
|
||||
print("Archiving cancelled. Exiting script.")
|
||||
logger.info("Archiving cancelled by user")
|
||||
return
|
||||
|
||||
print("\nProceeding with sandbox archiving...\n")
|
||||
logger.info("User confirmed sandbox archiving")
|
||||
|
||||
# List a sample of projects to be processed
|
||||
for i, project in enumerate(old_projects[:5]): # Just show first 5 for brevity
|
||||
created_at = project.get('created_at', 'Unknown')
|
||||
project_name = project.get('name', 'Unknown')
|
||||
project_id = project.get('project_id', 'Unknown')
|
||||
sandbox_id = project['sandbox'].get('id')
|
||||
|
||||
print(f"{i+1}. Project: {project_name}")
|
||||
print(f" Project ID: {project_id}")
|
||||
print(f" Created At: {created_at}")
|
||||
print(f" Sandbox ID: {sandbox_id}")
|
||||
|
||||
if len(old_projects) > 5:
|
||||
print(f" ... and {len(old_projects) - 5} more projects")
|
||||
|
||||
# Process all sandboxes
|
||||
processed_count, failed_count = await process_sandboxes(old_projects, args.dry_run)
|
||||
|
||||
# Print final summary
|
||||
print("\nSandbox Cleanup Summary:")
|
||||
print(f"Total projects older than {args.days} day(s): {len(old_projects)}")
|
||||
print(f"Total sandboxes processed: {len(old_projects)}")
|
||||
|
||||
if args.dry_run:
|
||||
print(f"DRY RUN: No sandboxes were actually archived")
|
||||
else:
|
||||
print(f"Successfully processed: {processed_count}")
|
||||
print(f"Failed to process: {failed_count}")
|
||||
|
||||
logger.info("Sandbox cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during sandbox cleanup: {str(e)}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Clean up database connection
|
||||
if db_connection:
|
||||
await DBConnection.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
|
@ -4,13 +4,15 @@ import { useEffect, useState } from 'react';
|
|||
import { SidebarLeft } from '@/components/sidebar/sidebar-left';
|
||||
import { SidebarInset, SidebarProvider } from '@/components/ui/sidebar';
|
||||
// import { PricingAlert } from "@/components/billing/pricing-alert"
|
||||
import { MaintenanceAlert } from '@/components/maintenance-alert';
|
||||
import { useAccounts } from '@/hooks/use-accounts';
|
||||
import { useAuth } from '@/components/AuthProvider';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { Loader2 } from 'lucide-react';
|
||||
import { checkApiHealth } from '@/lib/api';
|
||||
import { MaintenancePage } from '@/components/maintenance/maintenance-page';
|
||||
import { MaintenanceAlert } from "@/components/maintenance-alert"
|
||||
import { useAccounts } from "@/hooks/use-accounts"
|
||||
import { useAuth } from "@/components/AuthProvider"
|
||||
import { useRouter } from "next/navigation"
|
||||
import { Loader2 } from "lucide-react"
|
||||
import { checkApiHealth } from "@/lib/api"
|
||||
import { MaintenancePage } from "@/components/maintenance/maintenance-page"
|
||||
import { DeleteOperationProvider } from "@/contexts/DeleteOperationContext"
|
||||
import { StatusOverlay } from "@/components/ui/status-overlay"
|
||||
|
||||
interface DashboardLayoutProps {
|
||||
children: React.ReactNode;
|
||||
|
@ -78,24 +80,31 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
|||
}
|
||||
|
||||
return (
|
||||
<SidebarProvider>
|
||||
<SidebarLeft />
|
||||
<SidebarInset>
|
||||
<div className="bg-background">{children}</div>
|
||||
</SidebarInset>
|
||||
|
||||
{/* <PricingAlert
|
||||
open={showPricingAlert}
|
||||
onOpenChange={setShowPricingAlert}
|
||||
closeable={false}
|
||||
accountId={personalAccount?.account_id}
|
||||
/> */}
|
||||
|
||||
<MaintenanceAlert
|
||||
open={showMaintenanceAlert}
|
||||
onOpenChange={setShowMaintenanceAlert}
|
||||
closeable={true}
|
||||
/>
|
||||
</SidebarProvider>
|
||||
);
|
||||
<DeleteOperationProvider>
|
||||
<SidebarProvider>
|
||||
<SidebarLeft />
|
||||
<SidebarInset>
|
||||
<div className="bg-background">
|
||||
{children}
|
||||
</div>
|
||||
</SidebarInset>
|
||||
|
||||
{/* <PricingAlert
|
||||
open={showPricingAlert}
|
||||
onOpenChange={setShowPricingAlert}
|
||||
closeable={false}
|
||||
accountId={personalAccount?.account_id}
|
||||
/> */}
|
||||
|
||||
<MaintenanceAlert
|
||||
open={showMaintenanceAlert}
|
||||
onOpenChange={setShowMaintenanceAlert}
|
||||
closeable={true}
|
||||
/>
|
||||
|
||||
{/* Status overlay for deletion operations */}
|
||||
<StatusOverlay />
|
||||
</SidebarProvider>
|
||||
</DeleteOperationProvider>
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
'use client';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useEffect, useState, useRef } from "react"
|
||||
import {
|
||||
ArrowUpRight,
|
||||
Link as LinkIcon,
|
||||
|
@ -32,10 +32,12 @@ import {
|
|||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip';
|
||||
import { getProjects, getThreads, Project } from '@/lib/api';
|
||||
import Link from 'next/link';
|
||||
TooltipTrigger
|
||||
} from "@/components/ui/tooltip"
|
||||
import { getProjects, getThreads, Project, deleteThread } from "@/lib/api"
|
||||
import Link from "next/link"
|
||||
import { DeleteConfirmationDialog } from "@/components/thread/DeleteConfirmationDialog"
|
||||
import { useDeleteOperation } from '@/contexts/DeleteOperationContext'
|
||||
|
||||
// Thread with associated project info for display in sidebar
|
||||
type ThreadWithProject = {
|
||||
|
@ -47,12 +49,18 @@ type ThreadWithProject = {
|
|||
};
|
||||
|
||||
export function NavAgents() {
|
||||
const { isMobile, state } = useSidebar();
|
||||
const [threads, setThreads] = useState<ThreadWithProject[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadingThreadId, setLoadingThreadId] = useState<string | null>(null);
|
||||
const pathname = usePathname();
|
||||
const router = useRouter();
|
||||
const { isMobile, state } = useSidebar()
|
||||
const [threads, setThreads] = useState<ThreadWithProject[]>([])
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
const [loadingThreadId, setLoadingThreadId] = useState<string | null>(null)
|
||||
const pathname = usePathname()
|
||||
const router = useRouter()
|
||||
const [isDeleteDialogOpen, setIsDeleteDialogOpen] = useState(false)
|
||||
const [threadToDelete, setThreadToDelete] = useState<{ id: string; name: string } | null>(null)
|
||||
const [isDeleting, setIsDeleting] = useState(false)
|
||||
const isNavigatingRef = useRef(false)
|
||||
const { performDelete, isOperationInProgress } = useDeleteOperation();
|
||||
const isPerformingActionRef = useRef(false);
|
||||
|
||||
// Helper to sort threads by updated_at (most recent first)
|
||||
const sortThreads = (
|
||||
|
@ -198,15 +206,84 @@ export function NavAgents() {
|
|||
setLoadingThreadId(null);
|
||||
}, [pathname]);
|
||||
|
||||
// Add event handler for completed navigation
|
||||
useEffect(() => {
|
||||
const handleNavigationComplete = () => {
|
||||
console.log("NAVIGATION - Navigation event completed");
|
||||
document.body.style.pointerEvents = "auto";
|
||||
isNavigatingRef.current = false;
|
||||
};
|
||||
|
||||
window.addEventListener("popstate", handleNavigationComplete);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener("popstate", handleNavigationComplete);
|
||||
// Ensure we clean up any leftover styles
|
||||
document.body.style.pointerEvents = "auto";
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Reset isNavigatingRef when pathname changes
|
||||
useEffect(() => {
|
||||
isNavigatingRef.current = false;
|
||||
document.body.style.pointerEvents = "auto";
|
||||
}, [pathname]);
|
||||
|
||||
// Function to handle thread click with loading state
|
||||
const handleThreadClick = (
|
||||
e: React.MouseEvent<HTMLAnchorElement>,
|
||||
threadId: string,
|
||||
url: string,
|
||||
) => {
|
||||
e.preventDefault();
|
||||
setLoadingThreadId(threadId);
|
||||
router.push(url);
|
||||
const handleThreadClick = (e: React.MouseEvent<HTMLAnchorElement>, threadId: string, url: string) => {
|
||||
e.preventDefault()
|
||||
setLoadingThreadId(threadId)
|
||||
router.push(url)
|
||||
}
|
||||
|
||||
// Function to handle thread deletion
|
||||
const handleDeleteThread = async (threadId: string, threadName: string) => {
|
||||
setThreadToDelete({ id: threadId, name: threadName });
|
||||
setIsDeleteDialogOpen(true);
|
||||
};
|
||||
|
||||
const confirmDelete = async () => {
|
||||
if (!threadToDelete || isPerformingActionRef.current) return;
|
||||
|
||||
// Mark action in progress
|
||||
isPerformingActionRef.current = true;
|
||||
|
||||
// Close dialog first for immediate feedback
|
||||
setIsDeleteDialogOpen(false);
|
||||
|
||||
const threadId = threadToDelete.id;
|
||||
const isActive = pathname?.includes(threadId);
|
||||
|
||||
// Store threadToDelete in a local variable since it might be cleared
|
||||
const deletedThread = { ...threadToDelete };
|
||||
|
||||
// Log operation start
|
||||
console.log("DELETION - Starting thread deletion process", {
|
||||
threadId: deletedThread.id,
|
||||
isCurrentThread: isActive
|
||||
});
|
||||
|
||||
// Use the centralized deletion system with completion callback
|
||||
await performDelete(
|
||||
threadId,
|
||||
isActive,
|
||||
async () => {
|
||||
// Delete the thread
|
||||
await deleteThread(threadId);
|
||||
|
||||
// Update the thread list
|
||||
setThreads(prev => prev.filter(t => t.threadId !== threadId));
|
||||
|
||||
// Show success message
|
||||
toast.success("Conversation deleted successfully");
|
||||
},
|
||||
// Completion callback to reset local state
|
||||
() => {
|
||||
setThreadToDelete(null);
|
||||
setIsDeleting(false);
|
||||
isPerformingActionRef.current = false;
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -351,7 +428,7 @@ export function NavAgents() {
|
|||
</a>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={() => handleDeleteThread(thread.threadId, thread.projectName)}>
|
||||
<Trash2 className="text-muted-foreground" />
|
||||
<span>Delete</span>
|
||||
</DropdownMenuItem>
|
||||
|
@ -372,6 +449,16 @@ export function NavAgents() {
|
|||
</SidebarMenuItem>
|
||||
)}
|
||||
</SidebarMenu>
|
||||
|
||||
{threadToDelete && (
|
||||
<DeleteConfirmationDialog
|
||||
isOpen={isDeleteDialogOpen}
|
||||
onClose={() => setIsDeleteDialogOpen(false)}
|
||||
onConfirm={confirmDelete}
|
||||
threadName={threadToDelete.name}
|
||||
isDeleting={isDeleting}
|
||||
/>
|
||||
)}
|
||||
</SidebarGroup>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
"use client"
|
||||
|
||||
import React from "react"
|
||||
import { Loader2 } from "lucide-react"
|
||||
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogTitle,
|
||||
} from "@/components/ui/alert-dialog"
|
||||
|
||||
interface DeleteConfirmationDialogProps {
|
||||
isOpen: boolean
|
||||
onClose: () => void
|
||||
onConfirm: () => void
|
||||
threadName: string
|
||||
isDeleting: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirmation dialog for deleting a conversation
|
||||
*/
|
||||
export function DeleteConfirmationDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
onConfirm,
|
||||
threadName,
|
||||
isDeleting,
|
||||
}: DeleteConfirmationDialogProps) {
|
||||
return (
|
||||
<AlertDialog open={isOpen} onOpenChange={onClose}>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogTitle>Delete conversation</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
Are you sure you want to delete the conversation{" "}
|
||||
<span className="font-semibold">"{threadName}"</span>?
|
||||
<br />
|
||||
This action cannot be undone.
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel disabled={isDeleting}>Cancel</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={(e) => {
|
||||
e.preventDefault()
|
||||
onConfirm()
|
||||
}}
|
||||
disabled={isDeleting}
|
||||
className="bg-destructive text-white hover:bg-destructive/90"
|
||||
>
|
||||
{isDeleting ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Deleting...
|
||||
</>
|
||||
) : (
|
||||
"Delete"
|
||||
)}
|
||||
</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
)
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
import React from 'react';
|
||||
import { Loader2, CheckCircle, AlertCircle } from 'lucide-react';
|
||||
import { useDeleteOperation } from '@/contexts/DeleteOperationContext';
|
||||
|
||||
export function StatusOverlay() {
|
||||
const { state } = useDeleteOperation();
|
||||
|
||||
if (state.operation === 'none' || !state.isDeleting) return null;
|
||||
|
||||
return (
|
||||
<div className="fixed bottom-4 right-4 z-50 flex items-center gap-2 bg-background/90 backdrop-blur p-3 rounded-lg shadow-lg border border-border">
|
||||
{state.operation === 'pending' && (
|
||||
<>
|
||||
<Loader2 className="h-5 w-5 text-muted-foreground animate-spin" />
|
||||
<span className="text-sm">Processing...</span>
|
||||
</>
|
||||
)}
|
||||
|
||||
{state.operation === 'success' && (
|
||||
<>
|
||||
<CheckCircle className="h-5 w-5 text-green-500" />
|
||||
<span className="text-sm">Completed</span>
|
||||
</>
|
||||
)}
|
||||
|
||||
{state.operation === 'error' && (
|
||||
<>
|
||||
<AlertCircle className="h-5 w-5 text-destructive" />
|
||||
<span className="text-sm">Failed</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -0,0 +1,195 @@
|
|||
import React, { createContext, useContext, useReducer, useEffect, useRef } from 'react';
|
||||
|
||||
type DeleteState = {
|
||||
isDeleting: boolean;
|
||||
targetId: string | null;
|
||||
isActive: boolean;
|
||||
operation: 'none' | 'pending' | 'success' | 'error';
|
||||
};
|
||||
|
||||
type DeleteAction =
|
||||
| { type: 'START_DELETE'; id: string; isActive: boolean }
|
||||
| { type: 'DELETE_SUCCESS' }
|
||||
| { type: 'DELETE_ERROR' }
|
||||
| { type: 'RESET' };
|
||||
|
||||
const initialState: DeleteState = {
|
||||
isDeleting: false,
|
||||
targetId: null,
|
||||
isActive: false,
|
||||
operation: 'none'
|
||||
};
|
||||
|
||||
function deleteReducer(state: DeleteState, action: DeleteAction): DeleteState {
|
||||
switch (action.type) {
|
||||
case 'START_DELETE':
|
||||
return {
|
||||
...state,
|
||||
isDeleting: true,
|
||||
targetId: action.id,
|
||||
isActive: action.isActive,
|
||||
operation: 'pending'
|
||||
};
|
||||
case 'DELETE_SUCCESS':
|
||||
return {
|
||||
...state,
|
||||
operation: 'success'
|
||||
};
|
||||
case 'DELETE_ERROR':
|
||||
return {
|
||||
...state,
|
||||
isDeleting: false,
|
||||
operation: 'error'
|
||||
};
|
||||
case 'RESET':
|
||||
return initialState;
|
||||
default:
|
||||
return state;
|
||||
}
|
||||
}
|
||||
|
||||
type DeleteOperationContextType = {
|
||||
state: DeleteState;
|
||||
dispatch: React.Dispatch<DeleteAction>;
|
||||
performDelete: (
|
||||
id: string,
|
||||
isActive: boolean,
|
||||
deleteFunction: () => Promise<void>,
|
||||
onComplete?: () => void
|
||||
) => Promise<void>;
|
||||
isOperationInProgress: React.MutableRefObject<boolean>;
|
||||
};
|
||||
|
||||
const DeleteOperationContext = createContext<DeleteOperationContextType | undefined>(undefined);
|
||||
|
||||
export function DeleteOperationProvider({ children }: { children: React.ReactNode }) {
|
||||
const [state, dispatch] = useReducer(deleteReducer, initialState);
|
||||
const isOperationInProgress = useRef(false);
|
||||
|
||||
// Listen for state changes to handle navigation
|
||||
useEffect(() => {
|
||||
if (state.operation === 'success' && state.isActive) {
|
||||
// Delay navigation to allow UI feedback
|
||||
const timer = setTimeout(() => {
|
||||
try {
|
||||
// Use window.location for reliable navigation
|
||||
window.location.pathname = '/dashboard';
|
||||
} catch (error) {
|
||||
console.error("Navigation error:", error);
|
||||
}
|
||||
}, 500);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [state.operation, state.isActive]);
|
||||
|
||||
// Auto-reset after operations complete
|
||||
useEffect(() => {
|
||||
if (state.operation === 'success' && !state.isActive) {
|
||||
const timer = setTimeout(() => {
|
||||
dispatch({ type: 'RESET' });
|
||||
// Ensure pointer events are restored
|
||||
document.body.style.pointerEvents = "auto";
|
||||
isOperationInProgress.current = false;
|
||||
|
||||
// Restore sidebar menu interactivity
|
||||
const sidebarMenu = document.querySelector(".sidebar-menu");
|
||||
if (sidebarMenu) {
|
||||
sidebarMenu.classList.remove("pointer-events-none");
|
||||
}
|
||||
}, 1000);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
|
||||
if (state.operation === 'error') {
|
||||
// Reset on error immediately
|
||||
document.body.style.pointerEvents = "auto";
|
||||
isOperationInProgress.current = false;
|
||||
|
||||
// Restore sidebar menu interactivity
|
||||
const sidebarMenu = document.querySelector(".sidebar-menu");
|
||||
if (sidebarMenu) {
|
||||
sidebarMenu.classList.remove("pointer-events-none");
|
||||
}
|
||||
}
|
||||
}, [state.operation, state.isActive]);
|
||||
|
||||
const performDelete = async (
|
||||
id: string,
|
||||
isActive: boolean,
|
||||
deleteFunction: () => Promise<void>,
|
||||
onComplete?: () => void
|
||||
) => {
|
||||
// Prevent multiple operations
|
||||
if (isOperationInProgress.current) return;
|
||||
isOperationInProgress.current = true;
|
||||
|
||||
// Disable pointer events during operation
|
||||
document.body.style.pointerEvents = "none";
|
||||
|
||||
// Disable sidebar menu interactions
|
||||
const sidebarMenu = document.querySelector(".sidebar-menu");
|
||||
if (sidebarMenu) {
|
||||
sidebarMenu.classList.add("pointer-events-none");
|
||||
}
|
||||
|
||||
dispatch({ type: 'START_DELETE', id, isActive });
|
||||
|
||||
try {
|
||||
// Execute the delete operation
|
||||
await deleteFunction();
|
||||
|
||||
// Use precise timing for UI updates
|
||||
setTimeout(() => {
|
||||
dispatch({ type: 'DELETE_SUCCESS' });
|
||||
|
||||
// For non-active threads, restore interaction with delay
|
||||
if (!isActive) {
|
||||
setTimeout(() => {
|
||||
document.body.style.pointerEvents = "auto";
|
||||
|
||||
if (sidebarMenu) {
|
||||
sidebarMenu.classList.remove("pointer-events-none");
|
||||
}
|
||||
|
||||
// Call the completion callback
|
||||
if (onComplete) onComplete();
|
||||
}, 100);
|
||||
}
|
||||
}, 50);
|
||||
} catch (error) {
|
||||
console.error("Delete operation failed:", error);
|
||||
|
||||
// Reset states on error
|
||||
document.body.style.pointerEvents = "auto";
|
||||
isOperationInProgress.current = false;
|
||||
|
||||
if (sidebarMenu) {
|
||||
sidebarMenu.classList.remove("pointer-events-none");
|
||||
}
|
||||
|
||||
dispatch({ type: 'DELETE_ERROR' });
|
||||
|
||||
// Call the completion callback
|
||||
if (onComplete) onComplete();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<DeleteOperationContext.Provider value={{
|
||||
state,
|
||||
dispatch,
|
||||
performDelete,
|
||||
isOperationInProgress
|
||||
}}>
|
||||
{children}
|
||||
</DeleteOperationContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useDeleteOperation() {
|
||||
const context = useContext(DeleteOperationContext);
|
||||
if (context === undefined) {
|
||||
throw new Error('useDeleteOperation must be used within a DeleteOperationProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
|
@ -1244,6 +1244,53 @@ export const toggleThreadPublicStatus = async (
|
|||
return updateThread(threadId, { is_public: isPublic });
|
||||
};
|
||||
|
||||
export const deleteThread = async (threadId: string): Promise<void> => {
|
||||
try {
|
||||
const supabase = createClient();
|
||||
|
||||
// First delete all agent runs associated with this thread
|
||||
console.log(`Deleting all agent runs for thread ${threadId}`);
|
||||
const { error: agentRunsError } = await supabase
|
||||
.from('agent_runs')
|
||||
.delete()
|
||||
.eq('thread_id', threadId);
|
||||
|
||||
if (agentRunsError) {
|
||||
console.error('Error deleting agent runs:', agentRunsError);
|
||||
throw new Error(`Error deleting agent runs: ${agentRunsError.message}`);
|
||||
}
|
||||
|
||||
// Then delete all messages associated with the thread
|
||||
console.log(`Deleting all messages for thread ${threadId}`);
|
||||
const { error: messagesError } = await supabase
|
||||
.from('messages')
|
||||
.delete()
|
||||
.eq('thread_id', threadId);
|
||||
|
||||
if (messagesError) {
|
||||
console.error('Error deleting messages:', messagesError);
|
||||
throw new Error(`Error deleting messages: ${messagesError.message}`);
|
||||
}
|
||||
|
||||
// Finally, delete the thread itself
|
||||
console.log(`Deleting thread ${threadId}`);
|
||||
const { error: threadError } = await supabase
|
||||
.from('threads')
|
||||
.delete()
|
||||
.eq('thread_id', threadId);
|
||||
|
||||
if (threadError) {
|
||||
console.error('Error deleting thread:', threadError);
|
||||
throw new Error(`Error deleting thread: ${threadError.message}`);
|
||||
}
|
||||
|
||||
console.log(`Thread ${threadId} successfully deleted with all related items`);
|
||||
} catch (error) {
|
||||
console.error('Error deleting thread and related items:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
// Function to get public projects
|
||||
export const getPublicProjects = async (): Promise<Project[]> => {
|
||||
try {
|
||||
|
|
Loading…
Reference in New Issue