mirror of https://github.com/kortix-ai/suna.git
feat: implement background task for project name generation using LLM
This commit is contained in:
parent
25086ffa26
commit
4e69026f57
|
@ -20,6 +20,7 @@ from utils.logger import logger
|
||||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||||
from utils.db import update_agent_run_status
|
from utils.db import update_agent_run_status
|
||||||
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
||||||
|
from services.llm import make_llm_api_call
|
||||||
|
|
||||||
# Initialize shared resources
|
# Initialize shared resources
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -749,6 +750,72 @@ async def run_agent_background(
|
||||||
|
|
||||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||||
|
|
||||||
|
# New background task function
|
||||||
|
async def generate_and_update_project_name(project_id: str, prompt: str):
|
||||||
|
"""Generates a project name using an LLM and updates the database."""
|
||||||
|
logger.info(f"Starting background task to generate name for project: {project_id}")
|
||||||
|
try:
|
||||||
|
# Ensure db client is ready (may need re-initialization in background task context)
|
||||||
|
# Getting a fresh connection within the task might be safer
|
||||||
|
db_conn = DBConnection()
|
||||||
|
client = await db_conn.client
|
||||||
|
|
||||||
|
# Prepare LLM call
|
||||||
|
model_name = "openai/gpt-4o-mini" # Or claude-3-haiku
|
||||||
|
system_prompt = "You are a helpful assistant that generates extremely concise titles (2-4 words maximum) for chat threads based on the user's message. Respond with only the title, no other text or punctuation."
|
||||||
|
user_message = f"Generate an extremely brief title (2-4 words only) for a chat thread that starts with this message: \"{prompt}\""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message}
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(f"Calling LLM ({model_name}) for project {project_id} naming.")
|
||||||
|
|
||||||
|
# Use make_llm_api_call (ensure it's compatible with background task context)
|
||||||
|
response = await make_llm_api_call(
|
||||||
|
messages=messages,
|
||||||
|
model_name=model_name,
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract and clean the name
|
||||||
|
generated_name = None
|
||||||
|
if response and response.get('choices') and response['choices'][0].get('message'):
|
||||||
|
raw_name = response['choices'][0]['message'].get('content', '').strip()
|
||||||
|
# Simple cleaning: remove quotes and extra whitespace
|
||||||
|
cleaned_name = raw_name.strip('\'" \n\t')
|
||||||
|
if cleaned_name:
|
||||||
|
generated_name = cleaned_name
|
||||||
|
logger.info(f"LLM generated name for project {project_id}: '{generated_name}'")
|
||||||
|
else:
|
||||||
|
logger.warning(f"LLM returned an empty name for project {project_id}.")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to get valid response from LLM for project {project_id} naming. Response: {response}")
|
||||||
|
|
||||||
|
print(f"\n\n\nGenerated name: {generated_name}\n\n\n")
|
||||||
|
# Update database if name was generated
|
||||||
|
if generated_name:
|
||||||
|
update_result = await client.table('projects') \
|
||||||
|
.update({"name": generated_name}) \
|
||||||
|
.eq("project_id", project_id) \
|
||||||
|
.execute()
|
||||||
|
|
||||||
|
if hasattr(update_result, 'data') and update_result.data:
|
||||||
|
logger.info(f"Successfully updated project {project_id} name to '{generated_name}'")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to update project {project_id} name in database. Update result: {update_result}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"No generated name, skipping database update for project {project_id}.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in background naming task for project {project_id}: {str(e)}\n{traceback.format_exc()}")
|
||||||
|
finally:
|
||||||
|
if 'db_conn' in locals():
|
||||||
|
pass
|
||||||
|
logger.info(f"Finished background naming task for project: {project_id}")
|
||||||
|
|
||||||
@router.post("/agent/initiate", response_model=InitiateAgentResponse)
|
@router.post("/agent/initiate", response_model=InitiateAgentResponse)
|
||||||
async def initiate_agent_with_files(
|
async def initiate_agent_with_files(
|
||||||
prompt: str = Form(...),
|
prompt: str = Form(...),
|
||||||
|
@ -780,11 +847,14 @@ async def initiate_agent_with_files(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. Create Project
|
# 1. Create Project
|
||||||
project_name = f"Chat initiated on {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}"
|
# Use prompt for placeholder name
|
||||||
|
placeholder_name = f"{prompt[:30]}..." if len(prompt) > 30 else prompt
|
||||||
|
logger.info(f"Using placeholder name: '{placeholder_name}'")
|
||||||
|
|
||||||
project = await client.table('projects').insert({
|
project = await client.table('projects').insert({
|
||||||
"project_id": str(uuid.uuid4()),
|
"project_id": str(uuid.uuid4()),
|
||||||
"account_id": account_id,
|
"account_id": account_id,
|
||||||
"name": project_name,
|
"name": placeholder_name, # Use placeholder
|
||||||
"created_at": datetime.now(timezone.utc).isoformat()
|
"created_at": datetime.now(timezone.utc).isoformat()
|
||||||
}).execute()
|
}).execute()
|
||||||
|
|
||||||
|
@ -802,6 +872,16 @@ async def initiate_agent_with_files(
|
||||||
thread_id = thread.data[0]['thread_id']
|
thread_id = thread.data[0]['thread_id']
|
||||||
logger.info(f"Created new thread: {thread_id}")
|
logger.info(f"Created new thread: {thread_id}")
|
||||||
|
|
||||||
|
# ---- Trigger Background Naming Task ----
|
||||||
|
logger.info(f"Scheduling background task to generate name for project {project_id}")
|
||||||
|
asyncio.create_task(
|
||||||
|
generate_and_update_project_name(
|
||||||
|
project_id=project_id,
|
||||||
|
prompt=prompt
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# -----------------------------------------
|
||||||
|
|
||||||
# 3. Create Sandbox
|
# 3. Create Sandbox
|
||||||
sandbox_pass = str(uuid.uuid4())
|
sandbox_pass = str(uuid.uuid4())
|
||||||
sandbox = create_sandbox(sandbox_pass)
|
sandbox = create_sandbox(sandbox_pass)
|
||||||
|
@ -982,6 +1062,7 @@ async def initiate_agent_with_files(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Return immediately without waiting for the naming task
|
||||||
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
|
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Reference in New Issue