mirror of https://github.com/kortix-ai/suna.git
porject_id and migration
This commit is contained in:
parent
1987da720c
commit
4a29872ceb
|
@ -259,7 +259,7 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
||||||
|
|
||||||
# Run the agent in the background
|
# Run the agent in the background
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
run_agent_background(agent_run_id, thread_id, instance_id)
|
run_agent_background(agent_run_id, thread_id, instance_id, project_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set a callback to clean up when task is done
|
# Set a callback to clean up when task is done
|
||||||
|
@ -389,7 +389,7 @@ async def stream_agent_run(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str):
|
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str):
|
||||||
"""Run the agent in the background and handle status updates."""
|
"""Run the agent in the background and handle status updates."""
|
||||||
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
@ -510,7 +510,7 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
||||||
# Run the agent
|
# Run the agent
|
||||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||||
agent_gen = run_agent(thread_id, stream=True,
|
agent_gen = run_agent(thread_id, stream=True,
|
||||||
thread_manager=thread_manager)
|
thread_manager=thread_manager, project_id=project_id)
|
||||||
|
|
||||||
# Collect all responses to save to database
|
# Collect all responses to save to database
|
||||||
all_responses = []
|
all_responses = []
|
||||||
|
|
|
@ -81,4 +81,5 @@ def get_system_prompt():
|
||||||
'''
|
'''
|
||||||
Returns the system prompt with XML tool usage instructions.
|
Returns the system prompt with XML tool usage instructions.
|
||||||
'''
|
'''
|
||||||
return SYSTEM_PROMPT + RESPONSE_FORMAT
|
# return SYSTEM_PROMPT + RESPONSE_FORMAT
|
||||||
|
return SYSTEM_PROMPT
|
|
@ -1,42 +1,54 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
from uuid import uuid4
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from agentpress.thread_manager import ThreadManager
|
from agentpress.thread_manager import ThreadManager
|
||||||
|
from agentpress.response_processor import ProcessorConfig
|
||||||
from agent.tools.sb_browse_tool import SandboxBrowseTool
|
from agent.tools.sb_browse_tool import SandboxBrowseTool
|
||||||
from agent.tools.sb_shell_tool import SandboxShellTool
|
from agent.tools.sb_shell_tool import SandboxShellTool
|
||||||
from agent.tools.sb_website_tool import SandboxWebsiteTool
|
from agent.tools.sb_website_tool import SandboxWebsiteTool
|
||||||
from agent.tools.sb_files_tool import SandboxFilesTool
|
from agent.tools.sb_files_tool import SandboxFilesTool
|
||||||
from typing import Optional
|
|
||||||
from agent.prompt import get_system_prompt
|
from agent.prompt import get_system_prompt
|
||||||
from agentpress.response_processor import ProcessorConfig
|
from agent.tools.utils.daytona_sandbox import daytona, create_sandbox
|
||||||
from dotenv import load_dotenv
|
from daytona_api_client.models.workspace_state import WorkspaceState
|
||||||
from agent.tools.utils.daytona_sandbox import create_sandbox
|
|
||||||
|
|
||||||
# Load environment variables
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
async def run_agent(thread_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25):
|
async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25):
|
||||||
"""Run the development agent with specified configuration."""
|
"""Run the development agent with specified configuration."""
|
||||||
|
|
||||||
if not thread_manager:
|
if not thread_manager:
|
||||||
thread_manager = ThreadManager()
|
thread_manager = ThreadManager()
|
||||||
|
|
||||||
if True: # todo: change to of not sandbox running
|
client = await thread_manager.db.client
|
||||||
sandbox = create_sandbox("vvv")
|
## probably want to move to api.py
|
||||||
sandbox_id = sandbox.id
|
project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||||
sandbox_password = "vvv"
|
if project.data[0]['sandbox_id']:
|
||||||
|
sandbox_id = project.data[0]['sandbox_id']
|
||||||
|
sandbox_pass = project.data[0]['sandbox_pass']
|
||||||
|
sandbox = daytona.get_current_sandbox(sandbox_id)
|
||||||
|
if sandbox.instance.state == WorkspaceState.ARCHIVED or sandbox.instance.state == WorkspaceState.STOPPED:
|
||||||
|
try:
|
||||||
|
daytona.start(sandbox)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error starting sandbox: {e}")
|
||||||
|
raise e
|
||||||
else:
|
else:
|
||||||
sandbox_id = "sandbox-01efaaa5"
|
sandbox_pass = str(uuid4())
|
||||||
sandbox_password = "vvv"
|
sandbox = create_sandbox(sandbox_pass)
|
||||||
|
sandbox_id = sandbox.id
|
||||||
|
await client.table('projects').update({
|
||||||
|
'sandbox_id': sandbox_id,
|
||||||
|
'sandbox_pass': sandbox_pass
|
||||||
|
}).eq('project_id', project_id).execute()
|
||||||
|
### ---
|
||||||
|
|
||||||
print("Adding tools to thread manager...")
|
print("Adding tools to thread manager...")
|
||||||
# thread_manager.add_tool(FilesTool)
|
thread_manager.add_tool(SandboxBrowseTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||||
# thread_manager.add_tool(TerminalTool)
|
thread_manager.add_tool(SandboxWebsiteTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||||
# thread_manager.add_tool(CodeSearchTool)
|
thread_manager.add_tool(SandboxShellTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||||
thread_manager.add_tool(SandboxBrowseTool, sandbox_id=sandbox_id, password=sandbox_password)
|
thread_manager.add_tool(SandboxFilesTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||||
thread_manager.add_tool(SandboxWebsiteTool, sandbox_id=sandbox_id, password=sandbox_password)
|
|
||||||
thread_manager.add_tool(SandboxShellTool, sandbox_id=sandbox_id, password=sandbox_password)
|
|
||||||
thread_manager.add_tool(SandboxFilesTool, sandbox_id=sandbox_id, password=sandbox_password)
|
|
||||||
|
|
||||||
system_message = { "role": "system", "content": get_system_prompt() }
|
system_message = { "role": "system", "content": get_system_prompt() }
|
||||||
|
|
||||||
|
@ -49,7 +61,7 @@ async def run_agent(thread_id: str, stream: bool = True, thread_manager: Optiona
|
||||||
#groq/deepseek-r1-distill-llama-70b
|
#groq/deepseek-r1-distill-llama-70b
|
||||||
#bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0
|
#bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||||
|
|
||||||
files_tool = SandboxFilesTool(sandbox_id=sandbox_id, password=sandbox_password)
|
files_tool = SandboxFilesTool(sandbox_id=sandbox_id, password=sandbox_pass)
|
||||||
|
|
||||||
files_state = await files_tool.get_workspace_state()
|
files_state = await files_tool.get_workspace_state()
|
||||||
|
|
||||||
|
@ -74,8 +86,8 @@ Current development environment workspace state:
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
max_xml_tool_calls=1,
|
max_xml_tool_calls=1,
|
||||||
processor_config=ProcessorConfig(
|
processor_config=ProcessorConfig(
|
||||||
xml_tool_calling=True,
|
xml_tool_calling=False,
|
||||||
native_tool_calling=False,
|
native_tool_calling=True,
|
||||||
execute_tools=True,
|
execute_tools=True,
|
||||||
execute_on_stream=True,
|
execute_on_stream=True,
|
||||||
tool_execution_strategy="parallel",
|
tool_execution_strategy="parallel",
|
||||||
|
@ -103,7 +115,8 @@ async def test_agent():
|
||||||
client = await DBConnection().client
|
client = await DBConnection().client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
thread_result = await client.table('threads').insert({}).execute()
|
thread_result = await client.table('projects').insert({"name": "test", "user_id": "68e1da55-0749-49db-937a-ff56bf0269a0"}).execute()
|
||||||
|
thread_result = await client.table('threads').insert({'project_id': thread_result.data[0]['project_id']}).execute()
|
||||||
thread_data = thread_result.data[0] if thread_result.data else None
|
thread_data = thread_result.data[0] if thread_result.data else None
|
||||||
|
|
||||||
if not thread_data:
|
if not thread_data:
|
||||||
|
@ -111,6 +124,7 @@ async def test_agent():
|
||||||
return
|
return
|
||||||
|
|
||||||
thread_id = thread_data['thread_id']
|
thread_id = thread_data['thread_id']
|
||||||
|
project_id = thread_data['project_id']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating thread: {str(e)}")
|
print(f"Error creating thread: {str(e)}")
|
||||||
return
|
return
|
||||||
|
@ -126,7 +140,7 @@ async def test_agent():
|
||||||
|
|
||||||
if not user_message.strip():
|
if not user_message.strip():
|
||||||
print("\n🔄 Running agent...\n")
|
print("\n🔄 Running agent...\n")
|
||||||
await process_agent_response(thread_id, thread_manager)
|
await process_agent_response(thread_id, project_id, thread_manager)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add the user message to the thread
|
# Add the user message to the thread
|
||||||
|
@ -141,17 +155,17 @@ async def test_agent():
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n🔄 Running agent...\n")
|
print("\n🔄 Running agent...\n")
|
||||||
await process_agent_response(thread_id, thread_manager)
|
await process_agent_response(thread_id, project_id, thread_manager)
|
||||||
|
|
||||||
print("\n👋 Test completed. Goodbye!")
|
print("\n👋 Test completed. Goodbye!")
|
||||||
|
|
||||||
async def process_agent_response(thread_id: str, thread_manager: ThreadManager):
|
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
|
||||||
"""Process the streaming response from the agent."""
|
"""Process the streaming response from the agent."""
|
||||||
chunk_counter = 0
|
chunk_counter = 0
|
||||||
current_response = ""
|
current_response = ""
|
||||||
tool_call_counter = 0 # Track number of tool calls
|
tool_call_counter = 0 # Track number of tool calls
|
||||||
|
|
||||||
async for chunk in run_agent(thread_id=thread_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
||||||
chunk_counter += 1
|
chunk_counter += 1
|
||||||
|
|
||||||
if chunk.get('type') == 'content' and 'content' in chunk:
|
if chunk.get('type') == 'content' and 'content' in chunk:
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
revoke delete on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke insert on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke references on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke select on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke trigger on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke truncate on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke update on table "public"."agent_runs" from "anon";
|
||||||
|
|
||||||
|
revoke delete on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke insert on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke references on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke select on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke trigger on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke truncate on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke update on table "public"."messages" from "anon";
|
||||||
|
|
||||||
|
revoke delete on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke insert on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke references on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke select on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke trigger on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke truncate on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke update on table "public"."projects" from "anon";
|
||||||
|
|
||||||
|
revoke delete on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
revoke insert on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
revoke references on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
revoke select on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
revoke trigger on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
revoke truncate on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
revoke update on table "public"."threads" from "anon";
|
||||||
|
|
||||||
|
alter table "public"."projects" add column "sandbox_id" text;
|
||||||
|
|
||||||
|
alter table "public"."projects" add column "sandbox_pass" text;
|
||||||
|
|
||||||
|
set check_function_bodies = off;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION public.get_llm_formatted_messages(p_thread_id uuid)
|
||||||
|
RETURNS jsonb
|
||||||
|
LANGUAGE plpgsql
|
||||||
|
AS $function$
|
||||||
|
DECLARE
|
||||||
|
messages_array JSONB := '[]'::JSONB;
|
||||||
|
BEGIN
|
||||||
|
-- Check if thread exists
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1 FROM threads t
|
||||||
|
WHERE t.thread_id = p_thread_id
|
||||||
|
) THEN
|
||||||
|
RAISE EXCEPTION 'Thread not found';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Parse content if it's stored as a string and return proper JSON objects
|
||||||
|
WITH parsed_messages AS (
|
||||||
|
SELECT
|
||||||
|
CASE
|
||||||
|
WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb
|
||||||
|
ELSE content
|
||||||
|
END AS parsed_content,
|
||||||
|
created_at
|
||||||
|
FROM messages
|
||||||
|
WHERE thread_id = p_thread_id
|
||||||
|
AND is_llm_message = TRUE
|
||||||
|
),
|
||||||
|
-- Process each message to ensure tool_calls function arguments are strings
|
||||||
|
processed_messages AS (
|
||||||
|
SELECT
|
||||||
|
CASE
|
||||||
|
-- When the message has tool_calls
|
||||||
|
WHEN jsonb_path_exists(parsed_content, '$.tool_calls') THEN
|
||||||
|
(
|
||||||
|
WITH tool_calls AS (
|
||||||
|
-- Extract and process each tool call
|
||||||
|
SELECT
|
||||||
|
jsonb_array_elements(parsed_content -> 'tool_calls') AS tool_call,
|
||||||
|
i AS idx
|
||||||
|
FROM generate_series(0, jsonb_array_length(parsed_content -> 'tool_calls') - 1) AS i
|
||||||
|
),
|
||||||
|
processed_tool_calls AS (
|
||||||
|
SELECT
|
||||||
|
idx,
|
||||||
|
CASE
|
||||||
|
-- If function arguments exist and is not a string, convert to JSON string
|
||||||
|
WHEN jsonb_path_exists(tool_call, '$.function.arguments')
|
||||||
|
AND jsonb_typeof(tool_call #> '{function,arguments}') != 'string' THEN
|
||||||
|
jsonb_set(
|
||||||
|
tool_call,
|
||||||
|
'{function,arguments}',
|
||||||
|
to_jsonb(tool_call #>> '{function,arguments}')
|
||||||
|
)
|
||||||
|
ELSE tool_call
|
||||||
|
END AS processed_tool_call
|
||||||
|
FROM tool_calls
|
||||||
|
),
|
||||||
|
-- Convert processed tool calls back to an array
|
||||||
|
tool_calls_array AS (
|
||||||
|
SELECT jsonb_agg(processed_tool_call ORDER BY idx) AS tool_calls_array
|
||||||
|
FROM processed_tool_calls
|
||||||
|
)
|
||||||
|
-- Replace tool_calls in the original message
|
||||||
|
SELECT jsonb_set(parsed_content, '{tool_calls}', tool_calls_array)
|
||||||
|
FROM tool_calls_array
|
||||||
|
)
|
||||||
|
ELSE parsed_content
|
||||||
|
END AS final_content,
|
||||||
|
created_at
|
||||||
|
FROM parsed_messages
|
||||||
|
)
|
||||||
|
-- Aggregate messages into an array
|
||||||
|
SELECT JSONB_AGG(final_content ORDER BY created_at)
|
||||||
|
INTO messages_array
|
||||||
|
FROM processed_messages;
|
||||||
|
|
||||||
|
-- Handle the case when no messages are found
|
||||||
|
IF messages_array IS NULL THEN
|
||||||
|
RETURN '[]'::JSONB;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
RETURN messages_array;
|
||||||
|
END;
|
||||||
|
$function$
|
||||||
|
;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue