mirror of https://github.com/kortix-ai/suna.git
make add_tool() accept keyword arguments
This commit is contained in:
parent
d77a2dad7c
commit
daa90b5425
|
@ -14,13 +14,18 @@ class ThreadManager:
|
||||||
self.tool_registry = ToolRegistry()
|
self.tool_registry = ToolRegistry()
|
||||||
os.makedirs(self.threads_dir, exist_ok=True)
|
os.makedirs(self.threads_dir, exist_ok=True)
|
||||||
|
|
||||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None):
|
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Add a tool to the ThreadManager.
|
Add a tool to the ThreadManager.
|
||||||
If function_names is provided, only register those specific functions.
|
If function_names is provided, only register those specific functions.
|
||||||
If function_names is None, register all functions from the tool.
|
If function_names is None, register all functions from the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_class: The tool class to register
|
||||||
|
function_names: Optional list of function names to register
|
||||||
|
**kwargs: Additional keyword arguments passed to tool initialization
|
||||||
"""
|
"""
|
||||||
self.tool_registry.register_tool(tool_class, function_names)
|
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
||||||
|
|
||||||
async def create_thread(self) -> str:
|
async def create_thread(self) -> str:
|
||||||
thread_id = str(uuid.uuid4())
|
thread_id = str(uuid.uuid4())
|
||||||
|
|
|
@ -6,8 +6,16 @@ class ToolRegistry:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tools: Dict[str, Dict[str, Any]] = {}
|
self.tools: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
def register_tool(self, tool_cls: Type[Tool], function_names: Optional[List[str]] = None):
|
def register_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||||
tool_instance = tool_cls()
|
"""
|
||||||
|
Register a tool with optional function name filtering and initialization parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_class: The tool class to register
|
||||||
|
function_names: Optional list of function names to register
|
||||||
|
**kwargs: Additional keyword arguments passed to tool initialization
|
||||||
|
"""
|
||||||
|
tool_instance = tool_class(**kwargs)
|
||||||
schemas = tool_instance.get_schemas()
|
schemas = tool_instance.get_schemas()
|
||||||
|
|
||||||
if function_names is None:
|
if function_names is None:
|
||||||
|
@ -26,7 +34,7 @@ class ToolRegistry:
|
||||||
"schema": schemas[func_name]
|
"schema": schemas[func_name]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Function '{func_name}' not found in {tool_cls.__name__}")
|
raise ValueError(f"Function '{func_name}' not found in {tool_class.__name__}")
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Dict[str, Any]:
|
def get_tool(self, tool_name: str) -> Dict[str, Any]:
|
||||||
return self.tools.get(tool_name, {})
|
return self.tools.get(tool_name, {})
|
||||||
|
|
|
@ -0,0 +1,368 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional, Callable, Type
|
||||||
|
from agentpress.llm import make_llm_api_call
|
||||||
|
from agentpress.tool import Tool, ToolResult
|
||||||
|
from agentpress.tool_registry import ToolRegistry
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
class ThreadManager:
|
||||||
|
def __init__(self, threads_dir: str = 'threads'):
|
||||||
|
self.threads_dir = threads_dir
|
||||||
|
self.tool_registry = ToolRegistry()
|
||||||
|
os.makedirs(self.threads_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None):
|
||||||
|
"""
|
||||||
|
Add a tool to the ThreadManager.
|
||||||
|
If function_names is provided, only register those specific functions.
|
||||||
|
If function_names is None, register all functions from the tool.
|
||||||
|
"""
|
||||||
|
self.tool_registry.register_tool(tool_class, function_names)
|
||||||
|
|
||||||
|
async def create_thread(self) -> str:
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||||
|
with open(thread_path, 'w') as f:
|
||||||
|
json.dump({"messages": []}, f)
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
|
||||||
|
logging.info(f"Adding message to thread {thread_id} with images: {images}")
|
||||||
|
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(thread_path, 'r') as f:
|
||||||
|
thread_data = json.load(f)
|
||||||
|
|
||||||
|
messages = thread_data["messages"]
|
||||||
|
|
||||||
|
if message_data['role'] == 'user':
|
||||||
|
last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
|
||||||
|
|
||||||
|
if last_assistant_index is not None:
|
||||||
|
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
|
||||||
|
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:] if msg['role'] == 'tool')
|
||||||
|
|
||||||
|
if tool_call_count != tool_response_count:
|
||||||
|
await self.cleanup_incomplete_tool_calls(thread_id)
|
||||||
|
|
||||||
|
for key, value in message_data.items():
|
||||||
|
if isinstance(value, ToolResult):
|
||||||
|
message_data[key] = str(value)
|
||||||
|
|
||||||
|
if images:
|
||||||
|
if isinstance(message_data['content'], str):
|
||||||
|
message_data['content'] = [{"type": "text", "text": message_data['content']}]
|
||||||
|
elif not isinstance(message_data['content'], list):
|
||||||
|
message_data['content'] = []
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
image_content = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:{image['content_type']};base64,{image['base64']}",
|
||||||
|
"detail": "high"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message_data['content'].append(image_content)
|
||||||
|
|
||||||
|
messages.append(message_data)
|
||||||
|
thread_data["messages"] = messages
|
||||||
|
|
||||||
|
with open(thread_path, 'w') as f:
|
||||||
|
json.dump(thread_data, f)
|
||||||
|
|
||||||
|
logging.info(f"Message added to thread {thread_id}: {message_data}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to add message to thread {thread_id}: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def list_messages(self, thread_id: str, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]:
|
||||||
|
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(thread_path, 'r') as f:
|
||||||
|
thread_data = json.load(f)
|
||||||
|
messages = thread_data["messages"]
|
||||||
|
|
||||||
|
if only_latest_assistant:
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get('role') == 'assistant':
|
||||||
|
return [msg]
|
||||||
|
return []
|
||||||
|
|
||||||
|
filtered_messages = messages
|
||||||
|
|
||||||
|
if hide_tool_msgs:
|
||||||
|
filtered_messages = [
|
||||||
|
{k: v for k, v in msg.items() if k != 'tool_calls'}
|
||||||
|
for msg in filtered_messages
|
||||||
|
if msg.get('role') != 'tool'
|
||||||
|
]
|
||||||
|
|
||||||
|
if regular_list:
|
||||||
|
filtered_messages = [
|
||||||
|
msg for msg in filtered_messages
|
||||||
|
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_messages
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def cleanup_incomplete_tool_calls(self, thread_id: str):
|
||||||
|
messages = await self.list_messages(thread_id)
|
||||||
|
last_assistant_message = next((m for m in reversed(messages) if m['role'] == 'assistant' and 'tool_calls' in m), None)
|
||||||
|
|
||||||
|
if last_assistant_message:
|
||||||
|
tool_calls = last_assistant_message.get('tool_calls', [])
|
||||||
|
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool']
|
||||||
|
|
||||||
|
if len(tool_calls) != len(tool_responses):
|
||||||
|
failed_tool_results = []
|
||||||
|
for tool_call in tool_calls[len(tool_responses):]:
|
||||||
|
failed_tool_result = {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call['id'],
|
||||||
|
"name": tool_call['function']['name'],
|
||||||
|
"content": "ToolResult(success=False, output='Execution interrupted. Session was stopped.')"
|
||||||
|
}
|
||||||
|
failed_tool_results.append(failed_tool_result)
|
||||||
|
|
||||||
|
assistant_index = messages.index(last_assistant_message)
|
||||||
|
messages[assistant_index+1:assistant_index+1] = failed_tool_results
|
||||||
|
|
||||||
|
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||||
|
with open(thread_path, 'w') as f:
|
||||||
|
json.dump({"messages": messages}, f)
|
||||||
|
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def run_thread(self, thread_id: str, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = False) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
messages = await self.list_messages(thread_id)
|
||||||
|
prepared_messages = [system_message] + messages
|
||||||
|
|
||||||
|
if additional_message:
|
||||||
|
prepared_messages.append(additional_message)
|
||||||
|
|
||||||
|
tools = self.tool_registry.get_all_tool_schemas() if use_tools else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_response = await make_llm_api_call(
|
||||||
|
prepared_messages,
|
||||||
|
model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice if use_tools else None,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in API call: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"run_thread_params": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"system_message": system_message,
|
||||||
|
"model_name": model_name,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"tool_choice": tool_choice,
|
||||||
|
"additional_message": additional_message,
|
||||||
|
"execute_tools_async": execute_tools_async,
|
||||||
|
"execute_model_tool_calls": execute_model_tool_calls,
|
||||||
|
"use_tools": use_tools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_tools and execute_model_tool_calls:
|
||||||
|
await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async)
|
||||||
|
else:
|
||||||
|
await self.handle_response_without_tools(thread_id, llm_response)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"llm_response": llm_response,
|
||||||
|
"run_thread_params": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"system_message": system_message,
|
||||||
|
"model_name": model_name,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"tool_choice": tool_choice,
|
||||||
|
"additional_message": additional_message,
|
||||||
|
"execute_tools_async": execute_tools_async,
|
||||||
|
"execute_model_tool_calls": execute_model_tool_calls,
|
||||||
|
"use_tools": use_tools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def handle_response_without_tools(self, thread_id: str, response: Any):
|
||||||
|
response_content = response.choices[0].message['content']
|
||||||
|
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
||||||
|
|
||||||
|
async def handle_response_with_tools(self, thread_id: str, response: Any, execute_tools_async: bool):
|
||||||
|
try:
|
||||||
|
response_message = response.choices[0].message
|
||||||
|
tool_calls = response_message.get('tool_calls', [])
|
||||||
|
|
||||||
|
assistant_message = self.create_assistant_message_with_tools(response_message)
|
||||||
|
await self.add_message(thread_id, assistant_message)
|
||||||
|
|
||||||
|
available_functions = self.get_available_functions()
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
if execute_tools_async:
|
||||||
|
tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id)
|
||||||
|
else:
|
||||||
|
tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id)
|
||||||
|
|
||||||
|
for result in tool_results:
|
||||||
|
await self.add_message(thread_id, result)
|
||||||
|
|
||||||
|
except AttributeError as e:
|
||||||
|
logging.error(f"AttributeError: {e}")
|
||||||
|
response_content = response.choices[0].message['content']
|
||||||
|
await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""})
|
||||||
|
|
||||||
|
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
||||||
|
message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response_message.get('content') or "",
|
||||||
|
}
|
||||||
|
tool_calls = response_message.get('tool_calls')
|
||||||
|
if tool_calls:
|
||||||
|
message["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tool_call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": tool_call.function.arguments
|
||||||
|
}
|
||||||
|
} for tool_call in tool_calls
|
||||||
|
]
|
||||||
|
return message
|
||||||
|
|
||||||
|
def get_available_functions(self) -> Dict[str, Callable]:
|
||||||
|
available_functions = {}
|
||||||
|
for tool_name, tool_info in self.tool_registry.get_all_tools().items():
|
||||||
|
tool_instance = tool_info['instance']
|
||||||
|
for func_name, func in tool_instance.__class__.__dict__.items():
|
||||||
|
if callable(func) and not func_name.startswith("__"):
|
||||||
|
available_functions[func_name] = getattr(tool_instance, func_name)
|
||||||
|
return available_functions
|
||||||
|
|
||||||
|
async def execute_tools_async(self, tool_calls, available_functions, thread_id):
|
||||||
|
async def execute_single_tool(tool_call):
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
tool_call_id = tool_call.id
|
||||||
|
|
||||||
|
function_to_call = available_functions.get(function_name)
|
||||||
|
if function_to_call:
|
||||||
|
return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Function {function_name} not found in available functions")
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_results = await asyncio.gather(*[execute_single_tool(tool_call) for tool_call in tool_calls])
|
||||||
|
return [result for result in tool_results if result]
|
||||||
|
|
||||||
|
async def execute_tools_sync(self, tool_calls, available_functions, thread_id):
|
||||||
|
tool_results = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
tool_call_id = tool_call.id
|
||||||
|
|
||||||
|
function_to_call = available_functions.get(function_name)
|
||||||
|
if function_to_call:
|
||||||
|
result = await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
||||||
|
if result:
|
||||||
|
tool_results.append(result)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Function {function_name} not found in available functions")
|
||||||
|
|
||||||
|
return tool_results
|
||||||
|
|
||||||
|
async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id):
|
||||||
|
try:
|
||||||
|
function_response = await function_to_call(**function_args)
|
||||||
|
except Exception as e:
|
||||||
|
error_message = f"Error in {function_name}: {str(e)}"
|
||||||
|
function_response = ToolResult(success=False, output=error_message)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"name": function_name,
|
||||||
|
"content": str(function_response),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||||
|
try:
|
||||||
|
with open(thread_path, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
from tools.files_tool import FilesTool
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
manager = ThreadManager()
|
||||||
|
|
||||||
|
manager.add_tool(FilesTool, ['create_file', 'read_file'])
|
||||||
|
|
||||||
|
thread_id = await manager.create_thread()
|
||||||
|
|
||||||
|
await manager.add_message(thread_id, {"role": "user", "content": "Please create a file with a random name with the content 'Hello, world!'"})
|
||||||
|
|
||||||
|
system_message = {"role": "system", "content": "You are a helpful assistant that can create, read, update, and delete files."}
|
||||||
|
model_name = "gpt-4o"
|
||||||
|
|
||||||
|
# Test with tools
|
||||||
|
response_with_tools = await manager.run_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=150,
|
||||||
|
tool_choice="auto",
|
||||||
|
additional_message=None,
|
||||||
|
execute_tools_async=True,
|
||||||
|
execute_model_tool_calls=True,
|
||||||
|
use_tools=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response with tools:", response_with_tools)
|
||||||
|
|
||||||
|
# Test without tools
|
||||||
|
response_without_tools = await manager.run_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=150,
|
||||||
|
additional_message={"role": "user", "content": "What's the capital of France?"},
|
||||||
|
use_tools=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response without tools:", response_without_tools)
|
||||||
|
|
||||||
|
# List messages in the thread
|
||||||
|
messages = await manager.list_messages(thread_id)
|
||||||
|
print("\nMessages in the thread:")
|
||||||
|
for msg in messages:
|
||||||
|
print(f"{msg['role'].capitalize()}: {msg['content']}")
|
||||||
|
|
||||||
|
# Run the async main function
|
||||||
|
asyncio.run(main())
|
Loading…
Reference in New Issue