2024-11-18 13:54:26 +08:00
|
|
|
|
"""
|
|
|
|
|
Conversation thread management system for AgentPress.
|
|
|
|
|
|
|
|
|
|
This module provides comprehensive conversation management, including:
|
|
|
|
|
- Thread creation and persistence
|
|
|
|
|
- Message handling with support for text and images
|
|
|
|
|
- Tool registration and execution
|
|
|
|
|
- LLM interaction with streaming support
|
|
|
|
|
- Error handling and cleanup
|
|
|
|
|
"""
|
|
|
|
|
|
2024-10-06 01:04:15 +08:00
|
|
|
|
import json
|
|
|
|
|
import logging
|
2024-11-19 10:46:25 +08:00
|
|
|
|
import asyncio
|
2024-11-18 13:03:58 +08:00
|
|
|
|
import uuid
|
2024-11-12 19:37:47 +08:00
|
|
|
|
from typing import List, Dict, Any, Optional, Type, Union, AsyncGenerator
|
2024-10-10 22:21:39 +08:00
|
|
|
|
from agentpress.llm import make_llm_api_call
|
2024-10-23 09:28:12 +08:00
|
|
|
|
from agentpress.tool import Tool, ToolResult
|
2024-10-10 22:21:39 +08:00
|
|
|
|
from agentpress.tool_registry import ToolRegistry
|
2024-11-19 10:46:25 +08:00
|
|
|
|
from agentpress.processor.llm_response_processor import LLMResponseProcessor
|
|
|
|
|
from agentpress.processor.base_processors import ToolParserBase, ToolExecutorBase, ResultsAdderBase
|
|
|
|
|
from agentpress.db_connection import DBConnection
|
2024-10-10 22:21:39 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
from agentpress.processor.xml.xml_tool_parser import XMLToolParser
|
|
|
|
|
from agentpress.processor.xml.xml_tool_executor import XMLToolExecutor
|
|
|
|
|
from agentpress.processor.xml.xml_results_adder import XMLResultsAdder
|
|
|
|
|
from agentpress.processor.standard.standard_tool_parser import StandardToolParser
|
|
|
|
|
from agentpress.processor.standard.standard_tool_executor import StandardToolExecutor
|
|
|
|
|
from agentpress.processor.standard.standard_results_adder import StandardResultsAdder
|
2024-11-18 11:21:08 +08:00
|
|
|
|
|
2024-10-06 01:04:15 +08:00
|
|
|
|
class ThreadManager:
|
2024-11-18 04:20:16 +08:00
|
|
|
|
"""Manages conversation threads with LLM models and tool execution.
|
2024-11-12 06:56:15 +08:00
|
|
|
|
|
2024-11-18 13:54:26 +08:00
|
|
|
|
Provides comprehensive conversation management, handling message threading,
|
|
|
|
|
tool registration, and LLM interactions with support for both standard and
|
|
|
|
|
XML-based tool execution patterns.
|
2024-11-12 06:56:15 +08:00
|
|
|
|
"""
|
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
def __init__(self):
|
|
|
|
|
"""Initialize ThreadManager."""
|
|
|
|
|
self.db = DBConnection()
|
2024-10-06 01:04:15 +08:00
|
|
|
|
self.tool_registry = ToolRegistry()
|
2024-10-23 09:28:12 +08:00
|
|
|
|
|
2024-11-03 06:56:31 +08:00
|
|
|
|
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
2024-11-19 10:46:25 +08:00
|
|
|
|
"""Add a tool to the ThreadManager."""
|
2024-11-03 06:56:31 +08:00
|
|
|
|
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-10-23 10:16:35 +08:00
|
|
|
|
async def create_thread(self) -> str:
|
2024-11-19 10:46:25 +08:00
|
|
|
|
"""Create a new conversation thread."""
|
2024-10-23 10:16:35 +08:00
|
|
|
|
thread_id = str(uuid.uuid4())
|
2024-11-19 10:46:25 +08:00
|
|
|
|
await self.db.execute(
|
|
|
|
|
"INSERT INTO threads (thread_id, messages) VALUES (?, ?)",
|
|
|
|
|
(thread_id, json.dumps([]))
|
|
|
|
|
)
|
2024-10-23 10:16:35 +08:00
|
|
|
|
return thread_id
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-10-23 10:16:35 +08:00
|
|
|
|
async def add_message(self, thread_id: str, message_data: Dict[str, Any], images: Optional[List[Dict[str, Any]]] = None):
|
2024-11-19 10:46:25 +08:00
|
|
|
|
"""Add a message to an existing thread."""
|
2024-10-06 01:04:15 +08:00
|
|
|
|
logging.info(f"Adding message to thread {thread_id} with images: {images}")
|
2024-10-23 10:16:35 +08:00
|
|
|
|
|
|
|
|
|
try:
|
2024-11-19 10:46:25 +08:00
|
|
|
|
async with self.db.transaction() as conn:
|
|
|
|
|
# Handle cleanup of incomplete tool calls
|
|
|
|
|
if message_data['role'] == 'user':
|
|
|
|
|
messages = await self.get_messages(thread_id)
|
|
|
|
|
last_assistant_index = next((i for i in reversed(range(len(messages)))
|
|
|
|
|
if messages[i]['role'] == 'assistant' and 'tool_calls' in messages[i]), None)
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
if last_assistant_index is not None:
|
|
|
|
|
tool_call_count = len(messages[last_assistant_index]['tool_calls'])
|
|
|
|
|
tool_response_count = sum(1 for msg in messages[last_assistant_index+1:]
|
|
|
|
|
if msg['role'] == 'tool')
|
|
|
|
|
|
|
|
|
|
if tool_call_count != tool_response_count:
|
|
|
|
|
await self.cleanup_incomplete_tool_calls(thread_id)
|
|
|
|
|
|
|
|
|
|
# Convert ToolResult instances to strings
|
|
|
|
|
for key, value in message_data.items():
|
|
|
|
|
if isinstance(value, ToolResult):
|
|
|
|
|
message_data[key] = str(value)
|
|
|
|
|
|
|
|
|
|
# Handle image attachments
|
|
|
|
|
if images:
|
|
|
|
|
if isinstance(message_data['content'], str):
|
|
|
|
|
message_data['content'] = [{"type": "text", "text": message_data['content']}]
|
|
|
|
|
elif not isinstance(message_data['content'], list):
|
|
|
|
|
message_data['content'] = []
|
|
|
|
|
|
|
|
|
|
for image in images:
|
|
|
|
|
image_content = {
|
|
|
|
|
"type": "image_url",
|
|
|
|
|
"image_url": {
|
|
|
|
|
"url": f"data:{image['content_type']};base64,{image['base64']}",
|
|
|
|
|
"detail": "high"
|
|
|
|
|
}
|
2024-10-06 01:04:15 +08:00
|
|
|
|
}
|
2024-11-19 10:46:25 +08:00
|
|
|
|
message_data['content'].append(image_content)
|
2024-10-23 10:16:35 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
# Get current messages
|
|
|
|
|
row = await self.db.fetch_one(
|
|
|
|
|
"SELECT messages FROM threads WHERE thread_id = ?",
|
|
|
|
|
(thread_id,)
|
|
|
|
|
)
|
|
|
|
|
if not row:
|
|
|
|
|
raise ValueError(f"Thread {thread_id} not found")
|
|
|
|
|
|
|
|
|
|
messages = json.loads(row[0])
|
|
|
|
|
messages.append(message_data)
|
|
|
|
|
|
|
|
|
|
# Update thread
|
|
|
|
|
await conn.execute(
|
|
|
|
|
"""
|
|
|
|
|
UPDATE threads
|
|
|
|
|
SET messages = ?, updated_at = CURRENT_TIMESTAMP
|
|
|
|
|
WHERE thread_id = ?
|
|
|
|
|
""",
|
|
|
|
|
(json.dumps(messages), thread_id)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logging.info(f"Message added to thread {thread_id}: {message_data}")
|
2024-10-23 10:16:35 +08:00
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logging.error(f"Failed to add message to thread {thread_id}: {e}")
|
|
|
|
|
raise e
|
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
async def get_messages(
|
2024-11-18 13:54:26 +08:00
|
|
|
|
self,
|
2024-11-19 10:46:25 +08:00
|
|
|
|
thread_id: str,
|
|
|
|
|
hide_tool_msgs: bool = False,
|
|
|
|
|
only_latest_assistant: bool = False,
|
2024-11-18 13:54:26 +08:00
|
|
|
|
regular_list: bool = True
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
2024-11-19 10:46:25 +08:00
|
|
|
|
"""Retrieve messages from a thread with optional filtering."""
|
|
|
|
|
row = await self.db.fetch_one(
|
|
|
|
|
"SELECT messages FROM threads WHERE thread_id = ?",
|
|
|
|
|
(thread_id,)
|
|
|
|
|
)
|
|
|
|
|
if not row:
|
|
|
|
|
return []
|
2024-11-12 06:56:15 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
messages = json.loads(row[0])
|
2024-10-23 10:16:35 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
if only_latest_assistant:
|
|
|
|
|
for msg in reversed(messages):
|
|
|
|
|
if msg.get('role') == 'assistant':
|
|
|
|
|
return [msg]
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
if hide_tool_msgs:
|
|
|
|
|
messages = [
|
|
|
|
|
{k: v for k, v in msg.items() if k != 'tool_calls'}
|
|
|
|
|
for msg in messages
|
|
|
|
|
if msg.get('role') != 'tool'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if regular_list:
|
|
|
|
|
messages = [
|
|
|
|
|
msg for msg in messages
|
|
|
|
|
if msg.get('role') in ['system', 'assistant', 'tool', 'user']
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return messages
|
|
|
|
|
|
|
|
|
|
async def _update_message(self, thread_id: str, message: Dict[str, Any]):
|
|
|
|
|
"""Update an existing message in the thread."""
|
|
|
|
|
async with self.db.transaction() as conn:
|
|
|
|
|
row = await self.db.fetch_one(
|
|
|
|
|
"SELECT messages FROM threads WHERE thread_id = ?",
|
|
|
|
|
(thread_id,)
|
|
|
|
|
)
|
|
|
|
|
if not row:
|
|
|
|
|
return
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
messages = json.loads(row[0])
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
# Find and update the last assistant message
|
|
|
|
|
for i in reversed(range(len(messages))):
|
|
|
|
|
if messages[i].get('role') == 'assistant':
|
|
|
|
|
messages[i] = message
|
|
|
|
|
break
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
await conn.execute(
|
|
|
|
|
"""
|
|
|
|
|
UPDATE threads
|
|
|
|
|
SET messages = ?, updated_at = CURRENT_TIMESTAMP
|
|
|
|
|
WHERE thread_id = ?
|
|
|
|
|
""",
|
|
|
|
|
(json.dumps(messages), thread_id)
|
|
|
|
|
)
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-10-23 10:16:35 +08:00
|
|
|
|
async def cleanup_incomplete_tool_calls(self, thread_id: str):
|
2024-11-19 10:46:25 +08:00
|
|
|
|
"""Clean up incomplete tool calls in a thread."""
|
|
|
|
|
messages = await self.get_messages(thread_id)
|
2024-11-18 13:54:26 +08:00
|
|
|
|
last_assistant_message = next((m for m in reversed(messages)
|
|
|
|
|
if m['role'] == 'assistant' and 'tool_calls' in m), None)
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
|
|
|
|
if last_assistant_message:
|
|
|
|
|
tool_calls = last_assistant_message.get('tool_calls', [])
|
2024-11-18 13:54:26 +08:00
|
|
|
|
tool_responses = [m for m in messages[messages.index(last_assistant_message)+1:]
|
|
|
|
|
if m['role'] == 'tool']
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
|
|
|
|
if len(tool_calls) != len(tool_responses):
|
|
|
|
|
failed_tool_results = []
|
|
|
|
|
for tool_call in tool_calls[len(tool_responses):]:
|
|
|
|
|
failed_tool_result = {
|
|
|
|
|
"role": "tool",
|
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
|
"name": tool_call['function']['name'],
|
|
|
|
|
"content": "ToolResult(success=False, output='Execution interrupted. Session was stopped.')"
|
|
|
|
|
}
|
|
|
|
|
failed_tool_results.append(failed_tool_result)
|
|
|
|
|
|
|
|
|
|
assistant_index = messages.index(last_assistant_message)
|
|
|
|
|
messages[assistant_index+1:assistant_index+1] = failed_tool_results
|
|
|
|
|
|
2024-11-19 10:46:25 +08:00
|
|
|
|
async with self.db.transaction() as conn:
|
|
|
|
|
await conn.execute(
|
|
|
|
|
"""
|
|
|
|
|
UPDATE threads
|
|
|
|
|
SET messages = ?, updated_at = CURRENT_TIMESTAMP
|
|
|
|
|
WHERE thread_id = ?
|
|
|
|
|
""",
|
|
|
|
|
(json.dumps(messages), thread_id)
|
|
|
|
|
)
|
2024-10-06 01:04:15 +08:00
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
2024-11-12 06:56:15 +08:00
|
|
|
|
async def run_thread(
|
2024-11-12 19:37:47 +08:00
|
|
|
|
self,
|
|
|
|
|
thread_id: str,
|
|
|
|
|
system_message: Dict[str, Any],
|
|
|
|
|
model_name: str,
|
|
|
|
|
temperature: float = 0,
|
|
|
|
|
max_tokens: Optional[int] = None,
|
|
|
|
|
tool_choice: str = "auto",
|
|
|
|
|
temporary_message: Optional[Dict[str, Any]] = None,
|
2024-11-18 08:38:31 +08:00
|
|
|
|
native_tool_calling: bool = False,
|
2024-11-18 11:21:08 +08:00
|
|
|
|
xml_tool_calling: bool = False,
|
2024-11-12 19:37:47 +08:00
|
|
|
|
execute_tools: bool = True,
|
2024-11-12 06:56:15 +08:00
|
|
|
|
stream: bool = False,
|
2024-11-18 11:21:08 +08:00
|
|
|
|
execute_tools_on_stream: bool = False,
|
|
|
|
|
parallel_tool_execution: bool = False,
|
2024-11-18 06:36:37 +08:00
|
|
|
|
tool_parser: Optional[ToolParserBase] = None,
|
|
|
|
|
tool_executor: Optional[ToolExecutorBase] = None,
|
|
|
|
|
results_adder: Optional[ResultsAdderBase] = None
|
2024-11-12 06:56:15 +08:00
|
|
|
|
) -> Union[Dict[str, Any], AsyncGenerator]:
|
2024-11-18 13:54:26 +08:00
|
|
|
|
"""Run a conversation thread with specified parameters.
|
2024-11-18 11:21:08 +08:00
|
|
|
|
|
2024-11-18 13:54:26 +08:00
|
|
|
|
Args:
|
|
|
|
|
thread_id: ID of the thread to run
|
|
|
|
|
system_message: System message for the conversation
|
|
|
|
|
model_name: Name of the LLM model to use
|
|
|
|
|
temperature: Model temperature (0-1)
|
|
|
|
|
max_tokens: Maximum tokens in response
|
|
|
|
|
tool_choice: Tool selection strategy ("auto" or "none")
|
|
|
|
|
temporary_message: Optional message to include temporarily
|
|
|
|
|
native_tool_calling: Whether to use native LLM function calling
|
|
|
|
|
xml_tool_calling: Whether to use XML-based tool calling
|
|
|
|
|
execute_tools: Whether to execute tool calls
|
|
|
|
|
stream: Whether to stream the response
|
|
|
|
|
execute_tools_on_stream: Whether to execute tools during streaming
|
|
|
|
|
parallel_tool_execution: Whether to execute tools in parallel
|
|
|
|
|
tool_parser: Custom tool parser implementation
|
|
|
|
|
tool_executor: Custom tool executor implementation
|
|
|
|
|
results_adder: Custom results adder implementation
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Union[Dict[str, Any], AsyncGenerator]: Response or stream
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If incompatible tool calling options are specified
|
|
|
|
|
Exception: For other execution failures
|
|
|
|
|
|
|
|
|
|
Notes:
|
|
|
|
|
- Cannot use both native and XML tool calling simultaneously
|
|
|
|
|
- Streaming responses include both content and tool results
|
|
|
|
|
"""
|
2024-11-18 11:21:08 +08:00
|
|
|
|
# Validate tool calling configuration
|
|
|
|
|
if native_tool_calling and xml_tool_calling:
|
|
|
|
|
raise ValueError("Cannot use both native LLM tool calling and XML tool calling simultaneously")
|
|
|
|
|
|
|
|
|
|
# Initialize tool components if any tool calling is enabled
|
|
|
|
|
if native_tool_calling or xml_tool_calling:
|
|
|
|
|
if tool_parser is None:
|
|
|
|
|
tool_parser = XMLToolParser(tool_registry=self.tool_registry) if xml_tool_calling else StandardToolParser()
|
|
|
|
|
|
|
|
|
|
if tool_executor is None:
|
|
|
|
|
tool_executor = XMLToolExecutor(parallel=parallel_tool_execution, tool_registry=self.tool_registry) if xml_tool_calling else StandardToolExecutor(parallel=parallel_tool_execution)
|
|
|
|
|
|
|
|
|
|
if results_adder is None:
|
|
|
|
|
results_adder = XMLResultsAdder(self) if xml_tool_calling else StandardResultsAdder(self)
|
|
|
|
|
|
2024-10-06 01:04:15 +08:00
|
|
|
|
try:
|
2024-11-19 10:46:25 +08:00
|
|
|
|
messages = await self.get_messages(thread_id)
|
2024-11-18 03:39:58 +08:00
|
|
|
|
prepared_messages = [system_message] + messages
|
|
|
|
|
if temporary_message:
|
|
|
|
|
prepared_messages.append(temporary_message)
|
|
|
|
|
|
2024-11-18 11:21:08 +08:00
|
|
|
|
openapi_tool_schemas = None
|
|
|
|
|
if native_tool_calling:
|
|
|
|
|
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
|
|
|
|
available_functions = self.tool_registry.get_available_functions()
|
|
|
|
|
elif xml_tool_calling:
|
2024-11-18 08:38:31 +08:00
|
|
|
|
available_functions = self.tool_registry.get_available_functions()
|
|
|
|
|
else:
|
|
|
|
|
available_functions = {}
|
2024-11-18 03:39:58 +08:00
|
|
|
|
|
2024-11-18 11:21:08 +08:00
|
|
|
|
response_processor = LLMResponseProcessor(
|
2024-11-12 19:37:47 +08:00
|
|
|
|
thread_id=thread_id,
|
|
|
|
|
available_functions=available_functions,
|
|
|
|
|
add_message_callback=self.add_message,
|
2024-11-18 03:39:58 +08:00
|
|
|
|
update_message_callback=self._update_message,
|
2024-11-19 10:46:25 +08:00
|
|
|
|
get_messages_callback=self.get_messages,
|
2024-11-18 03:39:58 +08:00
|
|
|
|
parallel_tool_execution=parallel_tool_execution,
|
2024-11-18 08:38:31 +08:00
|
|
|
|
tool_parser=tool_parser,
|
|
|
|
|
tool_executor=tool_executor,
|
|
|
|
|
results_adder=results_adder
|
2024-11-12 19:37:47 +08:00
|
|
|
|
)
|
|
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
|
llm_response = await self._run_thread_completion(
|
|
|
|
|
messages=prepared_messages,
|
|
|
|
|
model_name=model_name,
|
2024-11-12 19:37:47 +08:00
|
|
|
|
temperature=temperature,
|
2024-10-23 09:28:12 +08:00
|
|
|
|
max_tokens=max_tokens,
|
2024-11-18 11:21:08 +08:00
|
|
|
|
tools=openapi_tool_schemas,
|
2024-11-18 08:38:31 +08:00
|
|
|
|
tool_choice=tool_choice if native_tool_calling else None,
|
2024-11-12 01:32:26 +08:00
|
|
|
|
stream=stream
|
2024-10-17 04:08:46 +08:00
|
|
|
|
)
|
2024-11-12 01:32:26 +08:00
|
|
|
|
|
|
|
|
|
if stream:
|
2024-11-18 03:39:58 +08:00
|
|
|
|
return response_processor.process_stream(
|
2024-11-12 06:56:15 +08:00
|
|
|
|
response_stream=llm_response,
|
2024-11-12 19:37:47 +08:00
|
|
|
|
execute_tools=execute_tools,
|
2024-11-18 11:21:08 +08:00
|
|
|
|
execute_tools_on_stream=execute_tools_on_stream
|
2024-11-12 06:56:15 +08:00
|
|
|
|
)
|
2024-11-18 03:39:58 +08:00
|
|
|
|
|
|
|
|
|
await response_processor.process_response(
|
2024-11-12 19:37:47 +08:00
|
|
|
|
response=llm_response,
|
|
|
|
|
execute_tools=execute_tools
|
|
|
|
|
)
|
2024-11-12 01:32:26 +08:00
|
|
|
|
|
2024-11-18 04:20:16 +08:00
|
|
|
|
return llm_response
|
2024-11-12 01:32:26 +08:00
|
|
|
|
|
2024-10-23 09:28:12 +08:00
|
|
|
|
except Exception as e:
|
2024-11-18 03:39:58 +08:00
|
|
|
|
logging.error(f"Error in run_thread: {str(e)}")
|
2024-10-08 03:13:11 +08:00
|
|
|
|
return {
|
2024-10-23 09:28:12 +08:00
|
|
|
|
"status": "error",
|
2024-11-18 04:20:16 +08:00
|
|
|
|
"message": str(e)
|
2024-10-08 03:13:11 +08:00
|
|
|
|
}
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
|
async def _run_thread_completion(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[Dict[str, Any]],
|
|
|
|
|
model_name: str,
|
|
|
|
|
temperature: float,
|
|
|
|
|
max_tokens: Optional[int],
|
|
|
|
|
tools: Optional[List[Dict[str, Any]]],
|
|
|
|
|
tool_choice: Optional[str],
|
|
|
|
|
stream: bool
|
|
|
|
|
) -> Union[Any, AsyncGenerator]:
|
|
|
|
|
"""Get completion from LLM API."""
|
|
|
|
|
return await make_llm_api_call(
|
|
|
|
|
messages,
|
|
|
|
|
model_name,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
tools=tools,
|
|
|
|
|
tool_choice=tool_choice,
|
|
|
|
|
stream=stream
|
|
|
|
|
)
|
|
|
|
|
|
2024-10-23 09:28:12 +08:00
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import asyncio
|
2024-11-12 01:32:26 +08:00
|
|
|
|
from agentpress.examples.example_agent.tools.files_tool import FilesTool
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-10-23 09:28:12 +08:00
|
|
|
|
async def main():
|
2024-11-12 19:53:07 +08:00
|
|
|
|
# Initialize managers
|
|
|
|
|
thread_manager = ThreadManager()
|
|
|
|
|
|
|
|
|
|
# Register available tools
|
|
|
|
|
thread_manager.add_tool(FilesTool)
|
|
|
|
|
|
|
|
|
|
# Create a new thread
|
|
|
|
|
thread_id = await thread_manager.create_thread()
|
2024-10-23 09:28:12 +08:00
|
|
|
|
|
2024-11-12 01:32:26 +08:00
|
|
|
|
# Add a test message
|
2024-11-12 19:53:07 +08:00
|
|
|
|
await thread_manager.add_message(thread_id, {
|
2024-11-12 01:32:26 +08:00
|
|
|
|
"role": "user",
|
2024-11-12 06:56:15 +08:00
|
|
|
|
"content": "Please create 10x files – Each should be a chapter of a book about an Introduction to Robotics.."
|
2024-11-12 01:32:26 +08:00
|
|
|
|
})
|
|
|
|
|
|
2024-11-12 19:53:07 +08:00
|
|
|
|
# Define system message
|
2024-11-12 01:32:26 +08:00
|
|
|
|
system_message = {
|
|
|
|
|
"role": "system",
|
|
|
|
|
"content": "You are a helpful assistant that can create, read, update, and delete files."
|
|
|
|
|
}
|
2024-10-10 22:21:39 +08:00
|
|
|
|
|
2024-11-12 19:53:07 +08:00
|
|
|
|
# Test with streaming response and tool execution
|
|
|
|
|
print("\n🤖 Testing streaming response with tools:")
|
|
|
|
|
response = await thread_manager.run_thread(
|
2024-10-23 09:28:12 +08:00
|
|
|
|
thread_id=thread_id,
|
|
|
|
|
system_message=system_message,
|
2024-11-12 19:53:07 +08:00
|
|
|
|
model_name="anthropic/claude-3-5-haiku-latest",
|
2024-10-23 09:28:12 +08:00
|
|
|
|
temperature=0.7,
|
2024-11-12 19:53:07 +08:00
|
|
|
|
max_tokens=4096,
|
|
|
|
|
stream=True,
|
2024-11-18 11:21:08 +08:00
|
|
|
|
native_tool_calling=True,
|
2024-11-12 19:53:07 +08:00
|
|
|
|
execute_tools=True,
|
2024-11-18 11:21:08 +08:00
|
|
|
|
execute_tools_on_stream=True,
|
2024-11-12 19:54:53 +08:00
|
|
|
|
parallel_tool_execution=True
|
2024-10-23 09:28:12 +08:00
|
|
|
|
)
|
2024-10-10 22:21:39 +08:00
|
|
|
|
|
2024-11-12 19:53:07 +08:00
|
|
|
|
# Handle streaming response
|
|
|
|
|
if isinstance(response, AsyncGenerator):
|
|
|
|
|
print("\nAssistant is responding:")
|
|
|
|
|
content_buffer = ""
|
|
|
|
|
try:
|
|
|
|
|
async for chunk in response:
|
|
|
|
|
if hasattr(chunk.choices[0], 'delta'):
|
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
|
|
|
|
|
|
# Handle content streaming
|
|
|
|
|
if hasattr(delta, 'content') and delta.content is not None:
|
|
|
|
|
content_buffer += delta.content
|
|
|
|
|
if delta.content.endswith((' ', '\n')):
|
|
|
|
|
print(content_buffer, end='', flush=True)
|
|
|
|
|
content_buffer = ""
|
|
|
|
|
|
|
|
|
|
# Handle tool calls
|
|
|
|
|
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
|
|
|
|
for tool_call in delta.tool_calls:
|
|
|
|
|
# Print tool name when it first appears
|
|
|
|
|
if tool_call.function and tool_call.function.name:
|
|
|
|
|
print(f"\n🛠️ Tool Call: {tool_call.function.name}", flush=True)
|
|
|
|
|
|
|
|
|
|
# Print arguments as they stream in
|
|
|
|
|
if tool_call.function and tool_call.function.arguments:
|
|
|
|
|
print(f" {tool_call.function.arguments}", end='', flush=True)
|
|
|
|
|
|
|
|
|
|
# Print any remaining content
|
|
|
|
|
if content_buffer:
|
|
|
|
|
print(content_buffer, flush=True)
|
|
|
|
|
print("\n✨ Response completed\n")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"\n❌ Error processing stream: {e}")
|
|
|
|
|
else:
|
|
|
|
|
print("\n✨ Response completed\n")
|
|
|
|
|
|
|
|
|
|
# Display final thread state
|
2024-11-19 10:46:25 +08:00
|
|
|
|
messages = await thread_manager.get_messages(thread_id)
|
2024-11-12 19:53:07 +08:00
|
|
|
|
print("\n📝 Final Thread State:")
|
|
|
|
|
for msg in messages:
|
|
|
|
|
role = msg.get('role', 'unknown')
|
|
|
|
|
content = msg.get('content', '')
|
|
|
|
|
print(f"\n{role.upper()}: {content[:100]}...")
|
2024-10-06 01:04:15 +08:00
|
|
|
|
|
2024-10-23 09:28:12 +08:00
|
|
|
|
asyncio.run(main())
|
2024-11-18 03:39:58 +08:00
|
|
|
|
|