mirror of https://github.com/kortix-ai/suna.git
1028 lines
47 KiB
Python
1028 lines
47 KiB
Python
import json
|
|
import logging
|
|
import asyncio
|
|
from typing import List, Dict, Any, Optional, Callable, AsyncGenerator, Union, Coroutine
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from agentpress.db import Database, Thread, ThreadRun, AgentRun
|
|
from agentpress.tool import Tool, ToolResult
|
|
from agentpress.llm import make_llm_api_call
|
|
from datetime import datetime
|
|
from agentpress.tool_registry import ToolRegistry
|
|
import re
|
|
import uuid
|
|
|
|
class ThreadAgent:
|
|
def __init__(self, thread_manager, thread_id: str, agent_run: AgentRun, **kwargs):
|
|
self.thread_manager = thread_manager
|
|
self.thread_id = thread_id
|
|
self.agent_run = agent_run
|
|
self.system_message = kwargs.get('system_message', {"role": "system", "content": ""})
|
|
self.model_name = kwargs.get('model_name', "gpt-4")
|
|
self.temperature = kwargs.get('temperature', 0.5)
|
|
self.max_tokens = kwargs.get('max_tokens')
|
|
self.tools = kwargs.get('tools')
|
|
self.additional_system_message = kwargs.get('additional_system_message')
|
|
self.additional_message = kwargs.get('additional_message')
|
|
self.execute_tools_async = kwargs.get('execute_tools_async', True)
|
|
self.top_p = kwargs.get('top_p')
|
|
self.tool_choice = kwargs.get('tool_choice', "auto")
|
|
self.response_format = kwargs.get('response_format')
|
|
self.autonomous_iterations_amount = kwargs.get('autonomous_iterations_amount', 5)
|
|
self.continue_instructions = kwargs['continue_instructions'] # Make this required
|
|
|
|
self.initializer = kwargs.get('initializer')
|
|
self.pre_iteration = kwargs.get('pre_iteration')
|
|
self.after_iteration = kwargs.get('after_iteration')
|
|
self.finalizer = kwargs.get('finalizer')
|
|
|
|
async def run(self) -> Dict[str, Any]:
|
|
if self.agent_run.status == "queued":
|
|
self.agent_run.status = "in_progress"
|
|
self.agent_run.started_at = int(datetime.utcnow().timestamp())
|
|
await self.thread_manager.update_agent_run(self.agent_run)
|
|
|
|
iteration_results = []
|
|
final_status = "completed"
|
|
|
|
if self.initializer:
|
|
await self.initializer(self)
|
|
|
|
try:
|
|
for iteration in range(self.autonomous_iterations_amount):
|
|
if await self.thread_manager.should_stop(self.thread_id, self.agent_run.id):
|
|
final_status = "stopped"
|
|
break
|
|
|
|
if self.pre_iteration:
|
|
await self.pre_iteration(iteration, self)
|
|
|
|
# Add continue_instructions as a user message after the first iteration
|
|
if iteration > 0 and self.continue_instructions:
|
|
await self.thread_manager.add_message(
|
|
self.thread_id,
|
|
{"role": "user", "content": self.continue_instructions}
|
|
)
|
|
|
|
# Create a new ThreadRun for this iteration
|
|
thread_run_kwargs = {k: v for k, v in self.__dict__.items() if k not in ['thread_id', 'thread_manager', 'agent_run'] and not k.startswith('_') and not callable(v)}
|
|
thread_run = await self.thread_manager.create_thread_run(self.thread_id, **thread_run_kwargs)
|
|
|
|
result = await self.thread_manager.run_thread(
|
|
thread_id=self.thread_id,
|
|
thread_run=thread_run,
|
|
system_message=self.system_message,
|
|
model_name=self.model_name,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
tools=self.tools,
|
|
additional_system_message=self.additional_system_message,
|
|
additional_message=self.additional_message,
|
|
execute_tools_async=self.execute_tools_async,
|
|
top_p=self.top_p,
|
|
tool_choice=self.tool_choice,
|
|
response_format=self.response_format
|
|
)
|
|
|
|
iteration_results.append(result)
|
|
self.agent_run.iterations_count += 1
|
|
await self.thread_manager.update_agent_run(self.agent_run)
|
|
|
|
if self.after_iteration:
|
|
await self.after_iteration(iteration, result, self)
|
|
|
|
if result.get("status") == "error" or result.get("status") == "stopped":
|
|
final_status = result["status"]
|
|
break
|
|
|
|
if await self.thread_manager.should_stop(self.thread_id, self.agent_run.id):
|
|
final_status = "stopped"
|
|
break
|
|
|
|
except Exception as e:
|
|
final_status = "failed"
|
|
self.agent_run.last_error = str(e)
|
|
logging.error(f"Error in thread agent run: {str(e)}")
|
|
|
|
finally:
|
|
if self.finalizer:
|
|
await self.finalizer(final_status, self)
|
|
|
|
self.agent_run.status = final_status
|
|
if final_status == "completed":
|
|
self.agent_run.completed_at = int(datetime.utcnow().timestamp())
|
|
elif final_status in ["stopped", "cancelled"]:
|
|
self.agent_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
elif final_status == "failed":
|
|
self.agent_run.failed_at = int(datetime.utcnow().timestamp())
|
|
|
|
await self.thread_manager.update_agent_run(self.agent_run)
|
|
|
|
return {
|
|
"status": final_status,
|
|
"iterations": iteration_results,
|
|
"total_iterations": self.agent_run.iterations_count,
|
|
"final_config": {
|
|
k: v for k, v in self.__dict__.items()
|
|
if not k.startswith('_') and not callable(v) and not isinstance(v, (ThreadManager, AgentRun))
|
|
}
|
|
}
|
|
|
|
class ThreadManager:
|
|
def __init__(self, db: Database):
|
|
self.db = db
|
|
self.tool_registry = ToolRegistry()
|
|
|
|
async def create_thread(self) -> int:
|
|
async with self.db.get_async_session() as session:
|
|
new_thread = Thread(
|
|
messages=json.dumps([])
|
|
)
|
|
session.add(new_thread)
|
|
await session.commit()
|
|
await session.refresh(new_thread) # Ensure thread_id is populated
|
|
return new_thread.thread_id
|
|
|
|
async def add_message(self, thread_id: int, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
|
|
logging.info(f"Adding message to thread {thread_id} with images: {images}")
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if not thread:
|
|
raise ValueError(f"Thread with id {thread_id} not found")
|
|
|
|
try:
|
|
messages = json.loads(thread.messages)
|
|
|
|
# If we're adding a user message, perform checks
|
|
if message_data['role'] == 'user':
|
|
# Find the last assistant message with tool calls
|
|
last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
|
|
|
|
if last_assistant_index is not None:
|
|
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
|
|
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:] if msg['role'] == 'tool')
|
|
|
|
if tool_call_count != tool_response_count:
|
|
await self.cleanup_incomplete_tool_calls(thread_id)
|
|
|
|
# Convert ToolResult objects to strings
|
|
for key, value in message_data.items():
|
|
if isinstance(value, ToolResult):
|
|
message_data[key] = str(value)
|
|
|
|
# Process images if present
|
|
if images:
|
|
if isinstance(message_data['content'], str):
|
|
message_data['content'] = [{"type": "text", "text": message_data['content']}]
|
|
elif not isinstance(message_data['content'], list):
|
|
message_data['content'] = []
|
|
|
|
for image in images:
|
|
image_content = {
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{image['content_type']};base64,{image['base64']}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
message_data['content'].append(image_content)
|
|
|
|
messages.append(message_data)
|
|
thread.messages = json.dumps(messages)
|
|
await session.commit()
|
|
logging.info(f"Message added to thread {thread_id}: {message_data}")
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logging.error(f"Failed to add message to thread {thread_id}: {e}")
|
|
raise e
|
|
|
|
async def get_message(self, thread_id: int, message_index: int) -> Optional[Dict[str, Any]]:
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if not thread:
|
|
return None
|
|
messages = json.loads(thread.messages)
|
|
if message_index < len(messages):
|
|
return messages[message_index]
|
|
return None
|
|
|
|
async def modify_message(self, thread_id: int, message_index: int, new_message_data: Dict[str, Any]):
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if not thread:
|
|
raise ValueError(f"Thread with id {thread_id} not found")
|
|
|
|
try:
|
|
messages = json.loads(thread.messages)
|
|
if message_index < len(messages):
|
|
messages[message_index] = new_message_data
|
|
thread.messages = json.dumps(messages)
|
|
await session.commit()
|
|
else:
|
|
raise ValueError(f"Message index {message_index} is out of range")
|
|
except Exception as e:
|
|
await session.rollback()
|
|
raise e
|
|
|
|
async def remove_message(self, thread_id: int, message_index: int):
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if not thread:
|
|
raise ValueError(f"Thread with id {thread_id} not found")
|
|
|
|
try:
|
|
messages = json.loads(thread.messages)
|
|
if message_index < len(messages):
|
|
del messages[message_index]
|
|
thread.messages = json.dumps(messages)
|
|
await session.commit()
|
|
except Exception as e:
|
|
await session.rollback()
|
|
raise e
|
|
|
|
async def list_messages(self, thread_id: int, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]:
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if not thread:
|
|
return []
|
|
messages = json.loads(thread.messages)
|
|
|
|
if only_latest_assistant:
|
|
for msg in reversed(messages):
|
|
if msg.get('role') == 'assistant':
|
|
return [msg]
|
|
return []
|
|
|
|
filtered_messages = messages # Initialize filtered_messages with all messages
|
|
|
|
if hide_tool_msgs:
|
|
filtered_messages = [
|
|
{k: v for k, v in msg.items() if k != 'tool_calls'}
|
|
for msg in filtered_messages
|
|
if msg.get('role') != 'tool'
|
|
]
|
|
|
|
if regular_list:
|
|
filtered_messages = [
|
|
msg for msg in filtered_messages
|
|
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
|
|
]
|
|
|
|
return filtered_messages
|
|
|
|
async def cleanup_incomplete_tool_calls(self, thread_id: int):
|
|
messages = await self.list_messages(thread_id)
|
|
last_assistant_message = next((m for m in reversed(messages) if m['role'] == 'assistant' and 'tool_calls' in m), None)
|
|
|
|
if last_assistant_message:
|
|
tool_calls = last_assistant_message.get('tool_calls', [])
|
|
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool']
|
|
|
|
if len(tool_calls) != len(tool_responses):
|
|
# Create failed ToolResults for incomplete tool calls
|
|
failed_tool_results = []
|
|
for tool_call in tool_calls[len(tool_responses):]:
|
|
failed_tool_result = {
|
|
"role": "tool",
|
|
"tool_call_id": tool_call['id'],
|
|
"name": tool_call['function']['name'],
|
|
"content": "ToolResult(success=False, output='Execution interrupted. Session was stopped.')"
|
|
}
|
|
failed_tool_results.append(failed_tool_result)
|
|
|
|
# Insert failed tool results after the last assistant message
|
|
assistant_index = messages.index(last_assistant_message)
|
|
messages[assistant_index+1:assistant_index+1] = failed_tool_results
|
|
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if thread:
|
|
thread.messages = json.dumps(messages)
|
|
await session.commit()
|
|
|
|
return True
|
|
return False
|
|
|
|
async def run_thread(self, thread_id: str, thread_run: ThreadRun, **kwargs) -> Dict[str, Any]:
|
|
try:
|
|
thread_run.status = "in_progress"
|
|
thread_run.started_at = int(datetime.utcnow().timestamp())
|
|
await self.update_thread_run(thread_run)
|
|
|
|
if await self.should_stop(thread_id, thread_run.id):
|
|
thread_run.status = "stopped"
|
|
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
await self.update_thread_run(thread_run)
|
|
return {"status": "stopped", "message": "Thread run cancelled"}
|
|
|
|
# Fetch full tool objects based on the provided tool names
|
|
full_tools = None
|
|
if kwargs.get('tools'):
|
|
full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in kwargs['tools'] if self.tool_registry.get_tool(tool_name)]
|
|
|
|
# Modify the system message if additional_system_message is provided
|
|
if kwargs.get('additional_system_message'):
|
|
kwargs['system_message']['content'] += f"\n\n{kwargs['additional_system_message']}"
|
|
|
|
if await self.should_stop(thread_id, thread_run.id):
|
|
thread_run.status = "stopped"
|
|
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
await self.update_thread_run(thread_run)
|
|
return {"status": "stopped", "message": "Thread run cancelled"}
|
|
|
|
if kwargs.get('use_tool_parser'):
|
|
hide_tool_msgs = True
|
|
|
|
await self.cleanup_incomplete_tool_calls(thread_id)
|
|
|
|
# Prepare messages
|
|
messages = await self.list_messages(thread_id, hide_tool_msgs=kwargs.get('hide_tool_msgs', False))
|
|
prepared_messages = [kwargs['system_message']] + messages
|
|
|
|
# Add the additional_message if provided
|
|
if kwargs.get('additional_message'):
|
|
prepared_messages.append(kwargs['additional_message'])
|
|
|
|
response = await make_llm_api_call(
|
|
prepared_messages,
|
|
kwargs['model_name'],
|
|
temperature=kwargs.get('temperature', 0.5),
|
|
max_tokens=kwargs.get('max_tokens'),
|
|
tools=full_tools,
|
|
tool_choice=kwargs.get('tool_choice', "auto"),
|
|
stream=False,
|
|
top_p=kwargs.get('top_p'),
|
|
response_format=kwargs.get('response_format')
|
|
)
|
|
|
|
usage = response.usage if hasattr(response, 'usage') else None
|
|
usage_dict = self.serialize_usage(usage) if usage else None
|
|
thread_run.usage = usage_dict
|
|
|
|
# Add the assistant's message to the thread
|
|
assistant_message = {
|
|
"role": "assistant",
|
|
"content": response.choices[0].message['content']
|
|
}
|
|
if 'tool_calls' in response.choices[0].message:
|
|
assistant_message['tool_calls'] = response.choices[0].message['tool_calls']
|
|
|
|
await self.add_message(thread_id, assistant_message)
|
|
|
|
if kwargs.get('tools') is None or kwargs.get('use_tool_parser'):
|
|
await self.handle_response_without_tools(thread_id, response, kwargs.get('use_tool_parser', False))
|
|
else:
|
|
await self.handle_response_with_tools(thread_id, response, kwargs.get('execute_tools_async', True))
|
|
|
|
if await self.should_stop(thread_id, thread_run.id):
|
|
thread_run.status = "stopped"
|
|
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
await self.update_thread_run(thread_run)
|
|
return {"status": "stopped", "message": "Thread run cancelled"}
|
|
|
|
thread_run.status = "completed"
|
|
thread_run.completed_at = int(datetime.utcnow().timestamp())
|
|
await self.update_thread_run(thread_run)
|
|
|
|
return {
|
|
"id": thread_run.id,
|
|
"status": thread_run.status,
|
|
"choices": [self.serialize_choice(choice) for choice in response.choices],
|
|
"usage": usage_dict,
|
|
"model": kwargs['model_name'],
|
|
"object": "chat.completion",
|
|
"created": int(datetime.utcnow().timestamp())
|
|
}
|
|
except Exception as e:
|
|
thread_run.status = "failed"
|
|
thread_run.failed_at = int(datetime.utcnow().timestamp())
|
|
thread_run.last_error = str(e)
|
|
await self.update_thread_run(thread_run)
|
|
raise
|
|
|
|
def serialize_usage(self, usage):
|
|
return {
|
|
"completion_tokens": usage.completion_tokens,
|
|
"prompt_tokens": usage.prompt_tokens,
|
|
"total_tokens": usage.total_tokens,
|
|
"completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details),
|
|
"prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details)
|
|
}
|
|
|
|
def serialize_completion_tokens_details(self, details):
|
|
return {
|
|
"audio_tokens": details.audio_tokens,
|
|
"reasoning_tokens": details.reasoning_tokens
|
|
}
|
|
|
|
def serialize_prompt_tokens_details(self, details):
|
|
return {
|
|
"audio_tokens": details.audio_tokens,
|
|
"cached_tokens": details.cached_tokens
|
|
}
|
|
|
|
def serialize_choice(self, choice):
|
|
return {
|
|
"finish_reason": choice.finish_reason,
|
|
"index": choice.index,
|
|
"message": self.serialize_message(choice.message)
|
|
}
|
|
|
|
def serialize_message(self, message):
|
|
return {
|
|
"content": message.content,
|
|
"role": message.role,
|
|
"tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None
|
|
}
|
|
|
|
def serialize_tool_call(self, tool_call):
|
|
return {
|
|
"id": tool_call.id,
|
|
"type": tool_call.type,
|
|
"function": {
|
|
"name": tool_call.function.name,
|
|
"arguments": tool_call.function.arguments
|
|
}
|
|
}
|
|
|
|
async def update_thread_run(self, thread_run: ThreadRun):
|
|
async with self.db.get_async_session() as session:
|
|
session.add(thread_run)
|
|
await session.commit()
|
|
await session.refresh(thread_run)
|
|
|
|
async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool):
|
|
response_content = response.choices[0].message['content']
|
|
|
|
if use_tool_parser:
|
|
await self.handle_tool_parser_response(thread_id, response_content)
|
|
else:
|
|
# The message has already been added in the run_thread method, so we don't need to add it again here
|
|
pass
|
|
|
|
async def handle_tool_parser_response(self, thread_id: int, response_content: str):
|
|
tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content)
|
|
if tool_call_match:
|
|
try:
|
|
tool_call_json = json.loads(tool_call_match.group())
|
|
tool_calls = tool_call_json.get('function_calls', [])
|
|
|
|
assistant_message = {
|
|
"role": "assistant",
|
|
"content": response_content,
|
|
"tool_calls": [
|
|
{
|
|
"id": f"call_{i}",
|
|
"type": "function",
|
|
"function": {
|
|
"name": call['name'],
|
|
"arguments": json.dumps(call['arguments'])
|
|
}
|
|
} for i, call in enumerate(tool_calls)
|
|
]
|
|
}
|
|
await self.add_message(thread_id, assistant_message)
|
|
|
|
available_functions = self.get_available_functions()
|
|
|
|
tool_results = await self.execute_tools(assistant_message['tool_calls'], available_functions, thread_id, execute_tools_async=True)
|
|
|
|
await self.process_tool_results(thread_id, tool_results)
|
|
|
|
except json.JSONDecodeError:
|
|
logging.error("Failed to parse tool call JSON from response")
|
|
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
|
else:
|
|
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
|
|
|
async def handle_response_with_tools(self, thread_id: int, response: Any, execute_tools_async: bool):
|
|
try:
|
|
response_message = response.choices[0].message
|
|
tool_calls = response_message.get('tool_calls', [])
|
|
|
|
# The assistant message has already been added in the run_thread method
|
|
|
|
available_functions = self.get_available_functions()
|
|
|
|
if await self.should_stop(thread_id, thread_id):
|
|
return {"status": "stopped", "message": "Session cancelled"}
|
|
|
|
if tool_calls:
|
|
if execute_tools_async:
|
|
tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id)
|
|
else:
|
|
tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id)
|
|
|
|
# Add tool results to messages
|
|
for result in tool_results:
|
|
await self.add_message(thread_id, result)
|
|
|
|
if await self.should_stop(thread_id, thread_id):
|
|
return {"status": "stopped", "message": "Session cancelled after tool execution"}
|
|
|
|
except AttributeError as e:
|
|
logging.error(f"AttributeError: {e}")
|
|
# No need to add the message here as it's already been added in the run_thread method
|
|
|
|
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
|
message = {
|
|
"role": "assistant",
|
|
"content": response_message.get('content') or "",
|
|
}
|
|
tool_calls = response_message.get('tool_calls')
|
|
if tool_calls:
|
|
message["tool_calls"] = [
|
|
{
|
|
"id": tool_call.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_call.function.name,
|
|
"arguments": tool_call.function.arguments
|
|
}
|
|
} for tool_call in tool_calls
|
|
]
|
|
return message
|
|
|
|
def get_available_functions(self) -> Dict[str, Callable]:
|
|
available_functions = {}
|
|
for tool_name, tool_info in self.tool_registry.get_all_tools().items():
|
|
tool_instance = tool_info['instance']
|
|
for func_name, func in tool_instance.__class__.__dict__.items():
|
|
if callable(func) and not func_name.startswith("__"):
|
|
available_functions[func_name] = getattr(tool_instance, func_name)
|
|
return available_functions
|
|
|
|
async def execute_tools(self, tool_calls: List[Any], available_functions: Dict[str, Callable], thread_id: int, execute_tools_async: bool) -> List[Dict[str, Any]]:
|
|
if execute_tools_async:
|
|
return await self.execute_tools_async(tool_calls, available_functions, thread_id)
|
|
else:
|
|
return await self.execute_tools_sync(tool_calls, available_functions, thread_id)
|
|
|
|
async def execute_tools_async(self, tool_calls, available_functions, thread_id):
|
|
async def execute_single_tool(tool_call):
|
|
if await self.should_stop(thread_id, thread_id):
|
|
return {"status": "stopped", "message": "Session cancelled"}
|
|
|
|
function_name = tool_call.function.name
|
|
function_args = json.loads(tool_call.function.arguments)
|
|
tool_call_id = tool_call.id
|
|
|
|
function_to_call = available_functions.get(function_name)
|
|
if function_to_call:
|
|
return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
|
else:
|
|
logging.warning(f"Function {function_name} not found in available functions")
|
|
return None
|
|
|
|
tool_results = await asyncio.gather(*[execute_single_tool(tool_call) for tool_call in tool_calls])
|
|
return [result for result in tool_results if result]
|
|
|
|
async def execute_tools_sync(self, tool_calls, available_functions, thread_id):
|
|
tool_results = []
|
|
for tool_call in tool_calls:
|
|
if await self.should_stop(thread_id, thread_id):
|
|
return [{"status": "stopped", "message": "Session cancelled"}]
|
|
|
|
function_name = tool_call.function.name
|
|
function_args = json.loads(tool_call.function.arguments)
|
|
tool_call_id = tool_call.id
|
|
|
|
function_to_call = available_functions.get(function_name)
|
|
if function_to_call:
|
|
result = await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
|
if result:
|
|
tool_results.append(result)
|
|
else:
|
|
logging.warning(f"Function {function_name} not found in available functions")
|
|
|
|
return tool_results
|
|
|
|
async def process_tool_results(self, thread_id: int, tool_results: List[Dict[str, Any]]):
|
|
for result in tool_results:
|
|
await self.add_message(thread_id, result['tool_message'])
|
|
|
|
async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id):
|
|
try:
|
|
function_response = await function_to_call(**function_args)
|
|
except Exception as e:
|
|
error_message = f"Error in {function_name}: {str(e)}"
|
|
function_response = ToolResult(success=False, output=error_message)
|
|
|
|
return {
|
|
"role": "tool",
|
|
"tool_call_id": tool_call_id,
|
|
"name": function_name,
|
|
"content": str(function_response),
|
|
}
|
|
|
|
async def should_stop(self, thread_id: str, run_id: str) -> bool:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(AgentRun, run_id)
|
|
if run and run.status in ["stopped", "cancelled", "queued"]:
|
|
return True
|
|
return False
|
|
|
|
async def stop_thread_run(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(ThreadRun, run_id)
|
|
if run and run.thread_id == thread_id and run.status == "in_progress":
|
|
run.status = "stopping"
|
|
await session.commit()
|
|
return self.serialize_thread_run(run)
|
|
return None
|
|
|
|
async def stop_agent_run(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
|
async with self.db.get_async_session() as session:
|
|
agent_run = await session.get(AgentRun, run_id)
|
|
if agent_run and agent_run.thread_id == thread_id and agent_run.status in ["in_progress", "queued"]:
|
|
agent_run.status = "stopped"
|
|
agent_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
|
|
# Update all associated ThreadRuns
|
|
associated_thread_runs = await session.execute(
|
|
select(ThreadRun).where(
|
|
(ThreadRun.thread_id == thread_id) &
|
|
(ThreadRun.created_at >= agent_run.created_at) &
|
|
(ThreadRun.status.in_(["queued", "in_progress"]))
|
|
)
|
|
)
|
|
for thread_run in associated_thread_runs.scalars():
|
|
thread_run.status = "stopped"
|
|
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
|
|
# Update all associated AgentRuns that are still queued
|
|
associated_agent_runs = await session.execute(
|
|
select(AgentRun).where(
|
|
(AgentRun.thread_id == thread_id) &
|
|
(AgentRun.created_at >= agent_run.created_at) &
|
|
(AgentRun.status == "queued")
|
|
)
|
|
)
|
|
for assoc_agent_run in associated_agent_runs.scalars():
|
|
assoc_agent_run.status = "stopped"
|
|
assoc_agent_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
|
|
await session.commit()
|
|
return self.serialize_agent_run(agent_run)
|
|
return None
|
|
|
|
async def save_thread_run(self, thread_id: str):
|
|
async with self.db.get_async_session() as session:
|
|
thread = await session.get(Thread, thread_id)
|
|
if not thread:
|
|
raise ValueError(f"Thread with id {thread_id} not found")
|
|
|
|
messages = json.loads(thread.messages)
|
|
creation_date = datetime.now().isoformat()
|
|
|
|
# Get the latest ThreadRun for this thread
|
|
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
|
result = await session.execute(stmt)
|
|
latest_thread_run = result.scalar_one_or_none()
|
|
|
|
if latest_thread_run:
|
|
# Update the existing ThreadRun
|
|
latest_thread_run.messages = json.dumps(messages)
|
|
latest_thread_run.last_updated_date = creation_date
|
|
await session.commit()
|
|
else:
|
|
# Create a new ThreadRun if none exists
|
|
new_thread_run = ThreadRun(
|
|
thread_id=thread_id,
|
|
messages=json.dumps(messages),
|
|
creation_date=creation_date,
|
|
status='completed'
|
|
)
|
|
session.add(new_thread_run)
|
|
await session.commit()
|
|
|
|
async def get_thread(self, thread_id: int) -> Optional[Thread]:
|
|
async with self.db.get_async_session() as session:
|
|
return await session.get(Thread, thread_id)
|
|
|
|
async def update_thread_run_with_error(self, thread_id: int, error_message: str):
|
|
async with self.db.get_async_session() as session:
|
|
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1)
|
|
result = await session.execute(stmt)
|
|
thread_run = result.scalar_one_or_none()
|
|
if thread_run:
|
|
thread_run.status = 'error'
|
|
thread_run.error_message = error_message # Store the full error message
|
|
await session.commit()
|
|
|
|
async def get_threads(self) -> List[Thread]:
|
|
async with self.db.get_async_session() as session:
|
|
result = await session.execute(select(Thread).order_by(Thread.thread_id.desc()))
|
|
return result.scalars().all()
|
|
|
|
async def get_latest_thread_run(self, thread_id: str):
|
|
async with self.db.get_async_session() as session:
|
|
stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1)
|
|
result = await session.execute(stmt)
|
|
latest_run = result.scalar_one_or_none()
|
|
if latest_run:
|
|
return {
|
|
"id": latest_run.id,
|
|
"status": latest_run.status,
|
|
"error_message": latest_run.last_error,
|
|
"created_at": latest_run.created_at,
|
|
"started_at": latest_run.started_at,
|
|
"completed_at": latest_run.completed_at,
|
|
"cancelled_at": latest_run.cancelled_at,
|
|
"failed_at": latest_run.failed_at,
|
|
"model": latest_run.model,
|
|
"usage": latest_run.usage
|
|
}
|
|
return None
|
|
|
|
async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(ThreadRun, run_id)
|
|
if run and run.thread_id == thread_id:
|
|
return {
|
|
"id": run.id,
|
|
"thread_id": run.thread_id,
|
|
"status": run.status,
|
|
"created_at": run.created_at,
|
|
"started_at": run.started_at,
|
|
"completed_at": run.completed_at,
|
|
"cancelled_at": run.cancelled_at,
|
|
"failed_at": run.failed_at,
|
|
"model": run.model,
|
|
"system_message": json.loads(run.system_message) if run.system_message else None,
|
|
"tools": json.loads(run.tools) if run.tools else None,
|
|
"usage": run.usage,
|
|
"temperature": run.temperature,
|
|
"top_p": run.top_p,
|
|
"max_tokens": run.max_tokens,
|
|
"tool_choice": run.tool_choice,
|
|
"execute_tools_async": run.execute_tools_async,
|
|
"response_format": json.loads(run.response_format) if run.response_format else None,
|
|
"last_error": run.last_error
|
|
}
|
|
return None
|
|
|
|
async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(ThreadRun, run_id)
|
|
if run and run.thread_id == thread_id and run.status == "in_progress":
|
|
run.status = "cancelled"
|
|
run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
await session.commit()
|
|
return await self.get_run(thread_id, run_id)
|
|
return None
|
|
|
|
async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
|
|
async with self.db.get_async_session() as session:
|
|
thread_runs_stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit)
|
|
agent_runs_stmt = select(AgentRun).where(AgentRun.thread_id == thread_id).order_by(AgentRun.created_at.desc()).limit(limit)
|
|
|
|
thread_runs_result = await session.execute(thread_runs_stmt)
|
|
agent_runs_result = await session.execute(agent_runs_stmt)
|
|
|
|
thread_runs = thread_runs_result.scalars().all()
|
|
agent_runs = agent_runs_result.scalars().all()
|
|
|
|
all_runs = [self.serialize_thread_run(run) for run in thread_runs] + [self.serialize_agent_run(run) for run in agent_runs]
|
|
all_runs.sort(key=lambda x: x['created_at'], reverse=True)
|
|
|
|
return all_runs[:limit]
|
|
|
|
async def create_agent_run(self, thread_id: str, **kwargs) -> AgentRun:
|
|
run_id = str(uuid.uuid4())
|
|
agent_run = AgentRun(
|
|
id=run_id,
|
|
thread_id=thread_id,
|
|
status="queued",
|
|
autonomous_iterations_amount=kwargs.get('autonomous_iterations_amount'),
|
|
continue_instructions=kwargs.get('continue_instructions')
|
|
)
|
|
async with self.db.get_async_session() as session:
|
|
session.add(agent_run)
|
|
await session.commit()
|
|
return agent_run
|
|
|
|
async def update_agent_run(self, run: AgentRun):
|
|
async with self.db.get_async_session() as session:
|
|
await session.merge(run)
|
|
await session.commit()
|
|
|
|
async def create_thread_run(self, thread_id: str, **kwargs) -> ThreadRun:
|
|
run_id = str(uuid.uuid4())
|
|
thread_run = ThreadRun(
|
|
id=run_id,
|
|
thread_id=thread_id,
|
|
status="queued",
|
|
model=kwargs.get('model_name'),
|
|
temperature=kwargs.get('temperature'),
|
|
max_tokens=kwargs.get('max_tokens'),
|
|
top_p=kwargs.get('top_p'),
|
|
tool_choice=kwargs.get('tool_choice', "auto"),
|
|
execute_tools_async=kwargs.get('execute_tools_async', True),
|
|
system_message=json.dumps(kwargs.get('system_message')),
|
|
tools=json.dumps(kwargs.get('tools')),
|
|
response_format=json.dumps(kwargs.get('response_format'))
|
|
)
|
|
async with self.db.get_async_session() as session:
|
|
session.add(thread_run)
|
|
await session.commit()
|
|
logging.info(f"Created ThreadRun {run_id} for thread {thread_id}. Total ThreadRuns: {await self.get_thread_run_count(thread_id)}")
|
|
return thread_run
|
|
|
|
async def get_thread_run_count(self, thread_id: str) -> int:
|
|
async with self.db.get_async_session() as session:
|
|
result = await session.execute(select(ThreadRun).filter_by(thread_id=thread_id))
|
|
return len(result.all())
|
|
|
|
async def run_thread_agent(self, thread_id: str, **kwargs) -> Dict[str, Any]:
|
|
if 'continue_instructions' not in kwargs or not kwargs['continue_instructions']:
|
|
raise ValueError("continue_instructions is required for running a thread agent")
|
|
|
|
agent_run = await self.create_agent_run(thread_id, **kwargs)
|
|
agent = ThreadAgent(self, thread_id, agent_run, **kwargs) # Remove None argument
|
|
|
|
try:
|
|
result = await agent.run()
|
|
return result
|
|
except Exception as e:
|
|
agent_run.status = "failed"
|
|
agent_run.failed_at = int(datetime.utcnow().timestamp())
|
|
agent_run.last_error = str(e)
|
|
await self.update_agent_run(agent_run)
|
|
raise
|
|
|
|
async def list_agent_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]:
|
|
async with self.db.get_async_session() as session:
|
|
stmt = select(AgentRun).where(AgentRun.thread_id == thread_id).order_by(AgentRun.created_at.desc()).limit(limit)
|
|
result = await session.execute(stmt)
|
|
runs = result.scalars().all()
|
|
return [
|
|
{
|
|
"id": run.id,
|
|
"thread_id": run.thread_id,
|
|
"status": run.status,
|
|
"created_at": run.created_at,
|
|
"started_at": run.started_at,
|
|
"completed_at": run.completed_at,
|
|
"cancelled_at": run.cancelled_at,
|
|
"failed_at": run.failed_at,
|
|
"autonomous_iterations_amount": run.autonomous_iterations_amount,
|
|
"iterations_count": run.iterations_count,
|
|
"continue_instructions": run.continue_instructions,
|
|
"iterations": json.loads(run.iterations) if run.iterations else None,
|
|
"last_error": run.last_error
|
|
}
|
|
for run in runs
|
|
]
|
|
|
|
async def stop_thread_run(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(ThreadRun, run_id)
|
|
if run and run.thread_id == thread_id and run.status == "in_progress":
|
|
run.status = "stopping"
|
|
await session.commit()
|
|
return self.serialize_thread_run(run)
|
|
return None
|
|
|
|
async def stop_agent_run(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
|
async with self.db.get_async_session() as session:
|
|
agent_run = await session.get(AgentRun, run_id)
|
|
if agent_run and agent_run.thread_id == thread_id and agent_run.status in ["in_progress", "queued"]:
|
|
agent_run.status = "stopped"
|
|
agent_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
|
|
# Update all associated ThreadRuns
|
|
associated_thread_runs = await session.execute(
|
|
select(ThreadRun).where(
|
|
(ThreadRun.thread_id == thread_id) &
|
|
(ThreadRun.created_at >= agent_run.created_at) &
|
|
(ThreadRun.status.in_(["queued", "in_progress"]))
|
|
)
|
|
)
|
|
for thread_run in associated_thread_runs.scalars():
|
|
thread_run.status = "stopped"
|
|
thread_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
|
|
# Update all associated AgentRuns that are still queued
|
|
associated_agent_runs = await session.execute(
|
|
select(AgentRun).where(
|
|
(AgentRun.thread_id == thread_id) &
|
|
(AgentRun.created_at >= agent_run.created_at) &
|
|
(AgentRun.status == "queued")
|
|
)
|
|
)
|
|
for assoc_agent_run in associated_agent_runs.scalars():
|
|
assoc_agent_run.status = "stopped"
|
|
assoc_agent_run.cancelled_at = int(datetime.utcnow().timestamp())
|
|
|
|
await session.commit()
|
|
return self.serialize_agent_run(agent_run)
|
|
return None
|
|
|
|
async def get_thread_run_status(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(ThreadRun, run_id)
|
|
if run and run.thread_id == thread_id:
|
|
return self.serialize_thread_run(run)
|
|
return None
|
|
|
|
async def get_agent_run_status(self, thread_id: str, run_id: str) -> Dict[str, Any]:
|
|
async with self.db.get_async_session() as session:
|
|
run = await session.get(AgentRun, run_id)
|
|
if run and run.thread_id == thread_id:
|
|
return self.serialize_agent_run(run)
|
|
return None
|
|
|
|
def serialize_thread_run(self, run: ThreadRun) -> Dict[str, Any]:
|
|
return {
|
|
"id": run.id,
|
|
"thread_id": run.thread_id,
|
|
"status": run.status,
|
|
"created_at": run.created_at,
|
|
"started_at": run.started_at,
|
|
"completed_at": run.completed_at,
|
|
"cancelled_at": run.cancelled_at,
|
|
"failed_at": run.failed_at,
|
|
"model": run.model,
|
|
"temperature": run.temperature,
|
|
"max_tokens": run.max_tokens,
|
|
"top_p": run.top_p,
|
|
"tool_choice": run.tool_choice,
|
|
"execute_tools_async": run.execute_tools_async,
|
|
"system_message": json.loads(run.system_message) if run.system_message else None,
|
|
"tools": json.loads(run.tools) if run.tools else None,
|
|
"usage": run.usage,
|
|
"response_format": json.loads(run.response_format) if run.response_format else None,
|
|
"last_error": run.last_error,
|
|
"is_agent_run": False # Add this line
|
|
}
|
|
|
|
def serialize_agent_run(self, run: AgentRun) -> Dict[str, Any]:
|
|
return {
|
|
"id": run.id,
|
|
"thread_id": run.thread_id,
|
|
"status": run.status,
|
|
"created_at": run.created_at,
|
|
"started_at": run.started_at,
|
|
"completed_at": run.completed_at,
|
|
"cancelled_at": run.cancelled_at,
|
|
"failed_at": run.failed_at,
|
|
"autonomous_iterations_amount": run.autonomous_iterations_amount,
|
|
"iterations_count": run.iterations_count,
|
|
"continue_instructions": run.continue_instructions,
|
|
"iterations": json.loads(run.iterations) if run.iterations else None,
|
|
"last_error": run.last_error,
|
|
"is_agent_run": True # Add this line
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
from agentpress.db import Database
|
|
|
|
async def main():
|
|
db = Database()
|
|
manager = ThreadManager(db)
|
|
|
|
thread_id = await manager.create_thread()
|
|
await manager.add_message(thread_id, {"role": "user", "content": "Let's have a conversation about artificial intelligence."})
|
|
|
|
async def initializer(agent):
|
|
print("Initializing thread agent...")
|
|
agent.temperature = 0.8
|
|
|
|
async def pre_iteration(iteration: int, agent: ThreadAgent):
|
|
print(f"Preparing iteration {iteration + 1}...")
|
|
agent.max_tokens = 200 if iteration > 1 else 150
|
|
|
|
async def after_iteration(iteration: int, result: Dict[str, Any], agent: ThreadAgent):
|
|
print(f"Completed iteration {iteration + 1}. Status: {result.get('status')}")
|
|
if "AI ethics" in result.get("content", ""):
|
|
agent.continue_instructions = "Let's focus more on AI ethics in the next iteration."
|
|
|
|
async def finalizer(status: str, agent: ThreadAgent):
|
|
print(f"Thread agent finished with status: {status}")
|
|
print(f"Final configuration: {agent.__dict__}")
|
|
|
|
system_message = {"role": "system", "content": "You are an AI expert engaging in a conversation about artificial intelligence."}
|
|
response = await manager.run_thread_agent(
|
|
thread_id=thread_id,
|
|
system_message=system_message,
|
|
model_name="gpt-4o",
|
|
temperature=0.7,
|
|
max_tokens=150,
|
|
autonomous_iterations_amount=3,
|
|
continue_instructions="Continue the conversation about AI, introducing new aspects or asking thought-provoking questions.",
|
|
initializer=initializer,
|
|
pre_iteration=pre_iteration,
|
|
after_iteration=after_iteration,
|
|
finalizer=finalizer
|
|
)
|
|
|
|
print(f"Thread agent response: {response}")
|
|
|
|
messages = await manager.list_messages(thread_id)
|
|
print("\nFinal conversation:")
|
|
for msg in messages:
|
|
print(f"{msg['role'].capitalize()}: {msg['content']}")
|
|
|
|
asyncio.run(main()) |