mirror of https://github.com/kortix-ai/suna.git
porject_id and migration
This commit is contained in:
parent
1987da720c
commit
4a29872ceb
|
@ -259,7 +259,7 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
|||
|
||||
# Run the agent in the background
|
||||
task = asyncio.create_task(
|
||||
run_agent_background(agent_run_id, thread_id, instance_id)
|
||||
run_agent_background(agent_run_id, thread_id, instance_id, project_id)
|
||||
)
|
||||
|
||||
# Set a callback to clean up when task is done
|
||||
|
@ -389,7 +389,7 @@ async def stream_agent_run(
|
|||
}
|
||||
)
|
||||
|
||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str):
|
||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str):
|
||||
"""Run the agent in the background and handle status updates."""
|
||||
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
||||
client = await db.client
|
||||
|
@ -510,7 +510,7 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
|||
# Run the agent
|
||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||
agent_gen = run_agent(thread_id, stream=True,
|
||||
thread_manager=thread_manager)
|
||||
thread_manager=thread_manager, project_id=project_id)
|
||||
|
||||
# Collect all responses to save to database
|
||||
all_responses = []
|
||||
|
|
|
@ -81,4 +81,5 @@ def get_system_prompt():
|
|||
'''
|
||||
Returns the system prompt with XML tool usage instructions.
|
||||
'''
|
||||
return SYSTEM_PROMPT + RESPONSE_FORMAT
|
||||
# return SYSTEM_PROMPT + RESPONSE_FORMAT
|
||||
return SYSTEM_PROMPT
|
|
@ -1,42 +1,54 @@
|
|||
import os
|
||||
import json
|
||||
import uuid
|
||||
from uuid import uuid4
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from agentpress.response_processor import ProcessorConfig
|
||||
from agent.tools.sb_browse_tool import SandboxBrowseTool
|
||||
from agent.tools.sb_shell_tool import SandboxShellTool
|
||||
from agent.tools.sb_website_tool import SandboxWebsiteTool
|
||||
from agent.tools.sb_files_tool import SandboxFilesTool
|
||||
from typing import Optional
|
||||
from agent.prompt import get_system_prompt
|
||||
from agentpress.response_processor import ProcessorConfig
|
||||
from dotenv import load_dotenv
|
||||
from agent.tools.utils.daytona_sandbox import create_sandbox
|
||||
|
||||
# Load environment variables
|
||||
from agent.tools.utils.daytona_sandbox import daytona, create_sandbox
|
||||
from daytona_api_client.models.workspace_state import WorkspaceState
|
||||
load_dotenv()
|
||||
|
||||
async def run_agent(thread_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25):
|
||||
async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25):
|
||||
"""Run the development agent with specified configuration."""
|
||||
|
||||
if not thread_manager:
|
||||
thread_manager = ThreadManager()
|
||||
|
||||
if True: # todo: change to of not sandbox running
|
||||
sandbox = create_sandbox("vvv")
|
||||
sandbox_id = sandbox.id
|
||||
sandbox_password = "vvv"
|
||||
client = await thread_manager.db.client
|
||||
## probably want to move to api.py
|
||||
project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
if project.data[0]['sandbox_id']:
|
||||
sandbox_id = project.data[0]['sandbox_id']
|
||||
sandbox_pass = project.data[0]['sandbox_pass']
|
||||
sandbox = daytona.get_current_sandbox(sandbox_id)
|
||||
if sandbox.instance.state == WorkspaceState.ARCHIVED or sandbox.instance.state == WorkspaceState.STOPPED:
|
||||
try:
|
||||
daytona.start(sandbox)
|
||||
except Exception as e:
|
||||
print(f"Error starting sandbox: {e}")
|
||||
raise e
|
||||
else:
|
||||
sandbox_id = "sandbox-01efaaa5"
|
||||
sandbox_password = "vvv"
|
||||
sandbox_pass = str(uuid4())
|
||||
sandbox = create_sandbox(sandbox_pass)
|
||||
sandbox_id = sandbox.id
|
||||
await client.table('projects').update({
|
||||
'sandbox_id': sandbox_id,
|
||||
'sandbox_pass': sandbox_pass
|
||||
}).eq('project_id', project_id).execute()
|
||||
### ---
|
||||
|
||||
print("Adding tools to thread manager...")
|
||||
# thread_manager.add_tool(FilesTool)
|
||||
# thread_manager.add_tool(TerminalTool)
|
||||
# thread_manager.add_tool(CodeSearchTool)
|
||||
thread_manager.add_tool(SandboxBrowseTool, sandbox_id=sandbox_id, password=sandbox_password)
|
||||
thread_manager.add_tool(SandboxWebsiteTool, sandbox_id=sandbox_id, password=sandbox_password)
|
||||
thread_manager.add_tool(SandboxShellTool, sandbox_id=sandbox_id, password=sandbox_password)
|
||||
thread_manager.add_tool(SandboxFilesTool, sandbox_id=sandbox_id, password=sandbox_password)
|
||||
thread_manager.add_tool(SandboxBrowseTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||
thread_manager.add_tool(SandboxWebsiteTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||
thread_manager.add_tool(SandboxShellTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||
thread_manager.add_tool(SandboxFilesTool, sandbox_id=sandbox_id, password=sandbox_pass)
|
||||
|
||||
system_message = { "role": "system", "content": get_system_prompt() }
|
||||
|
||||
|
@ -49,7 +61,7 @@ async def run_agent(thread_id: str, stream: bool = True, thread_manager: Optiona
|
|||
#groq/deepseek-r1-distill-llama-70b
|
||||
#bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
|
||||
files_tool = SandboxFilesTool(sandbox_id=sandbox_id, password=sandbox_password)
|
||||
files_tool = SandboxFilesTool(sandbox_id=sandbox_id, password=sandbox_pass)
|
||||
|
||||
files_state = await files_tool.get_workspace_state()
|
||||
|
||||
|
@ -74,8 +86,8 @@ Current development environment workspace state:
|
|||
tool_choice="auto",
|
||||
max_xml_tool_calls=1,
|
||||
processor_config=ProcessorConfig(
|
||||
xml_tool_calling=True,
|
||||
native_tool_calling=False,
|
||||
xml_tool_calling=False,
|
||||
native_tool_calling=True,
|
||||
execute_tools=True,
|
||||
execute_on_stream=True,
|
||||
tool_execution_strategy="parallel",
|
||||
|
@ -103,7 +115,8 @@ async def test_agent():
|
|||
client = await DBConnection().client
|
||||
|
||||
try:
|
||||
thread_result = await client.table('threads').insert({}).execute()
|
||||
thread_result = await client.table('projects').insert({"name": "test", "user_id": "68e1da55-0749-49db-937a-ff56bf0269a0"}).execute()
|
||||
thread_result = await client.table('threads').insert({'project_id': thread_result.data[0]['project_id']}).execute()
|
||||
thread_data = thread_result.data[0] if thread_result.data else None
|
||||
|
||||
if not thread_data:
|
||||
|
@ -111,6 +124,7 @@ async def test_agent():
|
|||
return
|
||||
|
||||
thread_id = thread_data['thread_id']
|
||||
project_id = thread_data['project_id']
|
||||
except Exception as e:
|
||||
print(f"Error creating thread: {str(e)}")
|
||||
return
|
||||
|
@ -126,7 +140,7 @@ async def test_agent():
|
|||
|
||||
if not user_message.strip():
|
||||
print("\n🔄 Running agent...\n")
|
||||
await process_agent_response(thread_id, thread_manager)
|
||||
await process_agent_response(thread_id, project_id, thread_manager)
|
||||
continue
|
||||
|
||||
# Add the user message to the thread
|
||||
|
@ -141,17 +155,17 @@ async def test_agent():
|
|||
)
|
||||
|
||||
print("\n🔄 Running agent...\n")
|
||||
await process_agent_response(thread_id, thread_manager)
|
||||
await process_agent_response(thread_id, project_id, thread_manager)
|
||||
|
||||
print("\n👋 Test completed. Goodbye!")
|
||||
|
||||
async def process_agent_response(thread_id: str, thread_manager: ThreadManager):
|
||||
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
|
||||
"""Process the streaming response from the agent."""
|
||||
chunk_counter = 0
|
||||
current_response = ""
|
||||
tool_call_counter = 0 # Track number of tool calls
|
||||
|
||||
async for chunk in run_agent(thread_id=thread_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
||||
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
||||
chunk_counter += 1
|
||||
|
||||
if chunk.get('type') == 'content' and 'content' in chunk:
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
revoke delete on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke insert on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke references on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke select on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke trigger on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke truncate on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke update on table "public"."agent_runs" from "anon";
|
||||
|
||||
revoke delete on table "public"."messages" from "anon";
|
||||
|
||||
revoke insert on table "public"."messages" from "anon";
|
||||
|
||||
revoke references on table "public"."messages" from "anon";
|
||||
|
||||
revoke select on table "public"."messages" from "anon";
|
||||
|
||||
revoke trigger on table "public"."messages" from "anon";
|
||||
|
||||
revoke truncate on table "public"."messages" from "anon";
|
||||
|
||||
revoke update on table "public"."messages" from "anon";
|
||||
|
||||
revoke delete on table "public"."projects" from "anon";
|
||||
|
||||
revoke insert on table "public"."projects" from "anon";
|
||||
|
||||
revoke references on table "public"."projects" from "anon";
|
||||
|
||||
revoke select on table "public"."projects" from "anon";
|
||||
|
||||
revoke trigger on table "public"."projects" from "anon";
|
||||
|
||||
revoke truncate on table "public"."projects" from "anon";
|
||||
|
||||
revoke update on table "public"."projects" from "anon";
|
||||
|
||||
revoke delete on table "public"."threads" from "anon";
|
||||
|
||||
revoke insert on table "public"."threads" from "anon";
|
||||
|
||||
revoke references on table "public"."threads" from "anon";
|
||||
|
||||
revoke select on table "public"."threads" from "anon";
|
||||
|
||||
revoke trigger on table "public"."threads" from "anon";
|
||||
|
||||
revoke truncate on table "public"."threads" from "anon";
|
||||
|
||||
revoke update on table "public"."threads" from "anon";
|
||||
|
||||
alter table "public"."projects" add column "sandbox_id" text;
|
||||
|
||||
alter table "public"."projects" add column "sandbox_pass" text;
|
||||
|
||||
set check_function_bodies = off;
|
||||
|
||||
CREATE OR REPLACE FUNCTION public.get_llm_formatted_messages(p_thread_id uuid)
|
||||
RETURNS jsonb
|
||||
LANGUAGE plpgsql
|
||||
AS $function$
|
||||
DECLARE
|
||||
messages_array JSONB := '[]'::JSONB;
|
||||
BEGIN
|
||||
-- Check if thread exists
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM threads t
|
||||
WHERE t.thread_id = p_thread_id
|
||||
) THEN
|
||||
RAISE EXCEPTION 'Thread not found';
|
||||
END IF;
|
||||
|
||||
-- Parse content if it's stored as a string and return proper JSON objects
|
||||
WITH parsed_messages AS (
|
||||
SELECT
|
||||
CASE
|
||||
WHEN jsonb_typeof(content) = 'string' THEN content::text::jsonb
|
||||
ELSE content
|
||||
END AS parsed_content,
|
||||
created_at
|
||||
FROM messages
|
||||
WHERE thread_id = p_thread_id
|
||||
AND is_llm_message = TRUE
|
||||
),
|
||||
-- Process each message to ensure tool_calls function arguments are strings
|
||||
processed_messages AS (
|
||||
SELECT
|
||||
CASE
|
||||
-- When the message has tool_calls
|
||||
WHEN jsonb_path_exists(parsed_content, '$.tool_calls') THEN
|
||||
(
|
||||
WITH tool_calls AS (
|
||||
-- Extract and process each tool call
|
||||
SELECT
|
||||
jsonb_array_elements(parsed_content -> 'tool_calls') AS tool_call,
|
||||
i AS idx
|
||||
FROM generate_series(0, jsonb_array_length(parsed_content -> 'tool_calls') - 1) AS i
|
||||
),
|
||||
processed_tool_calls AS (
|
||||
SELECT
|
||||
idx,
|
||||
CASE
|
||||
-- If function arguments exist and is not a string, convert to JSON string
|
||||
WHEN jsonb_path_exists(tool_call, '$.function.arguments')
|
||||
AND jsonb_typeof(tool_call #> '{function,arguments}') != 'string' THEN
|
||||
jsonb_set(
|
||||
tool_call,
|
||||
'{function,arguments}',
|
||||
to_jsonb(tool_call #>> '{function,arguments}')
|
||||
)
|
||||
ELSE tool_call
|
||||
END AS processed_tool_call
|
||||
FROM tool_calls
|
||||
),
|
||||
-- Convert processed tool calls back to an array
|
||||
tool_calls_array AS (
|
||||
SELECT jsonb_agg(processed_tool_call ORDER BY idx) AS tool_calls_array
|
||||
FROM processed_tool_calls
|
||||
)
|
||||
-- Replace tool_calls in the original message
|
||||
SELECT jsonb_set(parsed_content, '{tool_calls}', tool_calls_array)
|
||||
FROM tool_calls_array
|
||||
)
|
||||
ELSE parsed_content
|
||||
END AS final_content,
|
||||
created_at
|
||||
FROM parsed_messages
|
||||
)
|
||||
-- Aggregate messages into an array
|
||||
SELECT JSONB_AGG(final_content ORDER BY created_at)
|
||||
INTO messages_array
|
||||
FROM processed_messages;
|
||||
|
||||
-- Handle the case when no messages are found
|
||||
IF messages_array IS NULL THEN
|
||||
RETURN '[]'::JSONB;
|
||||
END IF;
|
||||
|
||||
RETURN messages_array;
|
||||
END;
|
||||
$function$
|
||||
;
|
||||
|
||||
|
Loading…
Reference in New Issue