2024-11-18 06:36:37 +08:00
|
|
|
import logging
|
2024-11-18 08:38:31 +08:00
|
|
|
from typing import Dict, Any, Optional, List
|
2024-11-18 06:36:37 +08:00
|
|
|
from agentpress.thread_llm_response_processor import ToolParserBase
|
|
|
|
import json
|
|
|
|
import re
|
2024-11-18 08:38:31 +08:00
|
|
|
from agentpress.tool_registry import ToolRegistry
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
class XMLToolParser(ToolParserBase):
|
2024-11-18 08:38:31 +08:00
|
|
|
def __init__(self, tool_registry: Optional[ToolRegistry] = None):
|
|
|
|
self.tool_registry = tool_registry or ToolRegistry()
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
async def parse_response(self, response: Any) -> Dict[str, Any]:
|
|
|
|
response_message = response.choices[0].message
|
|
|
|
content = response_message.get('content') or ""
|
|
|
|
|
|
|
|
message = {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": content,
|
|
|
|
}
|
|
|
|
|
|
|
|
tool_calls = []
|
|
|
|
try:
|
|
|
|
xml_chunks = self._extract_xml_chunks(content)
|
|
|
|
for xml_chunk in xml_chunks:
|
2024-11-18 08:38:31 +08:00
|
|
|
tool_call = await self._parse_xml_to_tool_call(xml_chunk)
|
2024-11-18 06:36:37 +08:00
|
|
|
if tool_call:
|
|
|
|
tool_calls.append(tool_call)
|
|
|
|
|
|
|
|
if tool_calls:
|
|
|
|
message["tool_calls"] = tool_calls
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"Error parsing XML response: {e}")
|
|
|
|
|
|
|
|
return message
|
|
|
|
|
|
|
|
async def parse_stream(self, response_chunk: Any, tool_calls_buffer: Dict[int, Dict]) -> tuple[Optional[Dict[str, Any]], bool]:
|
|
|
|
content_chunk = ""
|
|
|
|
is_complete = False
|
|
|
|
|
|
|
|
if hasattr(response_chunk.choices[0], 'delta'):
|
|
|
|
delta = response_chunk.choices[0].delta
|
|
|
|
|
|
|
|
if hasattr(delta, 'content') and delta.content:
|
|
|
|
content_chunk = delta.content
|
|
|
|
tool_calls_buffer.setdefault('xml_buffer', '')
|
|
|
|
tool_calls_buffer['xml_buffer'] += content_chunk
|
|
|
|
|
|
|
|
# Process any complete XML tags
|
2024-11-18 08:38:31 +08:00
|
|
|
tool_calls = await self._process_streaming_xml(tool_calls_buffer['xml_buffer'])
|
2024-11-18 06:36:37 +08:00
|
|
|
if tool_calls:
|
|
|
|
# Clear processed content from buffer
|
|
|
|
last_end_tag = max(
|
2024-11-18 08:38:31 +08:00
|
|
|
(tool_calls_buffer['xml_buffer'].rfind(f'</{tag}>')
|
|
|
|
for tag in self.tool_registry.xml_tools.keys()),
|
|
|
|
default=-1
|
2024-11-18 06:36:37 +08:00
|
|
|
)
|
|
|
|
if last_end_tag > -1:
|
|
|
|
tool_calls_buffer['xml_buffer'] = tool_calls_buffer['xml_buffer'][last_end_tag + 1:]
|
|
|
|
|
|
|
|
return {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": content_chunk,
|
|
|
|
"tool_calls": tool_calls
|
|
|
|
}, is_complete
|
|
|
|
|
|
|
|
if hasattr(response_chunk.choices[0], 'finish_reason') and response_chunk.choices[0].finish_reason:
|
|
|
|
is_complete = True
|
|
|
|
if 'xml_buffer' in tool_calls_buffer:
|
2024-11-18 08:38:31 +08:00
|
|
|
tool_calls = await self._process_streaming_xml(tool_calls_buffer['xml_buffer'])
|
2024-11-18 06:36:37 +08:00
|
|
|
if tool_calls:
|
|
|
|
return {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": content_chunk,
|
|
|
|
"tool_calls": tool_calls
|
|
|
|
}, is_complete
|
|
|
|
|
|
|
|
return None, is_complete
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
async def _process_streaming_xml(self, content: str) -> List[Dict[str, Any]]:
|
2024-11-18 06:36:37 +08:00
|
|
|
tool_calls = []
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
# Find complete XML tags based on registered tools
|
|
|
|
for tag_name in self.tool_registry.xml_tools.keys():
|
|
|
|
start_tag = f'<{tag_name}'
|
|
|
|
end_tag = f'</{tag_name}>'
|
|
|
|
|
|
|
|
start_idx = 0
|
|
|
|
while True:
|
|
|
|
start_idx = content.find(start_tag, start_idx)
|
|
|
|
if start_idx == -1:
|
|
|
|
break
|
|
|
|
|
2024-11-18 06:36:37 +08:00
|
|
|
end_idx = content.find(end_tag, start_idx)
|
2024-11-18 08:38:31 +08:00
|
|
|
if end_idx == -1:
|
|
|
|
break
|
|
|
|
|
|
|
|
# Extract complete XML chunk
|
|
|
|
xml_chunk = content[start_idx:end_idx + len(end_tag)]
|
|
|
|
try:
|
|
|
|
tool_call = await self._parse_xml_to_tool_call(xml_chunk)
|
|
|
|
if tool_call:
|
|
|
|
tool_calls.append(tool_call)
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"Error parsing streaming XML chunk: {e}")
|
2024-11-18 06:36:37 +08:00
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
start_idx = end_idx + len(end_tag)
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
return tool_calls
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
def _extract_xml_chunks(self, content: str) -> List[str]:
|
2024-11-18 06:36:37 +08:00
|
|
|
chunks = []
|
|
|
|
current_chunk = []
|
|
|
|
in_tag = False
|
|
|
|
|
|
|
|
lines = content.split('\n')
|
|
|
|
for line in lines:
|
2024-11-18 08:38:31 +08:00
|
|
|
# Check for registered XML tags
|
|
|
|
for tag_name in self.tool_registry.xml_tools.keys():
|
|
|
|
if f'<{tag_name}' in line:
|
|
|
|
if in_tag: # Close previous tag if any
|
|
|
|
chunks.append('\n'.join(current_chunk))
|
|
|
|
current_chunk = []
|
|
|
|
in_tag = True
|
|
|
|
current_chunk = [line]
|
|
|
|
break
|
|
|
|
elif f'</{tag_name}>' in line and in_tag:
|
|
|
|
current_chunk.append(line)
|
2024-11-18 06:36:37 +08:00
|
|
|
chunks.append('\n'.join(current_chunk))
|
|
|
|
current_chunk = []
|
|
|
|
in_tag = False
|
2024-11-18 08:38:31 +08:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
if in_tag:
|
|
|
|
current_chunk.append(line)
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
if current_chunk and in_tag:
|
|
|
|
chunks.append('\n'.join(current_chunk))
|
|
|
|
|
|
|
|
return chunks
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
async def _parse_xml_to_tool_call(self, xml_chunk: str) -> Optional[Dict[str, Any]]:
|
2024-11-18 06:36:37 +08:00
|
|
|
try:
|
2024-11-18 08:38:31 +08:00
|
|
|
# Extract tag name to look up tool
|
|
|
|
tag_match = re.match(r'<([^\s>]+)', xml_chunk)
|
|
|
|
if not tag_match:
|
|
|
|
logging.error(f"No tag found in XML chunk: {xml_chunk}")
|
2024-11-18 06:36:37 +08:00
|
|
|
return None
|
2024-11-18 08:38:31 +08:00
|
|
|
|
|
|
|
tag_name = tag_match.group(1)
|
|
|
|
logging.info(f"Found XML tag: {tag_name}")
|
2024-11-18 06:36:37 +08:00
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
tool_info = self.tool_registry.get_xml_tool(tag_name)
|
|
|
|
if not tool_info:
|
|
|
|
logging.error(f"No tool found for tag: {tag_name}")
|
2024-11-18 06:36:37 +08:00
|
|
|
return None
|
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
schema = tool_info['schema'].xml_schema
|
|
|
|
if not schema:
|
|
|
|
logging.error(f"No XML schema found for tag: {tag_name}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
# Extract parameters
|
|
|
|
params = {}
|
2024-11-18 06:36:37 +08:00
|
|
|
|
2024-11-18 08:38:31 +08:00
|
|
|
# Extract attributes
|
|
|
|
for attr_name, param_name in schema.attributes.items():
|
|
|
|
attr_match = re.search(f'{attr_name}="([^"]+)"', xml_chunk)
|
|
|
|
if attr_match:
|
|
|
|
params[param_name] = attr_match.group(1)
|
|
|
|
logging.info(f"Found attribute {attr_name} -> {param_name}: {attr_match.group(1)}")
|
|
|
|
|
|
|
|
# Extract mapped parameters (both direct content and nested tags)
|
|
|
|
for xml_element, param_name in schema.param_mapping.items():
|
|
|
|
if xml_element == ".": # Root tag content
|
|
|
|
content_match = re.search(r'>(.*?)</[^>]+>$', xml_chunk, re.DOTALL)
|
|
|
|
if content_match:
|
|
|
|
content = content_match.group(1).strip()
|
|
|
|
if content: # Only set if there's actual content
|
|
|
|
params[param_name] = content
|
|
|
|
logging.info(f"Found root content for {param_name}: {content}")
|
|
|
|
else: # Nested tag
|
|
|
|
# Updated regex pattern to handle multiline content
|
|
|
|
pattern = f'<{xml_element}>(.*?)</{xml_element}>'
|
|
|
|
nested_match = re.search(pattern, xml_chunk, re.DOTALL | re.MULTILINE)
|
|
|
|
if nested_match:
|
|
|
|
params[param_name] = nested_match.group(1).strip()
|
|
|
|
logging.info(f"Found nested tag {xml_element} -> {param_name}: {nested_match.group(1)}")
|
|
|
|
|
|
|
|
if not all(param in params for param in schema.param_mapping.values()):
|
|
|
|
missing = [param for param in schema.param_mapping.values() if param not in params]
|
|
|
|
logging.error(f"Missing required parameters: {missing}")
|
|
|
|
logging.error(f"Current params: {params}")
|
|
|
|
logging.error(f"XML chunk: {xml_chunk}")
|
|
|
|
return None
|
|
|
|
|
|
|
|
tool_call = {
|
|
|
|
"id": f"tool_{hash(xml_chunk)}",
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": tool_info['method'],
|
|
|
|
"arguments": json.dumps(params)
|
2024-11-18 06:36:37 +08:00
|
|
|
}
|
2024-11-18 08:38:31 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
logging.info(f"Created tool call: {tool_call}")
|
|
|
|
return tool_call
|
2024-11-18 06:36:37 +08:00
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logging.error(f"Error parsing XML chunk: {e}")
|
2024-11-18 08:38:31 +08:00
|
|
|
logging.error(f"XML chunk was: {xml_chunk}")
|
2024-11-18 06:36:37 +08:00
|
|
|
return None
|