mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
d850800a5f
commit
4bbc03f674
|
@ -32,7 +32,6 @@ instance_id = None # Global instance ID for this backend instance
|
|||
# TTL for Redis response lists (24 hours)
|
||||
REDIS_RESPONSE_LIST_TTL = 3600 * 24
|
||||
|
||||
|
||||
class AgentStartRequest(BaseModel):
|
||||
model_name: Optional[str] = None # Will be set from config.MODEL_TO_USE in the endpoint
|
||||
enable_thinking: Optional[bool] = False
|
||||
|
@ -1798,6 +1797,7 @@ async def delete_agent(agent_id: str, user_id: str = Depends(get_current_user_id
|
|||
logger.error(f"Error deleting agent {agent_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
# Marketplace Models
|
||||
class MarketplaceAgent(BaseModel):
|
||||
agent_id: str
|
||||
|
|
|
@ -603,285 +603,4 @@ async def run_agent(
|
|||
break
|
||||
generation.end(output=full_response)
|
||||
|
||||
langfuse.flush() # Flush Langfuse events at the end of the run
|
||||
|
||||
|
||||
|
||||
# # TESTING
|
||||
|
||||
# async def test_agent():
|
||||
# """Test function to run the agent with a sample query"""
|
||||
# from agentpress.thread_manager import ThreadManager
|
||||
# from services.supabase import DBConnection
|
||||
|
||||
# # Initialize ThreadManager
|
||||
# thread_manager = ThreadManager()
|
||||
|
||||
# # Create a test thread directly with Postgres function
|
||||
# client = await DBConnection().client
|
||||
|
||||
# try:
|
||||
# # Get user's personal account
|
||||
# account_result = await client.rpc('get_personal_account').execute()
|
||||
|
||||
# # if not account_result.data:
|
||||
# # print("Error: No personal account found")
|
||||
# # return
|
||||
|
||||
# account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59"
|
||||
|
||||
# if not account_id:
|
||||
# print("Error: Could not get account ID")
|
||||
# return
|
||||
|
||||
# # Find or create a test project in the user's account
|
||||
# project_result = await client.table('projects').select('*').eq('name', 'test11').eq('account_id', account_id).execute()
|
||||
|
||||
# if project_result.data and len(project_result.data) > 0:
|
||||
# # Use existing test project
|
||||
# project_id = project_result.data[0]['project_id']
|
||||
# print(f"\n🔄 Using existing test project: {project_id}")
|
||||
# else:
|
||||
# # Create new test project if none exists
|
||||
# project_result = await client.table('projects').insert({
|
||||
# "name": "test11",
|
||||
# "account_id": account_id
|
||||
# }).execute()
|
||||
# project_id = project_result.data[0]['project_id']
|
||||
# print(f"\n✨ Created new test project: {project_id}")
|
||||
|
||||
# # Create a thread for this project
|
||||
# thread_result = await client.table('threads').insert({
|
||||
# 'project_id': project_id,
|
||||
# 'account_id': account_id
|
||||
# }).execute()
|
||||
# thread_data = thread_result.data[0] if thread_result.data else None
|
||||
|
||||
# if not thread_data:
|
||||
# print("Error: No thread data returned")
|
||||
# return
|
||||
|
||||
# thread_id = thread_data['thread_id']
|
||||
# except Exception as e:
|
||||
# print(f"Error setting up thread: {str(e)}")
|
||||
# return
|
||||
|
||||
# print(f"\n🤖 Agent Thread Created: {thread_id}\n")
|
||||
|
||||
# # Interactive message input loop
|
||||
# while True:
|
||||
# # Get user input
|
||||
# user_message = input("\n💬 Enter your message (or 'exit' to quit): ")
|
||||
# if user_message.lower() == 'exit':
|
||||
# break
|
||||
|
||||
# if not user_message.strip():
|
||||
# print("\n🔄 Running agent...\n")
|
||||
# await process_agent_response(thread_id, project_id, thread_manager)
|
||||
# continue
|
||||
|
||||
# # Add the user message to the thread
|
||||
# await thread_manager.add_message(
|
||||
# thread_id=thread_id,
|
||||
# type="user",
|
||||
# content={
|
||||
# "role": "user",
|
||||
# "content": user_message
|
||||
# },
|
||||
# is_llm_message=True
|
||||
# )
|
||||
|
||||
# print("\n🔄 Running agent...\n")
|
||||
# await process_agent_response(thread_id, project_id, thread_manager)
|
||||
|
||||
# print("\n👋 Test completed. Goodbye!")
|
||||
|
||||
# async def process_agent_response(
|
||||
# thread_id: str,
|
||||
# project_id: str,
|
||||
# thread_manager: ThreadManager,
|
||||
# stream: bool = True,
|
||||
# model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
||||
# enable_thinking: Optional[bool] = False,
|
||||
# reasoning_effort: Optional[str] = 'low',
|
||||
# enable_context_manager: bool = True
|
||||
# ):
|
||||
# """Process the streaming response from the agent."""
|
||||
# chunk_counter = 0
|
||||
# current_response = ""
|
||||
# tool_usage_counter = 0 # Renamed from tool_call_counter as we track usage via status
|
||||
|
||||
# # Create a test sandbox for processing with a unique test prefix to avoid conflicts with production sandboxes
|
||||
# # sandbox_pass = str(uuid4())
|
||||
# # sandbox = await create_sandbox(sandbox_pass)
|
||||
|
||||
# # Store the original ID so we can refer to it
|
||||
# # original_sandbox_id = sandbox.id
|
||||
|
||||
# # Generate a clear test identifier
|
||||
# # test_prefix = f"test_{uuid4().hex[:8]}_"
|
||||
# logger.info(f"Created test sandbox with ID {original_sandbox_id} and test prefix {test_prefix}")
|
||||
|
||||
# # Log the sandbox URL for debugging
|
||||
# # print(f"\033[91mTest sandbox created: {str(await sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
||||
|
||||
# async for chunk in run_agent(
|
||||
# thread_id=thread_id,
|
||||
# project_id=project_id,
|
||||
# sandbox=sandbox,
|
||||
# stream=stream,
|
||||
# thread_manager=thread_manager,
|
||||
# native_max_auto_continues=25,
|
||||
# model_name=model_name,
|
||||
# enable_thinking=enable_thinking,
|
||||
# reasoning_effort=reasoning_effort,
|
||||
# enable_context_manager=enable_context_manager
|
||||
# ):
|
||||
# chunk_counter += 1
|
||||
# # print(f"CHUNK: {chunk}") # Uncomment for debugging
|
||||
|
||||
# if chunk.get('type') == 'assistant':
|
||||
# # Try parsing the content JSON
|
||||
# try:
|
||||
# # Handle content as string or object
|
||||
# content = chunk.get('content', '{}')
|
||||
# if isinstance(content, str):
|
||||
# content_json = json.loads(content)
|
||||
# else:
|
||||
# content_json = content
|
||||
|
||||
# actual_content = content_json.get('content', '')
|
||||
# # Print the actual assistant text content as it comes
|
||||
# if actual_content:
|
||||
# # Check if it contains XML tool tags, if so, print the whole tag for context
|
||||
# if '<' in actual_content and '>' in actual_content:
|
||||
# # Avoid printing potentially huge raw content if it's not just text
|
||||
# if len(actual_content) < 500: # Heuristic limit
|
||||
# print(actual_content, end='', flush=True)
|
||||
# else:
|
||||
# # Maybe just print a summary if it's too long or contains complex XML
|
||||
# if '</ask>' in actual_content: print("<ask>...</ask>", end='', flush=True)
|
||||
# elif '</complete>' in actual_content: print("<complete>...</complete>", end='', flush=True)
|
||||
# else: print("<tool_call>...</tool_call>", end='', flush=True) # Generic case
|
||||
# else:
|
||||
# # Regular text content
|
||||
# print(actual_content, end='', flush=True)
|
||||
# current_response += actual_content # Accumulate only text part
|
||||
# except json.JSONDecodeError:
|
||||
# # If content is not JSON (e.g., just a string chunk), print directly
|
||||
# raw_content = chunk.get('content', '')
|
||||
# print(raw_content, end='', flush=True)
|
||||
# current_response += raw_content
|
||||
# except Exception as e:
|
||||
# print(f"\nError processing assistant chunk: {e}\n")
|
||||
|
||||
# elif chunk.get('type') == 'tool': # Updated from 'tool_result'
|
||||
# # Add timestamp and format tool result nicely
|
||||
# tool_name = "UnknownTool" # Try to get from metadata if available
|
||||
# result_content = "No content"
|
||||
|
||||
# # Parse metadata - handle both string and dict formats
|
||||
# metadata = chunk.get('metadata', {})
|
||||
# if isinstance(metadata, str):
|
||||
# try:
|
||||
# metadata = json.loads(metadata)
|
||||
# except json.JSONDecodeError:
|
||||
# metadata = {}
|
||||
|
||||
# linked_assistant_msg_id = metadata.get('assistant_message_id')
|
||||
# parsing_details = metadata.get('parsing_details')
|
||||
# if parsing_details:
|
||||
# tool_name = parsing_details.get('xml_tag_name', 'UnknownTool') # Get name from parsing details
|
||||
|
||||
# try:
|
||||
# # Content is a JSON string or object
|
||||
# content = chunk.get('content', '{}')
|
||||
# if isinstance(content, str):
|
||||
# content_json = json.loads(content)
|
||||
# else:
|
||||
# content_json = content
|
||||
|
||||
# # The actual tool result is nested inside content.content
|
||||
# tool_result_str = content_json.get('content', '')
|
||||
# # Extract the actual tool result string (remove outer <tool_result> tag if present)
|
||||
# match = re.search(rf'<{tool_name}>(.*?)</{tool_name}>', tool_result_str, re.DOTALL)
|
||||
# if match:
|
||||
# result_content = match.group(1).strip()
|
||||
# # Try to parse the result string itself as JSON for pretty printing
|
||||
# try:
|
||||
# result_obj = json.loads(result_content)
|
||||
# result_content = json.dumps(result_obj, indent=2)
|
||||
# except json.JSONDecodeError:
|
||||
# # Keep as string if not JSON
|
||||
# pass
|
||||
# else:
|
||||
# # Fallback if tag extraction fails
|
||||
# result_content = tool_result_str
|
||||
|
||||
# except json.JSONDecodeError:
|
||||
# result_content = chunk.get('content', 'Error parsing tool content')
|
||||
# except Exception as e:
|
||||
# result_content = f"Error processing tool chunk: {e}"
|
||||
|
||||
# print(f"\n\n🛠️ TOOL RESULT [{tool_name}] → {result_content}")
|
||||
|
||||
# elif chunk.get('type') == 'status':
|
||||
# # Log tool status changes
|
||||
# try:
|
||||
# # Handle content as string or object
|
||||
# status_content = chunk.get('content', '{}')
|
||||
# if isinstance(status_content, str):
|
||||
# status_content = json.loads(status_content)
|
||||
|
||||
# status_type = status_content.get('status_type')
|
||||
# function_name = status_content.get('function_name', '')
|
||||
# xml_tag_name = status_content.get('xml_tag_name', '') # Get XML tag if available
|
||||
# tool_name = xml_tag_name or function_name # Prefer XML tag name
|
||||
|
||||
# if status_type == 'tool_started' and tool_name:
|
||||
# tool_usage_counter += 1
|
||||
# print(f"\n⏳ TOOL STARTING #{tool_usage_counter} [{tool_name}]")
|
||||
# print(" " + "-" * 40)
|
||||
# # Return to the current content display
|
||||
# if current_response:
|
||||
# print("\nContinuing response:", flush=True)
|
||||
# print(current_response, end='', flush=True)
|
||||
# elif status_type == 'tool_completed' and tool_name:
|
||||
# status_emoji = "✅"
|
||||
# print(f"\n{status_emoji} TOOL COMPLETED: {tool_name}")
|
||||
# elif status_type == 'finish':
|
||||
# finish_reason = status_content.get('finish_reason', '')
|
||||
# if finish_reason:
|
||||
# print(f"\n📌 Finished: {finish_reason}")
|
||||
# # else: # Print other status types if needed for debugging
|
||||
# # print(f"\nℹ️ STATUS: {chunk.get('content')}")
|
||||
|
||||
# except json.JSONDecodeError:
|
||||
# print(f"\nWarning: Could not parse status content JSON: {chunk.get('content')}")
|
||||
# except Exception as e:
|
||||
# print(f"\nError processing status chunk: {e}")
|
||||
|
||||
|
||||
# # Removed elif chunk.get('type') == 'tool_call': block
|
||||
|
||||
# # Update final message
|
||||
# print(f"\n\n✅ Agent run completed with {tool_usage_counter} tool executions")
|
||||
|
||||
# # Try to clean up the test sandbox if possible
|
||||
# try:
|
||||
# # Attempt to delete/archive the sandbox to clean up resources
|
||||
# # Note: Actual deletion may depend on the Daytona SDK's capabilities
|
||||
# logger.info(f"Attempting to clean up test sandbox {original_sandbox_id}")
|
||||
# # If there's a method to archive/delete the sandbox, call it here
|
||||
# # Example: daytona.archive_sandbox(sandbox.id)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to clean up test sandbox {original_sandbox_id}: {str(e)}")
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
|
||||
# # Configure any environment variables or setup needed for testing
|
||||
# load_dotenv() # Ensure environment variables are loaded
|
||||
|
||||
# # Run the test function
|
||||
# asyncio.run(test_agent())
|
||||
langfuse.flush()
|
|
@ -9,15 +9,10 @@ import json
|
|||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from litellm.utils import token_counter
|
||||
from litellm.cost_calculator import completion_cost
|
||||
from services.supabase import DBConnection
|
||||
from services.llm import make_llm_api_call
|
||||
from utils.logger import logger
|
||||
|
||||
# Constants for token management
|
||||
DEFAULT_TOKEN_THRESHOLD = 120000 # 80k tokens threshold for summarization
|
||||
SUMMARY_TARGET_TOKENS = 10000 # Target ~10k tokens for the summary message
|
||||
RESERVE_TOKENS = 5000 # Reserve tokens for new messages
|
||||
DEFAULT_TOKEN_THRESHOLD = 120000
|
||||
|
||||
class ContextManager:
|
||||
"""Manages thread context including token counting and summarization."""
|
||||
|
|
|
@ -126,12 +126,17 @@ async def log_requests_middleware(request: Request, call_next):
|
|||
raise
|
||||
|
||||
# Define allowed origins based on environment
|
||||
allowed_origins = ["https://www.suna.so", "https://suna.so", "http://localhost:3000"]
|
||||
allowed_origins = ["https://www.suna.so", "https://suna.so"]
|
||||
allow_origin_regex = None
|
||||
|
||||
# Add staging-specific origins
|
||||
if config.ENV_MODE == EnvMode.LOCAL:
|
||||
allowed_origins.append("http://localhost:3000")
|
||||
|
||||
# Add staging-specific origins
|
||||
if config.ENV_MODE == EnvMode.STAGING:
|
||||
allowed_origins.append("https://staging.suna.so")
|
||||
allowed_origins.append("http://localhost:3000")
|
||||
allow_origin_regex = r"https://suna-.*-prjcts\.vercel\.app"
|
||||
|
||||
app.add_middleware(
|
||||
|
|
|
@ -14,7 +14,7 @@ from services.supabase import DBConnection
|
|||
from utils.auth_utils import get_current_user_id_from_jwt
|
||||
from pydantic import BaseModel
|
||||
from utils.constants import MODEL_ACCESS_TIERS, MODEL_NAME_ALIASES
|
||||
from litellm import cost_per_token, model_cost
|
||||
from litellm import cost_per_token
|
||||
import time
|
||||
|
||||
# Initialize Stripe
|
||||
|
|
Loading…
Reference in New Issue