mirror of https://github.com/kortix-ai/suna.git
Run Thread response streaming + Tool Parser, Tool Exec refactor (#10)
This commit is contained in:
parent
816c287a76
commit
25e1afaa99
|
@ -163,5 +163,5 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# AgentPress Threads
|
||||
# AgentPress
|
||||
/threads
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
0.1.8
|
||||
- New Tool Parser Base Class
|
||||
- XML Tool Parser
|
||||
- New Tool Executor Base Class
|
||||
- Execute_tools_on_stream
|
||||
- Documentation
|
||||
|
||||
0.1.7
|
||||
- Streaming Responses with Tool Calls
|
14
README.md
14
README.md
|
@ -89,7 +89,7 @@ async def main():
|
|||
},
|
||||
model_name="gpt-4",
|
||||
use_tools=True,
|
||||
execute_model_tool_calls=True
|
||||
execute_tool_calls=True
|
||||
)
|
||||
print("Response:", response)
|
||||
|
||||
|
@ -160,18 +160,10 @@ pip install poetry
|
|||
poetry install
|
||||
```
|
||||
|
||||
3. Build the package:
|
||||
3. For quick testing, you can install directly from the current directory:
|
||||
```bash
|
||||
poetry build
|
||||
pip install -e .
|
||||
```
|
||||
It will return the built package name with the version number.
|
||||
|
||||
4. Install the package with the correct version number, here for example its 0.1.3 `agentpress-0.1.3-py3-none-any.whl`:
|
||||
```bash
|
||||
pip install /Users/markokraemer/Projects/agentpress/dist/agentpress-0.1.3-py3-none-any.whl --force-reinstall
|
||||
```
|
||||
Then you can test that version.
|
||||
|
||||
|
||||
## License
|
||||
|
||||
|
|
|
@ -4,6 +4,9 @@ from agentpress.thread_manager import ThreadManager
|
|||
from tools.files_tool import FilesTool
|
||||
from agentpress.state_manager import StateManager
|
||||
from tools.terminal_tool import TerminalTool
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
import sys
|
||||
|
||||
async def run_agent(thread_id: str, max_iterations: int = 5):
|
||||
# Initialize managers and tools
|
||||
|
@ -111,13 +114,45 @@ Current development environment workspace state:
|
|||
temperature=0.1,
|
||||
max_tokens=8096,
|
||||
tool_choice="auto",
|
||||
additional_message=state_message,
|
||||
temporary_message=state_message,
|
||||
execute_tools_async=True,
|
||||
use_tools=True,
|
||||
execute_model_tool_calls=True
|
||||
execute_tool_calls=True,
|
||||
stream=True,
|
||||
execute_tools_on_stream=True
|
||||
)
|
||||
|
||||
print(response)
|
||||
# Handle streaming response
|
||||
if isinstance(response, AsyncGenerator):
|
||||
print("\n🤖 Assistant is responding:")
|
||||
content_buffer = ""
|
||||
try:
|
||||
async for chunk in response:
|
||||
if hasattr(chunk.choices[0], 'delta'):
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle content streaming
|
||||
if hasattr(delta, 'content') and delta.content is not None:
|
||||
print(delta.content, end='', flush=True)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
# Print tool name when it first appears
|
||||
if tool_call.function and tool_call.function.name:
|
||||
print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
|
||||
|
||||
# Print arguments as they stream in
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
print(f" {tool_call.function.arguments}", end='', flush=True)
|
||||
|
||||
print("\n✨ Response completed\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error processing stream: {e}", file=sys.stderr)
|
||||
logging.error(f"Error processing stream: {e}")
|
||||
else:
|
||||
print("\n❌ Non-streaming response received:", response)
|
||||
|
||||
# Call after_iteration without arguments
|
||||
await after_iteration()
|
||||
|
@ -134,7 +169,7 @@ if __name__ == "__main__":
|
|||
thread_id,
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Let's create a identical 1to1 Airbnb Clone using HTML, CSS, Javascript. Use images from pixabay, pexels, and co."
|
||||
"content": "Create a simple landing page with a header, hero section, and footer. Use modern CSS styling."
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import openai
|
|||
from openai import OpenAIError
|
||||
import asyncio
|
||||
import logging
|
||||
# import agentops
|
||||
|
||||
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
|
||||
ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY')
|
||||
|
@ -17,13 +16,62 @@ os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
|
|||
os.environ['ANTHROPIC_API_KEY'] = ANTHROPIC_API_KEY
|
||||
os.environ['GROQ_API_KEY'] = GROQ_API_KEY
|
||||
|
||||
# agentops.init(AGENTOPS_API_KEY)
|
||||
# os.environ['LITELLM_LOG'] = 'DEBUG'
|
||||
|
||||
async def make_llm_api_call(messages, model_name, response_format=None, temperature=0, max_tokens=None, tools=None, tool_choice="auto", api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None):
|
||||
litellm.set_verbose = True
|
||||
async def make_llm_api_call(
|
||||
messages: list,
|
||||
model_name: str,
|
||||
response_format: Any = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = None,
|
||||
tools: list = None,
|
||||
tool_choice: str = "auto",
|
||||
api_key: str = None,
|
||||
api_base: str = None,
|
||||
agentops_session: Any = None,
|
||||
stream: bool = False,
|
||||
top_p: float = None
|
||||
) -> Union[Dict[str, Any], Any]:
|
||||
"""
|
||||
Make an API call to a language model using litellm.
|
||||
|
||||
This function provides a unified interface for making calls to various LLM providers
|
||||
(OpenAI, Anthropic, Groq, etc.) with support for streaming, tool calls, and retry logic.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries for the conversation
|
||||
model_name (str): Name of the model to use (e.g., "gpt-4", "claude-3")
|
||||
response_format (Any, optional): Desired format for the response
|
||||
temperature (float, optional): Sampling temperature. Defaults to 0
|
||||
max_tokens (int, optional): Maximum tokens in the response
|
||||
tools (list, optional): List of tool definitions for function calling
|
||||
tool_choice (str, optional): How to select tools ("auto" or "none")
|
||||
api_key (str, optional): Override default API key
|
||||
api_base (str, optional): Override default API base URL
|
||||
agentops_session (Any, optional): Session for agentops integration
|
||||
stream (bool, optional): Whether to stream the response. Defaults to False
|
||||
top_p (float, optional): Top-p sampling parameter
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], Any]: API response, either complete or streaming
|
||||
|
||||
Raises:
|
||||
Exception: If API call fails after retries
|
||||
"""
|
||||
# litellm.set_verbose = True
|
||||
|
||||
async def attempt_api_call(api_call_func, max_attempts=3):
|
||||
"""
|
||||
Attempt an API call with retries.
|
||||
|
||||
Args:
|
||||
api_call_func: Async function that makes the API call
|
||||
max_attempts (int): Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
API response if successful
|
||||
|
||||
Raises:
|
||||
Exception: If all retry attempts fail
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return await api_call_func()
|
||||
|
@ -39,6 +87,12 @@ async def make_llm_api_call(messages, model_name, response_format=None, temperat
|
|||
raise Exception("Failed to make API call after multiple attempts.")
|
||||
|
||||
async def api_call():
|
||||
"""
|
||||
Prepare and execute the API call with the specified parameters.
|
||||
|
||||
Returns:
|
||||
API response from the language model
|
||||
"""
|
||||
api_call_params = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
|
@ -48,13 +102,13 @@ async def make_llm_api_call(messages, model_name, response_format=None, temperat
|
|||
"stream": stream,
|
||||
}
|
||||
|
||||
# Add api_key and api_base if provided
|
||||
# Add optional parameters if provided
|
||||
if api_key:
|
||||
api_call_params["api_key"] = api_key
|
||||
if api_base:
|
||||
api_call_params["api_base"] = api_base
|
||||
|
||||
# Use 'max_completion_tokens' for 'o1' models, otherwise use 'max_tokens'
|
||||
# Handle token limits differently for different models
|
||||
if 'o1' in model_name:
|
||||
if max_tokens is not None:
|
||||
api_call_params["max_completion_tokens"] = max_tokens
|
||||
|
@ -63,10 +117,10 @@ async def make_llm_api_call(messages, model_name, response_format=None, temperat
|
|||
api_call_params["max_tokens"] = max_tokens
|
||||
|
||||
if tools:
|
||||
# Use the existing method of adding tools
|
||||
api_call_params["tools"] = tools
|
||||
api_call_params["tool_choice"] = tool_choice
|
||||
|
||||
# Add special headers for Claude models
|
||||
if "claude" in model_name.lower() or "anthropic" in model_name.lower():
|
||||
api_call_params["extra_headers"] = {
|
||||
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"
|
||||
|
@ -75,6 +129,7 @@ async def make_llm_api_call(messages, model_name, response_format=None, temperat
|
|||
# Log the API request
|
||||
# logging.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}")
|
||||
|
||||
# Make the API call using either agentops session or direct litellm
|
||||
if agentops_session:
|
||||
response = await agentops_session.patch(litellm.acompletion)(**api_call_params)
|
||||
else:
|
||||
|
@ -87,10 +142,15 @@ async def make_llm_api_call(messages, model_name, response_format=None, temperat
|
|||
|
||||
return await attempt_api_call(api_call)
|
||||
|
||||
# Sample Usage
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
async def test_llm_api_call(stream=True):
|
||||
"""
|
||||
Test function for the LLM API call functionality.
|
||||
|
||||
Args:
|
||||
stream (bool): Whether to test streaming mode
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Complex essay on economics"}
|
||||
|
@ -106,17 +166,14 @@ if __name__ == "__main__":
|
|||
if isinstance(chunk, dict) and 'choices' in chunk:
|
||||
content = chunk['choices'][0]['delta'].get('content', '')
|
||||
else:
|
||||
# For non-dict responses (like ModelResponse objects)
|
||||
content = chunk.choices[0].delta.content
|
||||
|
||||
if content:
|
||||
buffer += content
|
||||
# Print complete words/sentences when we hit whitespace
|
||||
if content[-1].isspace():
|
||||
print(buffer, end='', flush=True)
|
||||
buffer = ""
|
||||
|
||||
# Print any remaining content
|
||||
if buffer:
|
||||
print(buffer, flush=True)
|
||||
print("\n✨ Stream completed.\n")
|
||||
|
@ -125,12 +182,7 @@ if __name__ == "__main__":
|
|||
if isinstance(response, dict) and 'choices' in response:
|
||||
print(response['choices'][0]['message']['content'])
|
||||
else:
|
||||
# For non-dict responses (like ModelResponse objects)
|
||||
print(response.choices[0].message.content)
|
||||
print()
|
||||
|
||||
# Example usage:
|
||||
# asyncio.run(test_llm_api_call(stream=True)) # For streaming
|
||||
# asyncio.run(test_llm_api_call(stream=False)) # For non-streaming
|
||||
|
||||
asyncio.run(test_llm_api_call())
|
||||
|
|
|
@ -6,12 +6,26 @@ from asyncio import Lock
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
class StateManager:
|
||||
"""
|
||||
Manages persistent state storage for AgentPress components.
|
||||
|
||||
The StateManager provides thread-safe access to a JSON-based state store,
|
||||
allowing components to save and retrieve data across sessions. It handles
|
||||
concurrent access using asyncio locks and provides atomic operations for
|
||||
state modifications.
|
||||
|
||||
Attributes:
|
||||
lock (Lock): Asyncio lock for thread-safe state access
|
||||
store_file (str): Path to the JSON file storing the state
|
||||
"""
|
||||
|
||||
def __init__(self, store_file: str = "state.json"):
|
||||
"""
|
||||
Initialize StateManager with custom store file name.
|
||||
|
||||
Args:
|
||||
store_file: Name of the JSON file to store state (default: "state.json")
|
||||
store_file (str): Path to the JSON file to store state.
|
||||
Defaults to "state.json" in the current directory.
|
||||
"""
|
||||
self.lock = Lock()
|
||||
self.store_file = store_file
|
||||
|
@ -19,6 +33,19 @@ class StateManager:
|
|||
|
||||
@asynccontextmanager
|
||||
async def store_scope(self):
|
||||
"""
|
||||
Context manager for atomic state operations.
|
||||
|
||||
Provides thread-safe access to the state store, handling file I/O
|
||||
and ensuring proper cleanup. Automatically loads the current state
|
||||
and saves changes when the context exits.
|
||||
|
||||
Yields:
|
||||
dict: The current state store contents
|
||||
|
||||
Raises:
|
||||
Exception: If there are errors reading from or writing to the store file
|
||||
"""
|
||||
try:
|
||||
# Read current state
|
||||
if os.path.exists(self.store_file):
|
||||
|
@ -42,8 +69,14 @@ class StateManager:
|
|||
Store any JSON-serializable data with a simple key.
|
||||
|
||||
Args:
|
||||
key: Simple string key like "config" or "settings"
|
||||
data: Any JSON-serializable data (dict, list, str, int, bool, etc)
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
data (Any): Any JSON-serializable data (dict, list, str, int, bool, etc)
|
||||
|
||||
Returns:
|
||||
Any: The stored data
|
||||
|
||||
Raises:
|
||||
Exception: If there are errors during storage operation
|
||||
"""
|
||||
async with self.lock:
|
||||
async with self.store_scope() as store:
|
||||
|
@ -60,7 +93,13 @@ class StateManager:
|
|||
Get data for a key.
|
||||
|
||||
Args:
|
||||
key: Simple string key like "config" or "settings"
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
|
||||
Returns:
|
||||
Any: The stored data for the key, or None if key not found
|
||||
|
||||
Note:
|
||||
This operation is read-only and doesn't require locking
|
||||
"""
|
||||
async with self.store_scope() as store:
|
||||
if key in store:
|
||||
|
@ -71,7 +110,15 @@ class StateManager:
|
|||
return None
|
||||
|
||||
async def delete(self, key: str):
|
||||
"""Delete data for a key"""
|
||||
"""
|
||||
Delete data for a key.
|
||||
|
||||
Args:
|
||||
key (str): Simple string key like "config" or "settings"
|
||||
|
||||
Note:
|
||||
No error is raised if the key doesn't exist
|
||||
"""
|
||||
async with self.lock:
|
||||
async with self.store_scope() as store:
|
||||
if key in store:
|
||||
|
@ -81,13 +128,26 @@ class StateManager:
|
|||
logging.info(f"Key not found for deletion: {key}")
|
||||
|
||||
async def export_store(self) -> dict:
|
||||
"""Export entire store"""
|
||||
"""
|
||||
Export entire store.
|
||||
|
||||
Returns:
|
||||
dict: Complete contents of the state store
|
||||
|
||||
Note:
|
||||
This operation is read-only and returns a copy of the store
|
||||
"""
|
||||
async with self.store_scope() as store:
|
||||
logging.info(f"Store content: {store}")
|
||||
return store
|
||||
|
||||
async def clear_store(self):
|
||||
"""Clear entire store"""
|
||||
"""
|
||||
Clear entire store.
|
||||
|
||||
Removes all data from the store, resetting it to an empty state.
|
||||
This operation is atomic and thread-safe.
|
||||
"""
|
||||
async with self.lock:
|
||||
async with self.store_scope() as store:
|
||||
store.clear()
|
||||
|
|
|
@ -6,12 +6,45 @@ from typing import List, Dict, Any, Optional, Callable, Type, Union, AsyncGenera
|
|||
from agentpress.llm import make_llm_api_call
|
||||
from agentpress.tool import Tool, ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
from agentpress.tool_parser import ToolParser, StandardToolParser
|
||||
from agentpress.tool_executor import ToolExecutor, StandardToolExecutor, SequentialToolExecutor
|
||||
import uuid
|
||||
|
||||
class ThreadManager:
|
||||
def __init__(self, threads_dir: str = "threads"):
|
||||
"""
|
||||
Manages conversation threads with LLM models and tool execution.
|
||||
|
||||
The ThreadManager handles:
|
||||
- Creating and managing conversation threads
|
||||
- Adding/retrieving messages in threads
|
||||
- Executing LLM calls with optional tool usage
|
||||
- Managing tool registration and execution
|
||||
- Supporting both streaming and non-streaming responses
|
||||
|
||||
Attributes:
|
||||
threads_dir (str): Directory where thread files are stored
|
||||
tool_registry (ToolRegistry): Registry for managing available tools
|
||||
tool_parser (ToolParser): Parser for handling tool calls/responses
|
||||
tool_executor (ToolExecutor): Executor for running tool functions
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threads_dir: str = "threads",
|
||||
tool_parser: Optional[ToolParser] = None,
|
||||
tool_executor: Optional[ToolExecutor] = None
|
||||
):
|
||||
"""Initialize ThreadManager with optional custom tool parser and executor.
|
||||
|
||||
Args:
|
||||
threads_dir (str): Directory to store thread files
|
||||
tool_parser (Optional[ToolParser]): Custom tool parser implementation
|
||||
tool_executor (Optional[ToolExecutor]): Custom tool executor implementation
|
||||
"""
|
||||
self.threads_dir = threads_dir
|
||||
self.tool_registry = ToolRegistry()
|
||||
self.tool_parser = tool_parser or StandardToolParser()
|
||||
self.tool_executor = tool_executor or StandardToolExecutor()
|
||||
os.makedirs(self.threads_dir, exist_ok=True)
|
||||
|
||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||
|
@ -28,6 +61,12 @@ class ThreadManager:
|
|||
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
||||
|
||||
async def create_thread(self) -> str:
|
||||
"""
|
||||
Create a new conversation thread.
|
||||
|
||||
Returns:
|
||||
str: Unique thread ID for the created thread
|
||||
"""
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
with open(thread_path, 'w') as f:
|
||||
|
@ -35,6 +74,18 @@ class ThreadManager:
|
|||
return thread_id
|
||||
|
||||
async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
|
||||
"""
|
||||
Add a message to an existing thread.
|
||||
|
||||
Args:
|
||||
thread_id (str): ID of the thread to add message to
|
||||
message_data (Dict[str, Any]): Message data including role and content
|
||||
images (Optional[List[Dict[str, Any]]]): List of image data to include
|
||||
Each image dict should contain 'content_type' and 'base64' keys
|
||||
|
||||
Raises:
|
||||
Exception: If message addition fails
|
||||
"""
|
||||
logging.info(f"Adding message to thread {thread_id} with images: {images}")
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
|
||||
|
@ -86,6 +137,18 @@ class ThreadManager:
|
|||
raise e
|
||||
|
||||
async def list_messages(self, thread_id: str, hide_tool_msgs: bool = False, only_latest_assistant: bool = False, regular_list: bool = True) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve messages from a thread with optional filtering.
|
||||
|
||||
Args:
|
||||
thread_id (str): ID of the thread to retrieve messages from
|
||||
hide_tool_msgs (bool): If True, excludes tool messages and tool calls
|
||||
only_latest_assistant (bool): If True, returns only the most recent assistant message
|
||||
regular_list (bool): If True, only includes standard message types
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of messages matching the filter criteria
|
||||
"""
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
|
||||
try:
|
||||
|
@ -147,16 +210,51 @@ class ThreadManager:
|
|||
return True
|
||||
return False
|
||||
|
||||
async def run_thread(self, thread_id: str, system_message: Dict[str, Any], model_name: str, temperature: float = 0, max_tokens: Optional[int] = None, tool_choice: str = "auto", additional_message: Optional[Dict[str, Any]] = None, execute_tools_async: bool = True, execute_model_tool_calls: bool = True, use_tools: bool = False, stream: bool = False) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
async def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
system_message: Dict[str, Any],
|
||||
model_name: str,
|
||||
temperature: float = 0,
|
||||
max_tokens: Optional[int] = None,
|
||||
tool_choice: str = "auto",
|
||||
temporary_message: Optional[Dict[str, Any]] = None,
|
||||
use_tools: bool = False,
|
||||
execute_tools_async: bool = True,
|
||||
execute_tool_calls: bool = True,
|
||||
stream: bool = False,
|
||||
execute_tools_on_stream: bool = False
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""
|
||||
Run a thread with the given parameters. If stream=True, returns an AsyncGenerator that yields response chunks.
|
||||
Otherwise returns a Dict with the complete response.
|
||||
Run a conversation thread with the specified parameters.
|
||||
|
||||
Args:
|
||||
thread_id (str): ID of the thread to run
|
||||
system_message (Dict[str, Any]): System message to guide model behavior
|
||||
model_name (str): Name of the LLM model to use
|
||||
temperature (float): Sampling temperature for model responses
|
||||
max_tokens (Optional[int]): Maximum tokens in model response
|
||||
tool_choice (str): How tools should be selected ('auto' or 'none')
|
||||
temporary_message (Optional[Dict[str, Any]]): Extra temporary message to include at the end of the LLM api request. Without adding it permanently to the Thread.
|
||||
use_tools (bool): Whether to enable tool usage
|
||||
execute_tools_async (bool): Whether to execute tools concurrently or synchronously if off.
|
||||
execute_tool_calls (bool): Whether to execute parsed tool calls
|
||||
stream (bool): Whether to stream the response
|
||||
execute_tools_on_stream (bool): Whether to execute tools during streaming, or waiting for full response before executing.
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], AsyncGenerator]:
|
||||
- Dict with response data for non-streaming
|
||||
- AsyncGenerator yielding chunks for streaming
|
||||
|
||||
Raises:
|
||||
Exception: If API call or tool execution fails
|
||||
"""
|
||||
messages = await self.list_messages(thread_id)
|
||||
prepared_messages = [system_message] + messages
|
||||
|
||||
if additional_message:
|
||||
prepared_messages.append(additional_message)
|
||||
if temporary_message:
|
||||
prepared_messages.append(temporary_message)
|
||||
|
||||
tools = self.tool_registry.get_all_tool_schemas() if use_tools else None
|
||||
|
||||
|
@ -172,10 +270,17 @@ class ThreadManager:
|
|||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_streaming_response(thread_id, llm_response, use_tools, execute_model_tool_calls, execute_tools_async)
|
||||
return self._handle_streaming_response(
|
||||
thread_id=thread_id,
|
||||
response_stream=llm_response,
|
||||
use_tools=use_tools,
|
||||
execute_tool_calls=execute_tool_calls,
|
||||
execute_tools_async=execute_tools_async,
|
||||
execute_tools_on_stream=execute_tools_on_stream
|
||||
)
|
||||
|
||||
# For non-streaming, handle the response as before
|
||||
if use_tools and execute_model_tool_calls:
|
||||
# For non-streaming, handle the response
|
||||
if use_tools and execute_tool_calls:
|
||||
await self.handle_response_with_tools(thread_id, llm_response, execute_tools_async)
|
||||
else:
|
||||
await self.handle_response_without_tools(thread_id, llm_response)
|
||||
|
@ -189,11 +294,12 @@ class ThreadManager:
|
|||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tool_choice": tool_choice,
|
||||
"additional_message": additional_message,
|
||||
"temporary_message": temporary_message,
|
||||
"execute_tools_async": execute_tools_async,
|
||||
"execute_model_tool_calls": execute_model_tool_calls,
|
||||
"execute_tool_calls": execute_tool_calls,
|
||||
"use_tools": use_tools,
|
||||
"stream": stream
|
||||
"stream": stream,
|
||||
"execute_tools_on_stream": execute_tools_on_stream
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,147 +315,145 @@ class ThreadManager:
|
|||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"tool_choice": tool_choice,
|
||||
"additional_message": additional_message,
|
||||
"temporary_message": temporary_message,
|
||||
"execute_tools_async": execute_tools_async,
|
||||
"execute_model_tool_calls": execute_model_tool_calls,
|
||||
"execute_tool_calls": execute_tool_calls,
|
||||
"use_tools": use_tools,
|
||||
"stream": stream
|
||||
"stream": stream,
|
||||
"execute_tools_on_stream": execute_tools_on_stream
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_streaming_response(self, thread_id: str, response_stream: AsyncGenerator, use_tools: bool, execute_model_tool_calls: bool, execute_tools_async: bool) -> AsyncGenerator:
|
||||
"""Handle streaming response and tool execution"""
|
||||
tool_calls_map = {} # Map to store tool calls by index
|
||||
content_buffer = ""
|
||||
|
||||
async def _handle_streaming_response(
|
||||
self,
|
||||
thread_id: str,
|
||||
response_stream: AsyncGenerator,
|
||||
use_tools: bool,
|
||||
execute_tool_calls: bool,
|
||||
execute_tools_async: bool,
|
||||
execute_tools_on_stream: bool
|
||||
) -> AsyncGenerator:
|
||||
"""Handle streaming response and tool execution."""
|
||||
tool_calls_buffer = {} # Buffer to store tool calls by index
|
||||
executed_tool_calls = set() # Track which tool calls have been executed
|
||||
available_functions = self.get_available_functions() if use_tools else {}
|
||||
content_buffer = "" # Buffer for content
|
||||
current_assistant_message = None # Track current assistant message
|
||||
pending_tool_calls = [] # Store tool calls for non-streaming execution
|
||||
|
||||
async def execute_tool_calls(tool_calls):
|
||||
if execute_tools_async:
|
||||
return await self.tool_executor.execute_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
available_functions=available_functions,
|
||||
thread_id=thread_id,
|
||||
executed_tool_calls=executed_tool_calls
|
||||
)
|
||||
else:
|
||||
sequential_executor = SequentialToolExecutor()
|
||||
return await sequential_executor.execute_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
available_functions=available_functions,
|
||||
thread_id=thread_id,
|
||||
executed_tool_calls=executed_tool_calls
|
||||
)
|
||||
|
||||
async def process_chunk(chunk):
|
||||
nonlocal content_buffer
|
||||
nonlocal content_buffer, current_assistant_message, pending_tool_calls
|
||||
|
||||
# Process tool calls in the chunk
|
||||
if hasattr(chunk.choices[0], 'delta'):
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle content if present
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
content_buffer += delta.content
|
||||
# Parse the chunk using tool parser
|
||||
parsed_message, is_complete = await self.tool_parser.parse_stream(chunk, tool_calls_buffer)
|
||||
|
||||
# If we have a message with tool calls
|
||||
if parsed_message and 'tool_calls' in parsed_message and parsed_message['tool_calls']:
|
||||
# Update or create assistant message
|
||||
if not current_assistant_message:
|
||||
current_assistant_message = parsed_message
|
||||
await self.add_message(thread_id, current_assistant_message)
|
||||
else:
|
||||
current_assistant_message['tool_calls'] = parsed_message['tool_calls']
|
||||
await self._update_message(thread_id, current_assistant_message)
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_calls_map:
|
||||
tool_calls_map[idx] = {
|
||||
'id': tool_call.id if tool_call.id else None,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tool_call.function.name if tool_call.function.name else None,
|
||||
'arguments': ''
|
||||
}
|
||||
}
|
||||
|
||||
current_tool = tool_calls_map[idx]
|
||||
if tool_call.id:
|
||||
current_tool['id'] = tool_call.id
|
||||
if tool_call.function.name:
|
||||
current_tool['function']['name'] = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
current_tool['function']['arguments'] += tool_call.function.arguments
|
||||
# Get new tool calls that haven't been executed
|
||||
new_tool_calls = [
|
||||
tool_call for tool_call in parsed_message['tool_calls']
|
||||
if tool_call['id'] not in executed_tool_calls
|
||||
]
|
||||
|
||||
# If this is the final chunk with tool_calls finish_reason
|
||||
if chunk.choices[0].finish_reason == 'tool_calls' and use_tools and execute_model_tool_calls:
|
||||
try:
|
||||
# Convert tool_calls_map to list and sort by index
|
||||
tool_calls = [tool_calls_map[idx] for idx in sorted(tool_calls_map.keys())]
|
||||
|
||||
# Create assistant message with tool calls and any content
|
||||
assistant_message = {
|
||||
"role": "assistant",
|
||||
"content": content_buffer,
|
||||
"tool_calls": tool_calls
|
||||
}
|
||||
await self.add_message(thread_id, assistant_message)
|
||||
|
||||
# Process the complete tool calls
|
||||
processed_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
args_str = tool_call['function']['arguments']
|
||||
# Try to parse the string as JSON
|
||||
tool_call['function']['arguments'] = json.loads(args_str)
|
||||
processed_tool_calls.append(tool_call)
|
||||
logging.info(f"Processed tool call: {tool_call}")
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Error parsing tool call arguments: {e}, args: {args_str}")
|
||||
continue
|
||||
|
||||
# Execute tools with the processed tool calls
|
||||
available_functions = self.get_available_functions()
|
||||
if execute_tools_async:
|
||||
tool_results = await self.execute_tools_async(processed_tool_calls, available_functions, thread_id)
|
||||
if new_tool_calls:
|
||||
if execute_tools_on_stream:
|
||||
# Execute tools immediately during streaming
|
||||
tool_results = await execute_tool_calls(new_tool_calls)
|
||||
for result in tool_results:
|
||||
await self.add_message(thread_id, result)
|
||||
executed_tool_calls.add(result['tool_call_id'])
|
||||
else:
|
||||
tool_results = await self.execute_tools_sync(processed_tool_calls, available_functions, thread_id)
|
||||
|
||||
# Add tool results
|
||||
# Store tool calls for later execution
|
||||
pending_tool_calls.extend(new_tool_calls)
|
||||
|
||||
# Handle end of response
|
||||
if chunk.choices[0].finish_reason:
|
||||
if not execute_tools_on_stream and pending_tool_calls:
|
||||
# Execute all pending tool calls at the end
|
||||
tool_results = await execute_tool_calls(pending_tool_calls)
|
||||
for result in tool_results:
|
||||
await self.add_message(thread_id, result)
|
||||
logging.info(f"Tool execution result: {result}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tools: {str(e)}")
|
||||
logging.error(f"Tool calls: {tool_calls}")
|
||||
|
||||
executed_tool_calls.add(result['tool_call_id'])
|
||||
pending_tool_calls.clear()
|
||||
|
||||
return chunk
|
||||
|
||||
async for chunk in response_stream:
|
||||
processed_chunk = await process_chunk(chunk)
|
||||
yield processed_chunk
|
||||
|
||||
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
|
||||
"""Update an existing message in the thread."""
|
||||
thread_path = os.path.join(self.threads_dir, f"{thread_id}.json")
|
||||
try:
|
||||
with open(thread_path, 'r') as f:
|
||||
thread_data = json.load(f)
|
||||
|
||||
# Find and update the last assistant message
|
||||
for i in reversed(range(len(thread_data["messages"]))):
|
||||
if thread_data["messages"][i]["role"] == "assistant":
|
||||
thread_data["messages"][i] = message
|
||||
break
|
||||
|
||||
with open(thread_path, 'w') as f:
|
||||
json.dump(thread_data, f)
|
||||
except Exception as e:
|
||||
logging.error(f"Error updating message in thread {thread_id}: {e}")
|
||||
raise e
|
||||
|
||||
async def handle_response_without_tools(self, thread_id: str, response: Any):
|
||||
response_content = response.choices[0].message['content']
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content})
|
||||
|
||||
async def handle_response_with_tools(self, thread_id: str, response: Any, execute_tools_async: bool):
|
||||
try:
|
||||
response_message = response.choices[0].message
|
||||
tool_calls = response_message.get('tool_calls', [])
|
||||
|
||||
assistant_message = self.create_assistant_message_with_tools(response_message)
|
||||
# Parse the response using the tool parser
|
||||
assistant_message = await self.tool_parser.parse_response(response)
|
||||
await self.add_message(thread_id, assistant_message)
|
||||
|
||||
available_functions = self.get_available_functions()
|
||||
|
||||
if tool_calls:
|
||||
# Execute tools if present
|
||||
if 'tool_calls' in assistant_message and assistant_message['tool_calls']:
|
||||
available_functions = self.get_available_functions()
|
||||
if execute_tools_async:
|
||||
tool_results = await self.execute_tools_async(tool_calls, available_functions, thread_id)
|
||||
tool_results = await self.execute_tools_async(assistant_message['tool_calls'], available_functions, thread_id)
|
||||
else:
|
||||
tool_results = await self.execute_tools_sync(tool_calls, available_functions, thread_id)
|
||||
tool_results = await self.execute_tools_sync(assistant_message['tool_calls'], available_functions, thread_id)
|
||||
|
||||
for result in tool_results:
|
||||
await self.add_message(thread_id, result)
|
||||
logging.info(f"Tool execution result: {result}")
|
||||
|
||||
except AttributeError as e:
|
||||
logging.error(f"AttributeError: {e}")
|
||||
response_content = response.choices[0].message['content']
|
||||
except Exception as e:
|
||||
logging.error(f"Error in handle_response_with_tools: {e}")
|
||||
logging.error(f"Response: {response}")
|
||||
response_content = response.choices[0].message.get('content', '')
|
||||
await self.add_message(thread_id, {"role": "assistant", "content": response_content or ""})
|
||||
|
||||
def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]:
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": response_message.get('content') or "",
|
||||
}
|
||||
tool_calls = response_message.get('tool_calls')
|
||||
if tool_calls:
|
||||
message["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
} for tool_call in tool_calls
|
||||
]
|
||||
return message
|
||||
|
||||
def get_available_functions(self) -> Dict[str, Callable]:
|
||||
available_functions = {}
|
||||
for tool_name, tool_info in self.tool_registry.get_all_tools().items():
|
||||
|
@ -359,55 +463,105 @@ class ThreadManager:
|
|||
available_functions[func_name] = getattr(tool_instance, func_name)
|
||||
return available_functions
|
||||
|
||||
async def execute_tools_async(self, tool_calls, available_functions, thread_id):
|
||||
async def execute_single_tool(tool_call):
|
||||
async def execute_tools_async(self, tool_calls: List[Dict[str, Any]], available_functions: Dict[str, Callable], thread_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute multiple tool calls concurrently.
|
||||
|
||||
Args:
|
||||
tool_calls (List[Dict[str, Any]]): List of tool calls to execute
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting tool execution
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from tool executions
|
||||
"""
|
||||
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
if isinstance(tool_call, dict):
|
||||
function_name = tool_call['function']['name']
|
||||
function_args = tool_call['function']['arguments'] # Already a dict
|
||||
tool_call_id = tool_call['id']
|
||||
else:
|
||||
function_name = tool_call.function.name
|
||||
function_args = json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments
|
||||
tool_call_id = tool_call.id
|
||||
|
||||
function_name = tool_call['function']['name']
|
||||
function_args = tool_call['function']['arguments']
|
||||
if isinstance(function_args, str):
|
||||
function_args = json.loads(function_args)
|
||||
|
||||
function_to_call = available_functions.get(function_name)
|
||||
if function_to_call:
|
||||
return await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
||||
else:
|
||||
logging.warning(f"Function {function_name} not found in available functions")
|
||||
return None
|
||||
if not function_to_call:
|
||||
error_msg = f"Function {function_name} not found"
|
||||
logging.error(error_msg)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_msg))
|
||||
}
|
||||
|
||||
result = await function_to_call(**function_args)
|
||||
logging.info(f"Tool execution result for {function_name}: {result}")
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(result)
|
||||
}
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool: {str(e)}")
|
||||
return None
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_msg))
|
||||
}
|
||||
|
||||
tool_results = await asyncio.gather(*[execute_single_tool(tool_call) for tool_call in tool_calls])
|
||||
return [result for result in tool_results if result]
|
||||
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
async def execute_tools_sync(self, tool_calls, available_functions, thread_id):
|
||||
tool_results = []
|
||||
async def execute_tools_sync(self, tool_calls: List[Dict[str, Any]], available_functions: Dict[str, Callable], thread_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute multiple tool calls sequentially.
|
||||
|
||||
Args:
|
||||
tool_calls (List[Dict[str, Any]]): List of tool calls to execute
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting tool execution
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from tool executions
|
||||
"""
|
||||
results = []
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
if isinstance(tool_call, dict):
|
||||
function_name = tool_call['function']['name']
|
||||
function_args = tool_call['function']['arguments'] # Already a dict
|
||||
tool_call_id = tool_call['id']
|
||||
else:
|
||||
function_name = tool_call.function.name
|
||||
function_args = json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments
|
||||
tool_call_id = tool_call.id
|
||||
|
||||
function_name = tool_call['function']['name']
|
||||
function_args = tool_call['function']['arguments']
|
||||
if isinstance(function_args, str):
|
||||
function_args = json.loads(function_args)
|
||||
|
||||
function_to_call = available_functions.get(function_name)
|
||||
if function_to_call:
|
||||
result = await self.execute_tool(function_to_call, function_args, function_name, tool_call_id)
|
||||
if result:
|
||||
tool_results.append(result)
|
||||
if not function_to_call:
|
||||
error_msg = f"Function {function_name} not found"
|
||||
logging.error(error_msg)
|
||||
result = ToolResult(success=False, output=error_msg)
|
||||
else:
|
||||
logging.warning(f"Function {function_name} not found in available functions")
|
||||
result = await function_to_call(**function_args)
|
||||
logging.info(f"Tool execution result for {function_name}: {result}")
|
||||
|
||||
results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing tool: {str(e)}")
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_msg))
|
||||
})
|
||||
|
||||
return tool_results
|
||||
return results
|
||||
|
||||
async def execute_tool(self, function_to_call, function_args, function_name, tool_call_id):
|
||||
try:
|
||||
|
@ -443,14 +597,15 @@ if __name__ == "__main__":
|
|||
# Add a test message
|
||||
await manager.add_message(thread_id, {
|
||||
"role": "user",
|
||||
"content": "Please create a file with a random name with the content 'Hello, world!' Explain what robotics is in a short message to me.."
|
||||
"content": "Please create 10x files – Each should be a chapter of a book about an Introduction to Robotics.."
|
||||
})
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can create, read, update, and delete files."
|
||||
}
|
||||
model_name = "gpt-4o-mini"
|
||||
model_name = "anthropic/claude-3-5-haiku-latest"
|
||||
# model_name = "gpt-4o-mini"
|
||||
|
||||
# Test with tools (non-streaming)
|
||||
print("\n🤖 Testing non-streaming response with tools:")
|
||||
|
@ -461,7 +616,7 @@ if __name__ == "__main__":
|
|||
temperature=0.7,
|
||||
stream=False,
|
||||
use_tools=True,
|
||||
execute_model_tool_calls=True
|
||||
execute_tool_calls=True
|
||||
)
|
||||
|
||||
# Print the non-streaming response
|
||||
|
@ -480,7 +635,8 @@ if __name__ == "__main__":
|
|||
temperature=0.7,
|
||||
stream=True,
|
||||
use_tools=True,
|
||||
execute_model_tool_calls=True
|
||||
execute_tool_calls=True,
|
||||
execute_tools_on_stream=True
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
|
@ -503,5 +659,4 @@ if __name__ == "__main__":
|
|||
print(buffer, flush=True)
|
||||
print("\n✨ Stream completed.\n")
|
||||
|
||||
# Run the async main function
|
||||
asyncio.run(main())
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Set, Callable, Optional
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from agentpress.tool import ToolResult
|
||||
|
||||
class ToolExecutor(ABC):
|
||||
"""
|
||||
Abstract base class for tool execution strategies.
|
||||
|
||||
Tool executors are responsible for running tool functions based on LLM tool calls.
|
||||
They handle both synchronous and streaming execution modes, managing the lifecycle
|
||||
of tool calls and their results.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
available_functions: Dict[str, Callable],
|
||||
thread_id: str,
|
||||
executed_tool_calls: Optional[Set[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a list of tool calls and return their results.
|
||||
|
||||
Args:
|
||||
tool_calls (List[Dict[str, Any]]): List of tool calls to execute
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting execution
|
||||
executed_tool_calls (Optional[Set[str]]): Set tracking already executed calls
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from tool executions
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls_buffer: Dict[int, Dict],
|
||||
available_functions: Dict[str, Callable],
|
||||
thread_id: str,
|
||||
executed_tool_calls: Set[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute tool calls from a streaming buffer when they're complete.
|
||||
|
||||
Args:
|
||||
tool_calls_buffer (Dict[int, Dict]): Buffer containing tool calls
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting execution
|
||||
executed_tool_calls (Set[str]): Set tracking already executed calls
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from completed tool executions
|
||||
"""
|
||||
pass
|
||||
|
||||
class StandardToolExecutor(ToolExecutor):
|
||||
"""
|
||||
Standard implementation of tool execution.
|
||||
|
||||
Executes tool calls concurrently using asyncio.gather(). Handles both streaming
|
||||
and non-streaming execution modes, with support for tracking executed calls to
|
||||
prevent duplicates.
|
||||
"""
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
available_functions: Dict[str, Callable],
|
||||
thread_id: str,
|
||||
executed_tool_calls: Optional[Set[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute all tool calls asynchronously.
|
||||
|
||||
Args:
|
||||
tool_calls (List[Dict[str, Any]]): List of tool calls to execute
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting execution
|
||||
executed_tool_calls (Optional[Set[str]]): Set tracking already executed calls
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from tool executions, including error responses
|
||||
for failed executions
|
||||
"""
|
||||
if executed_tool_calls is None:
|
||||
executed_tool_calls = set()
|
||||
|
||||
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if tool_call['id'] in executed_tool_calls:
|
||||
return None
|
||||
|
||||
try:
|
||||
function_name = tool_call['function']['name']
|
||||
function_args = tool_call['function']['arguments']
|
||||
if isinstance(function_args, str):
|
||||
function_args = json.loads(function_args)
|
||||
|
||||
function_to_call = available_functions.get(function_name)
|
||||
if not function_to_call:
|
||||
error_msg = f"Function {function_name} not found"
|
||||
logging.error(error_msg)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_msg))
|
||||
}
|
||||
|
||||
result = await function_to_call(**function_args)
|
||||
logging.info(f"Tool execution result for {function_name}: {result}")
|
||||
executed_tool_calls.add(tool_call['id'])
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(result)
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_msg))
|
||||
}
|
||||
|
||||
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return [r for r in results if r is not None]
|
||||
|
||||
async def execute_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls_buffer: Dict[int, Dict],
|
||||
available_functions: Dict[str, Callable],
|
||||
thread_id: str,
|
||||
executed_tool_calls: Set[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute complete tool calls from the streaming buffer.
|
||||
|
||||
Args:
|
||||
tool_calls_buffer (Dict[int, Dict]): Buffer containing tool calls
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting execution
|
||||
executed_tool_calls (Set[str]): Set tracking already executed calls
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from completed tool executions
|
||||
|
||||
Note:
|
||||
Only executes tool calls that are complete (have all required fields)
|
||||
and haven't been executed before.
|
||||
"""
|
||||
complete_tool_calls = []
|
||||
|
||||
# Find complete tool calls that haven't been executed
|
||||
for idx, tool_call in tool_calls_buffer.items():
|
||||
if (tool_call.get('id') and
|
||||
tool_call['function'].get('name') and
|
||||
tool_call['function'].get('arguments') and
|
||||
tool_call['id'] not in executed_tool_calls):
|
||||
try:
|
||||
# Verify arguments are complete JSON
|
||||
if isinstance(tool_call['function']['arguments'], str):
|
||||
json.loads(tool_call['function']['arguments'])
|
||||
complete_tool_calls.append(tool_call)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if complete_tool_calls:
|
||||
return await self.execute_tool_calls(
|
||||
complete_tool_calls,
|
||||
available_functions,
|
||||
thread_id,
|
||||
executed_tool_calls
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
class SequentialToolExecutor(ToolExecutor):
|
||||
"""
|
||||
Sequential implementation of tool execution.
|
||||
|
||||
Executes tool calls one at a time in sequence. This can be useful when tools
|
||||
need to be executed in a specific order or when concurrent execution might
|
||||
cause conflicts.
|
||||
"""
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
available_functions: Dict[str, Callable],
|
||||
thread_id: str,
|
||||
executed_tool_calls: Optional[Set[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute tool calls sequentially.
|
||||
|
||||
Args:
|
||||
tool_calls (List[Dict[str, Any]]): List of tool calls to execute
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting execution
|
||||
executed_tool_calls (Optional[Set[str]]): Set tracking already executed calls
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from tool executions in order of execution
|
||||
"""
|
||||
if executed_tool_calls is None:
|
||||
executed_tool_calls = set()
|
||||
|
||||
results = []
|
||||
for tool_call in tool_calls:
|
||||
if tool_call['id'] in executed_tool_calls:
|
||||
continue
|
||||
|
||||
try:
|
||||
function_name = tool_call['function']['name']
|
||||
function_args = tool_call['function']['arguments']
|
||||
if isinstance(function_args, str):
|
||||
function_args = json.loads(function_args)
|
||||
|
||||
function_to_call = available_functions.get(function_name)
|
||||
if not function_to_call:
|
||||
error_msg = f"Function {function_name} not found"
|
||||
logging.error(error_msg)
|
||||
result = ToolResult(success=False, output=error_msg)
|
||||
else:
|
||||
result = await function_to_call(**function_args)
|
||||
logging.info(f"Tool execution result for {function_name}: {result}")
|
||||
executed_tool_calls.add(tool_call['id'])
|
||||
|
||||
results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(result)
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call['id'],
|
||||
"name": function_name,
|
||||
"content": str(ToolResult(success=False, output=error_msg))
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def execute_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls_buffer: Dict[int, Dict],
|
||||
available_functions: Dict[str, Callable],
|
||||
thread_id: str,
|
||||
executed_tool_calls: Set[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute complete tool calls from the streaming buffer sequentially.
|
||||
|
||||
Args:
|
||||
tool_calls_buffer (Dict[int, Dict]): Buffer containing tool calls
|
||||
available_functions (Dict[str, Callable]): Map of function names to implementations
|
||||
thread_id (str): ID of the thread requesting execution
|
||||
executed_tool_calls (Set[str]): Set tracking already executed calls
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Results from completed tool executions in order
|
||||
|
||||
Note:
|
||||
Only executes tool calls that are complete and haven't been executed before,
|
||||
maintaining the order of the original buffer indices.
|
||||
"""
|
||||
complete_tool_calls = []
|
||||
|
||||
# Find complete tool calls that haven't been executed
|
||||
for idx, tool_call in tool_calls_buffer.items():
|
||||
if (tool_call.get('id') and
|
||||
tool_call['function'].get('name') and
|
||||
tool_call['function'].get('arguments') and
|
||||
tool_call['id'] not in executed_tool_calls):
|
||||
try:
|
||||
if isinstance(tool_call['function']['arguments'], str):
|
||||
json.loads(tool_call['function']['arguments'])
|
||||
complete_tool_calls.append(tool_call)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if complete_tool_calls:
|
||||
return await self.execute_tool_calls(
|
||||
complete_tool_calls,
|
||||
available_functions,
|
||||
thread_id,
|
||||
executed_tool_calls
|
||||
)
|
||||
|
||||
return []
|
|
@ -1 +1,163 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
import logging
|
||||
|
||||
class ToolParser(ABC):
|
||||
"""
|
||||
Abstract base class defining the interface for parsing tool calls from LLM responses.
|
||||
|
||||
Tool parsers are responsible for extracting and formatting tool calls from both
|
||||
streaming and non-streaming LLM responses. They handle the conversion of raw
|
||||
LLM output into structured tool call data that can be executed.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def parse_response(self, response: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a complete LLM response and return the assistant message with tool calls.
|
||||
|
||||
Args:
|
||||
response (Any): The complete response from the LLM
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parsed assistant message containing content and tool calls
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def parse_stream(self, response_chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
|
||||
"""
|
||||
Parse a streaming response chunk and update the tool calls buffer.
|
||||
|
||||
Args:
|
||||
response_chunk (Any): A single chunk from the streaming response
|
||||
tool_calls_buffer (Dict[int, Dict]): Buffer storing incomplete tool calls
|
||||
|
||||
Returns:
|
||||
tuple(message, is_complete):
|
||||
- message (Optional[Dict[str, Any]]): The parsed assistant message with
|
||||
tool calls if complete, None otherwise
|
||||
- is_complete (bool): Boolean indicating if tool calls parsing is complete
|
||||
"""
|
||||
pass
|
||||
|
||||
class StandardToolParser(ToolParser):
|
||||
"""
|
||||
Standard implementation of tool parsing for OpenAI-compatible API responses.
|
||||
|
||||
Handles both streaming and non-streaming responses, extracting tool calls
|
||||
and formatting them for execution. Supports incremental parsing of streaming
|
||||
tool calls and validation of tool call arguments.
|
||||
"""
|
||||
|
||||
async def parse_response(self, response: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a complete LLM response into an assistant message with tool calls.
|
||||
|
||||
Args:
|
||||
response (Any): Complete response from the LLM API
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Formatted assistant message containing:
|
||||
- role: "assistant"
|
||||
- content: Text content of the response
|
||||
- tool_calls: List of parsed tool calls (if present)
|
||||
"""
|
||||
response_message = response.choices[0].message
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": response_message.get('content') or "",
|
||||
}
|
||||
|
||||
tool_calls = response_message.get('tool_calls')
|
||||
if tool_calls:
|
||||
message["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
} for tool_call in tool_calls
|
||||
]
|
||||
|
||||
return message
|
||||
|
||||
async def parse_stream(self, chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
|
||||
"""
|
||||
Parse a streaming response chunk and update the tool calls buffer.
|
||||
Returns a message when a complete tool call is detected.
|
||||
"""
|
||||
content_chunk = ""
|
||||
is_complete = False
|
||||
has_complete_tool_call = False
|
||||
|
||||
if hasattr(chunk.choices[0], 'delta'):
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Handle content if present
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
content_chunk = delta.content
|
||||
|
||||
# Handle tool calls
|
||||
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_calls_buffer:
|
||||
tool_calls_buffer[idx] = {
|
||||
'id': tool_call.id if hasattr(tool_call, 'id') and tool_call.id else None,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': tool_call.function.name if hasattr(tool_call.function, 'name') and tool_call.function.name else None,
|
||||
'arguments': ''
|
||||
}
|
||||
}
|
||||
|
||||
current_tool = tool_calls_buffer[idx]
|
||||
if hasattr(tool_call, 'id') and tool_call.id:
|
||||
current_tool['id'] = tool_call.id
|
||||
if hasattr(tool_call.function, 'name') and tool_call.function.name:
|
||||
current_tool['function']['name'] = tool_call.function.name
|
||||
if hasattr(tool_call.function, 'arguments') and tool_call.function.arguments:
|
||||
current_tool['function']['arguments'] += tool_call.function.arguments
|
||||
|
||||
# Check if this tool call is complete
|
||||
if (current_tool['id'] and
|
||||
current_tool['function']['name'] and
|
||||
current_tool['function']['arguments']):
|
||||
try:
|
||||
# Validate JSON arguments
|
||||
json.loads(current_tool['function']['arguments'])
|
||||
has_complete_tool_call = True
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Check if this is the final chunk
|
||||
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
||||
is_complete = True
|
||||
|
||||
# Return message if we have complete tool calls or it's the final chunk
|
||||
if has_complete_tool_call or is_complete:
|
||||
# Get all complete tool calls
|
||||
complete_tool_calls = []
|
||||
for idx, tool_call in tool_calls_buffer.items():
|
||||
try:
|
||||
if (tool_call['id'] and
|
||||
tool_call['function']['name'] and
|
||||
tool_call['function']['arguments']):
|
||||
# Validate JSON
|
||||
json.loads(tool_call['function']['arguments'])
|
||||
complete_tool_calls.append(tool_call)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if complete_tool_calls:
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": content_chunk,
|
||||
"tool_calls": complete_tool_calls
|
||||
}, is_complete
|
||||
|
||||
return None, is_complete
|
|
@ -3,7 +3,20 @@ from agentpress.tool import Tool
|
|||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
Registry for managing and accessing tools in the AgentPress system.
|
||||
|
||||
The ToolRegistry maintains a collection of tool instances and their schemas,
|
||||
allowing for selective registration of tool functions and easy access to
|
||||
tool capabilities.
|
||||
|
||||
Attributes:
|
||||
tools (Dict[str, Dict[str, Any]]): Dictionary mapping function names to
|
||||
their tool instances and schemas
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize an empty tool registry."""
|
||||
self.tools: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def register_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||
|
@ -11,9 +24,13 @@ class ToolRegistry:
|
|||
Register a tool with optional function name filtering and initialization parameters.
|
||||
|
||||
Args:
|
||||
tool_class: The tool class to register
|
||||
function_names: Optional list of function names to register
|
||||
tool_class (Type[Tool]): The tool class to register
|
||||
function_names (Optional[List[str]]): Optional list of specific function names to register.
|
||||
If None, all functions from the tool will be registered.
|
||||
**kwargs: Additional keyword arguments passed to tool initialization
|
||||
|
||||
Raises:
|
||||
ValueError: If a specified function name is not found in the tool class
|
||||
"""
|
||||
tool_instance = tool_class(**kwargs)
|
||||
schemas = tool_instance.get_schemas()
|
||||
|
@ -37,10 +54,34 @@ class ToolRegistry:
|
|||
raise ValueError(f"Function '{func_name}' not found in {tool_class.__name__}")
|
||||
|
||||
def get_tool(self, tool_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a specific tool by name.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool function to retrieve
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing the tool instance and schema,
|
||||
or an empty dict if tool not found
|
||||
"""
|
||||
return self.tools.get(tool_name, {})
|
||||
|
||||
def get_all_tools(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get all registered tools.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, Any]]: Dictionary mapping tool names to their
|
||||
instances and schemas
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get schemas for all registered tools.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of OpenAPI-compatible schemas for all
|
||||
registered tool functions
|
||||
"""
|
||||
return [tool_info['schema'] for tool_info in self.tools.values()]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "agentpress"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
description = "Building blocks for AI Agents"
|
||||
authors = ["marko-kraemer <mail@markokraemer.com>"]
|
||||
readme = "README.md"
|
||||
|
|
22
state.json
22
state.json
|
@ -1,22 +0,0 @@
|
|||
{
|
||||
"files": {
|
||||
"random_message.txt": {
|
||||
"content": "Hello, world!"
|
||||
},
|
||||
"random_file_2.txt": {
|
||||
"content": "Hello, world!"
|
||||
},
|
||||
"robotics_explanation_2.txt": {
|
||||
"content": "Robotics is a branch of engineering and science that involves the design, construction, operation, and use of robots. It combines elements of mechanical engineering, electrical engineering, computer science, and artificial intelligence to create machines that can perform tasks autonomously or semi-autonomously."
|
||||
},
|
||||
"random_file_1.txt": {
|
||||
"content": "Hello, world!"
|
||||
},
|
||||
"robotics_explanation.txt": {
|
||||
"content": "Robotics is a branch of engineering and science that involves the design, construction, operation, and use of robots. It combines elements of computer science, mechanical engineering, and electrical engineering to create machines that can perform tasks autonomously or semi-autonomously, often mimicking human actions."
|
||||
},
|
||||
"hello_message.txt": {
|
||||
"content": "Hello, world!"
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue