suna/agentpress/response_processor.py

429 lines
17 KiB
Python
Raw Normal View History

2024-11-12 19:53:07 +08:00
import logging
from typing import Dict, Any, AsyncGenerator, Callable, List, Optional, Set
from abc import ABC, abstractmethod
2024-11-13 07:57:38 +08:00
import asyncio
2024-11-13 21:28:28 +08:00
import os
import json
from dataclasses import dataclass
from agentpress.tool import ToolResult
2024-11-12 19:37:47 +08:00
# --- Tool Parser ---
class ToolParser(ABC):
"""Abstract base class defining the interface for parsing tool calls from LLM responses."""
2024-11-12 19:37:47 +08:00
@abstractmethod
async def parse_response(self, response: Any) -> Dict[str, Any]:
"""Parse a complete LLM response and return the assistant message with tool calls."""
pass
@abstractmethod
async def parse_stream(self, response_chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
"""Parse a streaming response chunk and update the tool calls buffer."""
pass
class StandardToolParser(ToolParser):
"""Standard implementation of tool parsing for OpenAI-compatible API responses."""
async def parse_response(self, response: Any) -> Dict[str, Any]:
response_message = response.choices[0].message
message = {
"role": "assistant",
"content": response_message.get('content') or "",
}
tool_calls = response_message.get('tool_calls')
if tool_calls:
message["tool_calls"] = [
{
"id": tool_call.id,
"type": "function",
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
}
} for tool_call in tool_calls
]
return message
async def parse_stream(self, chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
content_chunk = ""
is_complete = False
has_complete_tool_call = False
if hasattr(chunk.choices[0], 'delta'):
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
content_chunk = delta.content
if hasattr(delta, 'tool_calls') and delta.tool_calls:
for tool_call in delta.tool_calls:
idx = tool_call.index
if idx not in tool_calls_buffer:
tool_calls_buffer[idx] = {
'id': tool_call.id if hasattr(tool_call, 'id') and tool_call.id else None,
'type': 'function',
'function': {
'name': tool_call.function.name if hasattr(tool_call.function, 'name') and tool_call.function.name else None,
'arguments': ''
}
}
current_tool = tool_calls_buffer[idx]
if hasattr(tool_call, 'id') and tool_call.id:
current_tool['id'] = tool_call.id
if hasattr(tool_call.function, 'name') and tool_call.function.name:
current_tool['function']['name'] = tool_call.function.name
if hasattr(tool_call.function, 'arguments') and tool_call.function.arguments:
current_tool['function']['arguments'] += tool_call.function.arguments
if (current_tool['id'] and
current_tool['function']['name'] and
current_tool['function']['arguments']):
try:
json.loads(current_tool['function']['arguments'])
has_complete_tool_call = True
except json.JSONDecodeError:
pass
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
is_complete = True
if has_complete_tool_call or is_complete:
complete_tool_calls = []
for idx, tool_call in tool_calls_buffer.items():
try:
if (tool_call['id'] and
tool_call['function']['name'] and
tool_call['function']['arguments']):
json.loads(tool_call['function']['arguments'])
complete_tool_calls.append(tool_call)
except json.JSONDecodeError:
continue
if complete_tool_calls:
return {
"role": "assistant",
"content": content_chunk,
"tool_calls": complete_tool_calls
}, is_complete
return None, is_complete
# --- Tool Executor ---
class ToolExecutor:
"""Handles tool execution with configurable execution strategies."""
def __init__(self, parallel: bool = True):
self.parallel = parallel
async def execute_tool_calls(
self,
tool_calls: List[Dict[str, Any]],
available_functions: Dict[str, Callable],
thread_id: str,
executed_tool_calls: Optional[Set[str]] = None
) -> List[Dict[str, Any]]:
if executed_tool_calls is None:
executed_tool_calls = set()
if self.parallel:
return await self._execute_parallel(
tool_calls,
available_functions,
thread_id,
executed_tool_calls
)
else:
return await self._execute_sequential(
tool_calls,
available_functions,
thread_id,
executed_tool_calls
)
async def _execute_parallel(
self,
tool_calls: List[Dict[str, Any]],
available_functions: Dict[str, Callable],
thread_id: str,
executed_tool_calls: Set[str]
) -> List[Dict[str, Any]]:
async def execute_single_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]:
if tool_call['id'] in executed_tool_calls:
return None
try:
function_name = tool_call['function']['name']
function_args = tool_call['function']['arguments']
if isinstance(function_args, str):
function_args = json.loads(function_args)
function_to_call = available_functions.get(function_name)
if not function_to_call:
error_msg = f"Function {function_name} not found"
logging.error(error_msg)
return {
"role": "tool",
"tool_call_id": tool_call['id'],
"name": function_name,
"content": str(ToolResult(success=False, output=error_msg))
}
result = await function_to_call(**function_args)
logging.info(f"Tool execution result for {function_name}: {result}")
executed_tool_calls.add(tool_call['id'])
return {
"role": "tool",
"tool_call_id": tool_call['id'],
"name": function_name,
"content": str(result)
}
except Exception as e:
error_msg = f"Error executing {function_name}: {str(e)}"
logging.error(error_msg)
return {
"role": "tool",
"tool_call_id": tool_call['id'],
"name": function_name,
"content": str(ToolResult(success=False, output=error_msg))
}
tasks = [execute_single_tool(tool_call) for tool_call in tool_calls]
results = await asyncio.gather(*tasks)
return [r for r in results if r is not None]
2024-11-12 19:37:47 +08:00
async def _execute_sequential(
self,
tool_calls: List[Dict[str, Any]],
available_functions: Dict[str, Callable],
thread_id: str,
executed_tool_calls: Set[str]
) -> List[Dict[str, Any]]:
results = []
for tool_call in tool_calls:
if tool_call['id'] in executed_tool_calls:
continue
try:
function_name = tool_call['function']['name']
function_args = tool_call['function']['arguments']
if isinstance(function_args, str):
function_args = json.loads(function_args)
function_to_call = available_functions.get(function_name)
if not function_to_call:
error_msg = f"Function {function_name} not found"
logging.error(error_msg)
result = ToolResult(success=False, output=error_msg)
else:
result = await function_to_call(**function_args)
logging.info(f"Tool execution result for {function_name}: {result}")
executed_tool_calls.add(tool_call['id'])
results.append({
"role": "tool",
"tool_call_id": tool_call['id'],
"name": function_name,
"content": str(result)
})
except Exception as e:
error_msg = f"Error executing {function_name}: {str(e)}"
logging.error(error_msg)
results.append({
"role": "tool",
"tool_call_id": tool_call['id'],
"name": function_name,
"content": str(ToolResult(success=False, output=error_msg))
})
return results
# --- Response Processor ---
@dataclass
class ProcessedResponse:
"""Container for processed LLM response data."""
content: str
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_results: Optional[List[Dict[str, Any]]] = None
class LLMResponseProcessor:
"""Handles LLM response processing and tool execution management."""
2024-11-12 19:37:47 +08:00
def __init__(
self,
thread_id: str,
available_functions: Dict = None,
add_message_callback: Callable = None,
update_message_callback: Callable = None,
list_messages_callback: Callable = None,
parallel_tool_execution: bool = True,
threads_dir: str = "threads"
2024-11-12 19:37:47 +08:00
):
self.thread_id = thread_id
self.tool_executor = ToolExecutor(parallel=parallel_tool_execution)
self.tool_parser = StandardToolParser()
self.available_functions = available_functions or {}
2024-11-12 19:37:47 +08:00
self.add_message = add_message_callback
self.update_message = update_message_callback
self.list_messages = list_messages_callback
self.threads_dir = threads_dir
2024-11-12 19:37:47 +08:00
# State tracking for streaming responses
self.tool_calls_buffer = {}
self.processed_tool_calls = set()
2024-11-13 07:57:38 +08:00
self.content_buffer = ""
self.tool_calls_accumulated = []
self.message_added = False
2024-11-12 19:37:47 +08:00
async def process_stream(
self,
response_stream: AsyncGenerator,
execute_tools: bool = True,
2024-11-12 19:53:07 +08:00
immediate_execution: bool = True
2024-11-12 19:37:47 +08:00
) -> AsyncGenerator:
"""
Process streaming LLM response and handle tool execution.
2024-11-13 21:28:28 +08:00
Yields chunks immediately while managing message state and tool execution efficiently.
2024-11-12 19:37:47 +08:00
"""
pending_tool_calls = []
2024-11-13 21:28:28 +08:00
background_tasks = set()
tool_results = [] # Track tool results
2024-11-12 19:37:47 +08:00
2024-11-13 07:57:38 +08:00
async def handle_message_management(chunk):
2024-11-13 21:28:28 +08:00
try:
nonlocal tool_results
2024-11-13 21:28:28 +08:00
# Accumulate content
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
self.content_buffer += chunk.choices[0].delta.content
# Parse tool calls only if present in chunk
if hasattr(chunk.choices[0].delta, 'tool_calls'):
parsed_message, is_complete = await self.tool_parser.parse_stream(
chunk,
self.tool_calls_buffer
)
if parsed_message and 'tool_calls' in parsed_message:
self.tool_calls_accumulated = parsed_message['tool_calls']
# Handle tool execution and results
2024-11-13 07:57:38 +08:00
if execute_tools and self.tool_calls_accumulated:
2024-11-12 19:37:47 +08:00
new_tool_calls = [
2024-11-13 07:57:38 +08:00
tool_call for tool_call in self.tool_calls_accumulated
2024-11-12 19:37:47 +08:00
if tool_call['id'] not in self.processed_tool_calls
]
2024-11-13 21:28:28 +08:00
if new_tool_calls and immediate_execution:
results = await self.tool_executor.execute_tool_calls(
tool_calls=new_tool_calls,
available_functions=self.available_functions,
thread_id=self.thread_id,
executed_tool_calls=self.processed_tool_calls
)
tool_results.extend(results)
for result in results:
self.processed_tool_calls.add(result['tool_call_id'])
elif new_tool_calls:
pending_tool_calls.extend(new_tool_calls)
for result in tool_results:
if not any(msg.get('tool_call_id') == result['tool_call_id']
for msg in await self.list_messages(self.thread_id)):
await self.add_message(self.thread_id, result)
tool_results = [] # Clear processed results
2024-11-13 21:28:28 +08:00
# Then add/update assistant message
message = {
"role": "assistant",
"content": self.content_buffer
}
if self.tool_calls_accumulated:
message["tool_calls"] = self.tool_calls_accumulated
2024-11-12 19:37:47 +08:00
if not self.message_added:
await self.add_message(self.thread_id, message)
self.message_added = True
else:
await self.update_message(self.thread_id, message)
2024-11-13 21:28:28 +08:00
# Handle stream completion
if chunk.choices[0].finish_reason:
if not immediate_execution and pending_tool_calls:
results = await self.tool_executor.execute_tool_calls(
tool_calls=pending_tool_calls,
available_functions=self.available_functions,
thread_id=self.thread_id,
executed_tool_calls=self.processed_tool_calls
)
for result in results:
await self.add_message(self.thread_id, result)
self.processed_tool_calls.add(result['tool_call_id'])
pending_tool_calls.clear()
except Exception as e:
logging.error(f"Error in background task: {e}")
return
try:
async for chunk in response_stream:
# Create and track background task
task = asyncio.create_task(handle_message_management(chunk))
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
# Immediately yield the chunk
yield chunk
# Wait for all background tasks to complete
if background_tasks:
await asyncio.gather(*background_tasks, return_exceptions=True)
except Exception as e:
logging.error(f"Error in stream processing: {e}")
# Clean up any remaining background tasks
for task in background_tasks:
if not task.done():
task.cancel()
raise
2024-11-12 19:37:47 +08:00
async def process_response(
self,
response: Any,
execute_tools: bool = True
) -> None:
"""
Process complete LLM response and execute tools.
Handles non-streaming responses, parsing the complete response and
executing any tool calls according to the configured execution strategy.
"""
try:
assistant_message = await self.tool_parser.parse_response(response)
await self.add_message(self.thread_id, assistant_message)
if execute_tools and 'tool_calls' in assistant_message and assistant_message['tool_calls']:
results = await self.tool_executor.execute_tool_calls(
tool_calls=assistant_message['tool_calls'],
available_functions=self.available_functions,
thread_id=self.thread_id,
executed_tool_calls=self.processed_tool_calls
)
for result in results:
await self.add_message(self.thread_id, result)
logging.info(f"Tool execution result: {result}")
except Exception as e:
logging.error(f"Error processing response: {e}")
response_content = response.choices[0].message.get('content', '')
await self.add_message(self.thread_id, {
"role": "assistant",
"content": response_content or ""
2024-11-13 21:28:28 +08:00
})