2024-11-12 19:53:07 +08:00
|
|
|
import logging
|
2024-11-18 03:39:58 +08:00
|
|
|
from typing import Dict, Any, AsyncGenerator, Callable, List, Optional, Set
|
|
|
|
from abc import ABC, abstractmethod
|
2024-11-13 07:57:38 +08:00
|
|
|
import asyncio
|
2024-11-13 21:28:28 +08:00
|
|
|
import os
|
|
|
|
import json
|
2024-11-18 03:39:58 +08:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from agentpress.tool import ToolResult
|
2024-11-12 19:37:47 +08:00
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
# --- Tool Parser ---
|
|
|
|
|
|
|
|
class ToolParser(ABC):
|
|
|
|
"""Abstract base class defining the interface for parsing tool calls from LLM responses."""
|
2024-11-12 19:37:47 +08:00
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
@abstractmethod
|
|
|
|
async def parse_response(self, response: Any) -> Dict[str, Any]:
|
|
|
|
"""Parse a complete LLM response and return the assistant message with tool calls."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def parse_stream(self, response_chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
|
|
|
|
"""Parse a streaming response chunk and update the tool calls buffer."""
|
|
|
|
pass
|
|
|
|
|
|
|
|
class StandardToolParser(ToolParser):
|
|
|
|
"""Standard implementation of tool parsing for OpenAI-compatible API responses."""
|
|
|
|
|
|
|
|
async def parse_response(self, response: Any) -> Dict[str, Any]:
|
|
|
|
response_message = response.choices[0].message
|
|
|
|
message = {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": response_message.get('content') or "",
|
|
|
|
}
|
|
|
|
|
|
|
|
tool_calls = response_message.get('tool_calls')
|
|
|
|
if tool_calls:
|
|
|
|
message["tool_calls"] = [
|
|
|
|
{
|
|
|
|
"id": tool_call.id,
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": tool_call.function.name,
|
|
|
|
"arguments": tool_call.function.arguments
|
|
|
|
}
|
|
|
|
} for tool_call in tool_calls
|
|
|
|
]
|
|
|
|
|
|
|
|
return message
|
|
|
|
|
|
|
|
async def parse_stream(self, chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
|
|
|
|
content_chunk = ""
|
|
|
|
is_complete = False
|
|
|
|
has_complete_tool_call = False
|
|
|
|
|
|
|
|
if hasattr(chunk.choices[0], 'delta'):
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
|
|
|
|
if hasattr(delta, 'content') and delta.content:
|
|
|
|
content_chunk = delta.content
|
|
|
|
|
|
|
|
if hasattr(delta, 'tool_calls') and delta.tool_calls:
|
|
|
|
for tool_call in delta.tool_calls:
|
|
|
|
idx = tool_call.index
|
|
|
|
if idx not in tool_calls_buffer:
|
|
|
|
tool_calls_buffer[idx] = {
|
|
|
|
'id': tool_call.id if hasattr(tool_call, 'id') and tool_call.id else None,
|
|
|
|
'type': 'function',
|
|
|
|
'function': {
|
|
|
|
'name': tool_call.function.name if hasattr(tool_call.function, 'name') and tool_call.function.name else None,
|
|
|
|
'arguments': ''
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
current_tool = tool_calls_buffer[idx]
|
|
|
|
if hasattr(tool_call, 'id') and tool_call.id:
|
|
|
|
current_tool['id'] = tool_call.id
|
|
|
|
if hasattr(tool_call.function, 'name') and tool_call.function.name:
|
|
|
|
current_tool['function']['name'] = tool_call.function.name
|
|
|
|
if hasattr(tool_call.function, 'arguments') and tool_call.function.arguments:
|
|
|
|
current_tool['function']['arguments'] += tool_call.function.arguments
|
|
|
|
|
|
|
|
if (current_tool['id'] and
|
|
|
|
current_tool['function']['name'] and
|
|
|
|
current_tool['function']['arguments']):
|
|
|
|
try:
|
|
|
|
json.loads(current_tool['function']['arguments'])
|
|
|
|
has_complete_tool_call = True
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
|
|
|
is_complete = True
|
|
|
|
|
|
|
|
if has_complete_tool_call or is_complete:
|
|
|
|
complete_tool_calls = []
|
|
|
|
for idx, tool_call in tool_calls_buffer.items():
|
|
|
|
try:
|
|
|
|
if (tool_call['id'] and
|
|
|
|
tool_call['function']['name'] and
|
|
|
|
tool_call['function']['arguments']):
|
|
|
|
json.loads(tool_call['function']['arguments'])
|
|
|
|
complete_tool_calls.append(tool_call)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
continue
|
|
|
|
|
|
|
|
if complete_tool_calls:
|
|
|
|
return {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": content_chunk,
|
|
|
|
"tool_calls": complete_tool_calls
|
|
|
|
}, is_complete
|
|
|
|
|
|
|
|
return None, is_complete
|
|
|
|
|
|
|
|
# --- Tool Executor ---
|
|
|
|
|
|
|
|
class ToolExecutor:
|
|
|
|
"""Handles tool execution with configurable execution strategies."""
|
|
|
|
|
|
|
|
def __init__(self, parallel: bool = True):
|
|
|
|
self.parallel = parallel
|
|
|
|
|
|
|
|
async def execute_tool_calls(
|
|
|
|
self,
|
|
|
|
tool_calls: List[Dict[str, Any]],
|
|
|
|
available_functions: Dict[str, Callable],
|
|
|
|
thread_id: str,
|
|
|
|
executed_tool_calls: Optional[Set[str]] = None
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
if executed_tool_calls is None:
|
|
|
|
executed_tool_calls = set()
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
return await self._execute_parallel(
|
|
|
|
tool_calls,
|
|
|
|
available_functions,
|
|
|
|
thread_id,
|
|
|
|
executed_tool_calls
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return await self._execute_sequential(
|
|
|
|
tool_calls,
|
|
|
|
available_functions,
|
|
|
|
thread_id,
|
|
|
|
executed_tool_calls
|
|
|
|
)
|
|
|
|
|
|
|
|
async def _execute_parallel(
|
|
|
|
self,
|
|
|
|
tool_calls: List[Dict[str, Any]],
|
|
|
|
available_functions: Dict[str, Callable],
|
|
|
|
thread_id: str,
|
|
|
|
executed_tool_calls: Set[str]
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
if tool_call['id'] in executed_tool_calls:
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
function_name = tool_call['function']['name']
|
|
|
|
function_args = tool_call['function']['arguments']
|
|
|
|
if isinstance(function_args, str):
|
|
|
|
function_args = json.loads(function_args)
|
|
|
|
|
|
|
|
function_to_call = available_functions.get(function_name)
|
|
|
|
if not function_to_call:
|
|
|
|
error_msg = f"Function {function_name} not found"
|
|
|
|
logging.error(error_msg)
|
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
}
|
|
|
|
|
|
|
|
result = await function_to_call(**function_args)
|
|
|
|
logging.info(f"Tool execution result for {function_name}: {result}")
|
|
|
|
executed_tool_calls.add(tool_call['id'])
|
|
|
|
|
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(result)
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
error_msg = f"Error executing {function_name}: {str(e)}"
|
|
|
|
logging.error(error_msg)
|
|
|
|
return {
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
}
|
|
|
|
|
|
|
|
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
return [r for r in results if r is not None]
|
2024-11-12 19:37:47 +08:00
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
async def _execute_sequential(
|
|
|
|
self,
|
|
|
|
tool_calls: List[Dict[str, Any]],
|
|
|
|
available_functions: Dict[str, Callable],
|
|
|
|
thread_id: str,
|
|
|
|
executed_tool_calls: Set[str]
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
results = []
|
|
|
|
for tool_call in tool_calls:
|
|
|
|
if tool_call['id'] in executed_tool_calls:
|
|
|
|
continue
|
|
|
|
|
|
|
|
try:
|
|
|
|
function_name = tool_call['function']['name']
|
|
|
|
function_args = tool_call['function']['arguments']
|
|
|
|
if isinstance(function_args, str):
|
|
|
|
function_args = json.loads(function_args)
|
|
|
|
|
|
|
|
function_to_call = available_functions.get(function_name)
|
|
|
|
if not function_to_call:
|
|
|
|
error_msg = f"Function {function_name} not found"
|
|
|
|
logging.error(error_msg)
|
|
|
|
result = ToolResult(success=False, output=error_msg)
|
|
|
|
else:
|
|
|
|
result = await function_to_call(**function_args)
|
|
|
|
logging.info(f"Tool execution result for {function_name}: {result}")
|
|
|
|
executed_tool_calls.add(tool_call['id'])
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(result)
|
|
|
|
})
|
|
|
|
except Exception as e:
|
|
|
|
error_msg = f"Error executing {function_name}: {str(e)}"
|
|
|
|
logging.error(error_msg)
|
|
|
|
results.append({
|
|
|
|
"role": "tool",
|
|
|
|
"tool_call_id": tool_call['id'],
|
|
|
|
"name": function_name,
|
|
|
|
"content": str(ToolResult(success=False, output=error_msg))
|
|
|
|
})
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
# --- Response Processor ---
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ProcessedResponse:
|
|
|
|
"""Container for processed LLM response data."""
|
|
|
|
content: str
|
|
|
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
|
|
tool_results: Optional[List[Dict[str, Any]]] = None
|
|
|
|
|
|
|
|
class LLMResponseProcessor:
|
|
|
|
"""Handles LLM response processing and tool execution management."""
|
2024-11-12 19:37:47 +08:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
thread_id: str,
|
2024-11-18 03:39:58 +08:00
|
|
|
available_functions: Dict = None,
|
|
|
|
add_message_callback: Callable = None,
|
|
|
|
update_message_callback: Callable = None,
|
|
|
|
list_messages_callback: Callable = None,
|
|
|
|
parallel_tool_execution: bool = True,
|
|
|
|
threads_dir: str = "threads"
|
2024-11-12 19:37:47 +08:00
|
|
|
):
|
|
|
|
self.thread_id = thread_id
|
2024-11-18 03:39:58 +08:00
|
|
|
self.tool_executor = ToolExecutor(parallel=parallel_tool_execution)
|
|
|
|
self.tool_parser = StandardToolParser()
|
|
|
|
self.available_functions = available_functions or {}
|
2024-11-12 19:37:47 +08:00
|
|
|
self.add_message = add_message_callback
|
|
|
|
self.update_message = update_message_callback
|
2024-11-18 03:39:58 +08:00
|
|
|
self.list_messages = list_messages_callback
|
|
|
|
self.threads_dir = threads_dir
|
2024-11-12 19:37:47 +08:00
|
|
|
|
|
|
|
# State tracking for streaming responses
|
|
|
|
self.tool_calls_buffer = {}
|
|
|
|
self.processed_tool_calls = set()
|
2024-11-13 07:57:38 +08:00
|
|
|
self.content_buffer = ""
|
|
|
|
self.tool_calls_accumulated = []
|
|
|
|
self.message_added = False
|
2024-11-12 19:37:47 +08:00
|
|
|
|
|
|
|
async def process_stream(
|
|
|
|
self,
|
|
|
|
response_stream: AsyncGenerator,
|
|
|
|
execute_tools: bool = True,
|
2024-11-12 19:53:07 +08:00
|
|
|
immediate_execution: bool = True
|
2024-11-12 19:37:47 +08:00
|
|
|
) -> AsyncGenerator:
|
|
|
|
"""
|
|
|
|
Process streaming LLM response and handle tool execution.
|
2024-11-13 21:28:28 +08:00
|
|
|
Yields chunks immediately while managing message state and tool execution efficiently.
|
2024-11-12 19:37:47 +08:00
|
|
|
"""
|
|
|
|
pending_tool_calls = []
|
2024-11-13 21:28:28 +08:00
|
|
|
background_tasks = set()
|
|
|
|
tool_results = [] # Track tool results
|
2024-11-12 19:37:47 +08:00
|
|
|
|
2024-11-13 07:57:38 +08:00
|
|
|
async def handle_message_management(chunk):
|
2024-11-13 21:28:28 +08:00
|
|
|
try:
|
2024-11-18 03:39:58 +08:00
|
|
|
nonlocal tool_results
|
2024-11-13 21:28:28 +08:00
|
|
|
|
|
|
|
# Accumulate content
|
|
|
|
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
|
|
self.content_buffer += chunk.choices[0].delta.content
|
|
|
|
|
|
|
|
# Parse tool calls only if present in chunk
|
|
|
|
if hasattr(chunk.choices[0].delta, 'tool_calls'):
|
|
|
|
parsed_message, is_complete = await self.tool_parser.parse_stream(
|
|
|
|
chunk,
|
|
|
|
self.tool_calls_buffer
|
|
|
|
)
|
|
|
|
if parsed_message and 'tool_calls' in parsed_message:
|
|
|
|
self.tool_calls_accumulated = parsed_message['tool_calls']
|
|
|
|
|
|
|
|
# Handle tool execution and results
|
2024-11-13 07:57:38 +08:00
|
|
|
if execute_tools and self.tool_calls_accumulated:
|
2024-11-12 19:37:47 +08:00
|
|
|
new_tool_calls = [
|
2024-11-13 07:57:38 +08:00
|
|
|
tool_call for tool_call in self.tool_calls_accumulated
|
2024-11-12 19:37:47 +08:00
|
|
|
if tool_call['id'] not in self.processed_tool_calls
|
|
|
|
]
|
|
|
|
|
2024-11-13 21:28:28 +08:00
|
|
|
if new_tool_calls and immediate_execution:
|
|
|
|
results = await self.tool_executor.execute_tool_calls(
|
|
|
|
tool_calls=new_tool_calls,
|
|
|
|
available_functions=self.available_functions,
|
|
|
|
thread_id=self.thread_id,
|
|
|
|
executed_tool_calls=self.processed_tool_calls
|
|
|
|
)
|
|
|
|
tool_results.extend(results)
|
|
|
|
for result in results:
|
|
|
|
self.processed_tool_calls.add(result['tool_call_id'])
|
|
|
|
elif new_tool_calls:
|
|
|
|
pending_tool_calls.extend(new_tool_calls)
|
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
for result in tool_results:
|
|
|
|
if not any(msg.get('tool_call_id') == result['tool_call_id']
|
|
|
|
for msg in await self.list_messages(self.thread_id)):
|
|
|
|
await self.add_message(self.thread_id, result)
|
|
|
|
tool_results = [] # Clear processed results
|
2024-11-13 21:28:28 +08:00
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
# Then add/update assistant message
|
|
|
|
message = {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": self.content_buffer
|
|
|
|
}
|
|
|
|
if self.tool_calls_accumulated:
|
|
|
|
message["tool_calls"] = self.tool_calls_accumulated
|
2024-11-12 19:37:47 +08:00
|
|
|
|
2024-11-18 03:39:58 +08:00
|
|
|
if not self.message_added:
|
|
|
|
await self.add_message(self.thread_id, message)
|
|
|
|
self.message_added = True
|
|
|
|
else:
|
|
|
|
await self.update_message(self.thread_id, message)
|
2024-11-13 21:28:28 +08:00
|
|
|
|
|
|
|
# Handle stream completion
|
|
|
|
if chunk.choices[0].finish_reason:
|
|
|
|
if not immediate_execution and pending_tool_calls:
|
|
|
|
results = await self.tool_executor.execute_tool_calls(
|
|
|
|
tool_calls=pending_tool_calls,
|
|
|
|
available_functions=self.available_functions,
|
|
|
|
thread_id=self.thread_id,
|
|
|
|
executed_tool_calls=self.processed_tool_calls
|
|
|
|
)
|
|
|
|
for result in results:
|
|
|
|
await self.add_message(self.thread_id, result)
|
|
|
|
self.processed_tool_calls.add(result['tool_call_id'])
|
|
|
|
pending_tool_calls.clear()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"Error in background task: {e}")
|
|
|
|
return
|
|
|
|
|
|
|
|
try:
|
|
|
|
async for chunk in response_stream:
|
|
|
|
# Create and track background task
|
|
|
|
task = asyncio.create_task(handle_message_management(chunk))
|
|
|
|
background_tasks.add(task)
|
|
|
|
task.add_done_callback(background_tasks.discard)
|
|
|
|
|
|
|
|
# Immediately yield the chunk
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
# Wait for all background tasks to complete
|
|
|
|
if background_tasks:
|
|
|
|
await asyncio.gather(*background_tasks, return_exceptions=True)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"Error in stream processing: {e}")
|
|
|
|
# Clean up any remaining background tasks
|
|
|
|
for task in background_tasks:
|
|
|
|
if not task.done():
|
|
|
|
task.cancel()
|
|
|
|
raise
|
2024-11-12 19:37:47 +08:00
|
|
|
|
|
|
|
async def process_response(
|
|
|
|
self,
|
|
|
|
response: Any,
|
|
|
|
execute_tools: bool = True
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Process complete LLM response and execute tools.
|
|
|
|
|
|
|
|
Handles non-streaming responses, parsing the complete response and
|
|
|
|
executing any tool calls according to the configured execution strategy.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
assistant_message = await self.tool_parser.parse_response(response)
|
|
|
|
await self.add_message(self.thread_id, assistant_message)
|
|
|
|
|
|
|
|
if execute_tools and 'tool_calls' in assistant_message and assistant_message['tool_calls']:
|
|
|
|
results = await self.tool_executor.execute_tool_calls(
|
|
|
|
tool_calls=assistant_message['tool_calls'],
|
|
|
|
available_functions=self.available_functions,
|
|
|
|
thread_id=self.thread_id,
|
|
|
|
executed_tool_calls=self.processed_tool_calls
|
|
|
|
)
|
|
|
|
|
|
|
|
for result in results:
|
|
|
|
await self.add_message(self.thread_id, result)
|
|
|
|
logging.info(f"Tool execution result: {result}")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"Error processing response: {e}")
|
|
|
|
response_content = response.choices[0].message.get('content', '')
|
|
|
|
await self.add_message(self.thread_id, {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": response_content or ""
|
2024-11-13 21:28:28 +08:00
|
|
|
})
|