From daa90b5425b16fb6ced9fc24ff48c9a2b804f4ac Mon Sep 17 00:00:00 2001 From: marko-kraemer Date: Sat, 2 Nov 2024 23:56:31 +0100 Subject: [PATCH] make add_tool() accept keyword arguments --- agentpress/thread_manager.py | 9 +- agentpress/tool_registry.py | 14 +- core/thread_manager.py | 368 +++++++++++++++++++++++++++++++++++ 3 files changed, 386 insertions(+), 5 deletions(-) create mode 100644 core/thread_manager.py diff --git a/agentpress/thread_manager.py b/agentpress/thread_manager.py index b36f5017..67921266 100644 --- a/agentpress/thread_manager.py +++ b/agentpress/thread_manager.py @@ -14,13 +14,18 @@ class ThreadManager: self.tool_registry = ToolRegistry() os.makedirs(self.threads_dir, exist_ok=True) - def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None): + def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): """ Add a tool to the ThreadManager. If function_names is provided, only register those specific functions. If function_names is None, register all functions from the tool. + + Args: + tool_class: The tool class to register + function_names: Optional list of function names to register + **kwargs: Additional keyword arguments passed to tool initialization """ - self.tool_registry.register_tool(tool_class, function_names) + self.tool_registry.register_tool(tool_class, function_names, **kwargs) async def create_thread(self) -> str: thread_id = str(uuid.uuid4()) diff --git a/agentpress/tool_registry.py b/agentpress/tool_registry.py index f1d15b2d..78b7a017 100644 --- a/agentpress/tool_registry.py +++ b/agentpress/tool_registry.py @@ -6,8 +6,16 @@ class ToolRegistry: def __init__(self): self.tools: Dict[str, Dict[str, Any]] = {} - def register_tool(self, tool_cls: Type[Tool], function_names: Optional[List[str]] = None): - tool_instance = tool_cls() + def register_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs): + """ + Register a tool with optional function name filtering and initialization parameters. + + Args: + tool_class: The tool class to register + function_names: Optional list of function names to register + **kwargs: Additional keyword arguments passed to tool initialization + """ + tool_instance = tool_class(**kwargs) schemas = tool_instance.get_schemas() if function_names is None: @@ -26,7 +34,7 @@ class ToolRegistry: "schema": schemas[func_name] } else: - raise ValueError(f"Function '{func_name}' not found in {tool_cls.__name__}") + raise ValueError(f"Function '{func_name}' not found in {tool_class.__name__}") def get_tool(self, tool_name: str) -> Dict[str, Any]: return self.tools.get(tool_name, {}) diff --git a/core/thread_manager.py b/core/thread_manager.py new file mode 100644 index 00000000..b36f5017 --- /dev/null +++ b/core/thread_manager.py @@ -0,0 +1,368 @@ +import json +import logging +import asyncio +import os +from typing import List, Dict, Any, Optional, Callable, Type +from agentpress.llm import make_llm_api_call +from agentpress.tool import Tool, ToolResult +from agentpress.tool_registry import ToolRegistry +import uuid + +class ThreadManager: + def __init__(self, threads_dir: str = 'threads'): + self.threads_dir = threads_dir + self.tool_registry = ToolRegistry() + os.makedirs(self.threads_dir, exist_ok=True) + + def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None): + """ + Add a tool to the ThreadManager. + If function_names is provided, only register those specific functions. + If function_names is None, register all functions from the tool. + """ + self.tool_registry.register_tool(tool_class, function_names) + + async def create_thread(self) -> str: + thread_id = str(uuid.uuid4()) + thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") + with open(thread_path, 'w') as f: + json.dump({"messages": []}, f) + return thread_id + + async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None): + logging.info(f"Adding message to thread {thread_id} with images: {images}") + thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") + + try: + with open(thread_path, 'r') as f: + thread_data = json.load(f) + + messages = thread_data["messages"] + + if message_data['role'] == 'user': + last_assistant_index = next((i for i in reversed(range(len(messages))) if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None) + + if last_assistant_index is not None: + tool_call_count = len(messages[last_assistant_index]['tool_calls']) + tool_response_count = sum(1 for msg in messages[last_assistant_index+1:] if msg['role'] == 'tool') + + if tool_call_count != tool_response_count: + await self.cleanup_incomplete_tool_calls(thread_id) + + for key, value in message_data.items(): + if isinstance(value, ToolResult): + message_data[key] = str(value) + + if images: + if isinstance(message_data['content'], str): + message_data['content'] = [{"type": "text", "text": message_data['content']}] + elif not isinstance(message_data['content'], list): + message_data['content'] = [] + + for image in images: + image_content = { + "type": "image_url", + "image_url": { + "url": f"data:{image['content_type']};base64,{image['base64']}", + "detail": "high" + } + } + message_data['content'].append(image_content) + + messages.append(message_data) + thread_data["messages"] = messages + + with open(thread_path, 'w') as f: + json.dump(thread_data, f) + + logging.info(f"Message added to thread {thread_id}: {message_data}") + except Exception as e: + logging.error(f"Failed to add message to thread {thread_id}: {e}") + raise e + + async def list_messages(self, thread_id: str, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]: + thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") + + try: + with open(thread_path, 'r') as f: + thread_data = json.load(f) + messages = thread_data["messages"] + + if only_latest_assistant: + for msg in reversed(messages): + if msg.get('role') == 'assistant': + return [msg] + return [] + + filtered_messages = messages + + if hide_tool_msgs: + filtered_messages = [ + {k: v for k, v in msg.items() if k != 'tool_calls'} + for msg in filtered_messages + if msg.get('role') != 'tool' + ] + + if regular_list: + filtered_messages = [ + msg for msg in filtered_messages + if msg.get('role') in ['system', 'assistant', 'tool', 'user'] + ] + + return filtered_messages + except FileNotFoundError: + return [] + + async def cleanup_incomplete_tool_calls(self, thread_id: str): + messages = await self.list_messages(thread_id) + last_assistant_message = next((m for m in reversed(messages) if m['role'] == 'assistant' and 'tool_calls' in m), None) + + if last_assistant_message: + tool_calls = last_assistant_message.get('tool_calls', []) + tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:] if m['role'] == 'tool'] + + if len(tool_calls) != len(tool_responses): + failed_tool_results = [] + for tool_call in tool_calls[len(tool_responses):]: + failed_tool_result = { + "role": "tool", + "tool_call_id": tool_call['id'], + "name": tool_call['function']['name'], + "content": "ToolResult(success=False, output='Execution interrupted. Session was stopped.')" + } + failed_tool_results.append(failed_tool_result) + + assistant_index = messages.index(last_assistant_message) + messages[assistant_index+1:assistant_index+1] = failed_tool_results + + thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") + with open(thread_path, 'w') as f: + json.dump({"messages": messages}, f) + + return True + return False + + async def run_thread(self, thread_id: str, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = False) -> Dict[str, Any]: + + messages = await self.list_messages(thread_id) + prepared_messages = [system_message] + messages + + if additional_message: + prepared_messages.append(additional_message) + + tools = self.tool_registry.get_all_tool_schemas() if use_tools else None + + try: + llm_response = await make_llm_api_call( + prepared_messages, + model_name, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + tool_choice=tool_choice if use_tools else None, + stream=False + ) + except Exception as e: + logging.error(f"Error in API call: {str(e)}") + return { + "status": "error", + "message": str(e), + "run_thread_params": { + "thread_id": thread_id, + "system_message": system_message, + "model_name": model_name, + "temperature": temperature, + "max_tokens": max_tokens, + "tool_choice": tool_choice, + "additional_message": additional_message, + "execute_tools_async": execute_tools_async, + "execute_model_tool_calls": execute_model_tool_calls, + "use_tools": use_tools + } + } + + if use_tools and execute_model_tool_calls: + await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async) + else: + await self.handle_response_without_tools(thread_id, llm_response) + + return { + "llm_response": llm_response, + "run_thread_params": { + "thread_id": thread_id, + "system_message": system_message, + "model_name": model_name, + "temperature": temperature, + "max_tokens": max_tokens, + "tool_choice": tool_choice, + "additional_message": additional_message, + "execute_tools_async": execute_tools_async, + "execute_model_tool_calls": execute_model_tool_calls, + "use_tools": use_tools + } + } + + async def handle_response_without_tools(self, thread_id: str, response: Any): + response_content = response.choices[0].message['content'] + await self.add_message(thread_id, {"role": "assistant", "content": response_content}) + + async def handle_response_with_tools(self, thread_id: str, response: Any, execute_tools_async: bool): + try: + response_message = response.choices[0].message + tool_calls = response_message.get('tool_calls', []) + + assistant_message = self.create_assistant_message_with_tools(response_message) + await self.add_message(thread_id, assistant_message) + + available_functions = self.get_available_functions() + + if tool_calls: + if execute_tools_async: + tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id) + else: + tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id) + + for result in tool_results: + await self.add_message(thread_id, result) + + except AttributeError as e: + logging.error(f"AttributeError: {e}") + response_content = response.choices[0].message['content'] + await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""}) + + def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]: + message = { + "role": "assistant", + "content": response_message.get('content') or "", + } + tool_calls = response_message.get('tool_calls') + if tool_calls: + message["tool_calls"] = [ + { + "id": tool_call.id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + } for tool_call in tool_calls + ] + return message + + def get_available_functions(self) -> Dict[str, Callable]: + available_functions = {} + for tool_name, tool_info in self.tool_registry.get_all_tools().items(): + tool_instance = tool_info['instance'] + for func_name, func in tool_instance.__class__.__dict__.items(): + if callable(func) and not func_name.startswith("__"): + available_functions[func_name] = getattr(tool_instance, func_name) + return available_functions + + async def execute_tools_async(self, tool_calls, available_functions, thread_id): + async def execute_single_tool(tool_call): + function_name = tool_call.function.name + function_args = json.loads(tool_call.function.arguments) + tool_call_id = tool_call.id + + function_to_call = available_functions.get(function_name) + if function_to_call: + return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id) + else: + logging.warning(f"Function {function_name} not found in available functions") + return None + + tool_results = await asyncio.gather(*[execute_single_tool(tool_call) for tool_call in tool_calls]) + return [result for result in tool_results if result] + + async def execute_tools_sync(self, tool_calls, available_functions, thread_id): + tool_results = [] + for tool_call in tool_calls: + function_name = tool_call.function.name + function_args = json.loads(tool_call.function.arguments) + tool_call_id = tool_call.id + + function_to_call = available_functions.get(function_name) + if function_to_call: + result = await self.execute_tool(function_to_call, function_args, function_name, tool_call_id) + if result: + tool_results.append(result) + else: + logging.warning(f"Function {function_name} not found in available functions") + + return tool_results + + async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id): + try: + function_response = await function_to_call(**function_args) + except Exception as e: + error_message = f"Error in {function_name}: {str(e)}" + function_response = ToolResult(success=False, output=error_message) + + return { + "role": "tool", + "tool_call_id": tool_call_id, + "name": function_name, + "content": str(function_response), + } + + async def get_thread(self, thread_id: str) -> Optional[Dict[str, Any]]: + thread_path = os.path.join(self.threads_dir, f"{thread_id}.json") + try: + with open(thread_path, 'r') as f: + return json.load(f) + except FileNotFoundError: + return None + +if __name__ == "__main__": + import asyncio + from tools.files_tool import FilesTool + + async def main(): + manager = ThreadManager() + + manager.add_tool(FilesTool, ['create_file', 'read_file']) + + thread_id = await manager.create_thread() + + await manager.add_message(thread_id, {"role": "user", "content": "Please create a file with a random name with the content 'Hello, world!'"}) + + system_message = {"role": "system", "content": "You are a helpful assistant that can create, read, update, and delete files."} + model_name = "gpt-4o" + + # Test with tools + response_with_tools = await manager.run_thread( + thread_id=thread_id, + system_message=system_message, + model_name=model_name, + temperature=0.7, + max_tokens=150, + tool_choice="auto", + additional_message=None, + execute_tools_async=True, + execute_model_tool_calls=True, + use_tools=True + ) + + print("Response with tools:", response_with_tools) + + # Test without tools + response_without_tools = await manager.run_thread( + thread_id=thread_id, + system_message=system_message, + model_name=model_name, + temperature=0.7, + max_tokens=150, + additional_message={"role": "user", "content": "What's the capital of France?"}, + use_tools=False + ) + + print("Response without tools:", response_without_tools) + + # List messages in the thread + messages = await manager.list_messages(thread_id) + print("\nMessages in the thread:") + for msg in messages: + print(f"{msg['role'].capitalize()}: {msg['content']}") + + # Run the async main function + asyncio.run(main())