mirror of https://github.com/kortix-ai/suna.git
wip
This commit is contained in:
parent
2b31379b89
commit
b9a4beb136
|
@ -98,8 +98,8 @@ class ResponseProcessor:
|
|||
self.tool_registry = tool_registry
|
||||
self.add_message = add_message_callback
|
||||
self.trace = trace or langfuse.trace(name="anonymous:response_processor")
|
||||
# Initialize the XML parser with backwards compatibility
|
||||
self.xml_parser = XMLToolParser(strict_mode=False)
|
||||
# Initialize the XML parser
|
||||
self.xml_parser = XMLToolParser()
|
||||
self.is_agent_builder = is_agent_builder
|
||||
self.target_agent_id = target_agent_id
|
||||
self.agent_config = agent_config
|
||||
|
@ -1047,80 +1047,6 @@ class ResponseProcessor:
|
|||
)
|
||||
if end_msg_obj: yield format_for_yield(end_msg_obj)
|
||||
|
||||
# XML parsing methods
|
||||
def _extract_tag_content(self, xml_chunk: str, tag_name: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract content between opening and closing tags, handling nested tags."""
|
||||
start_tag = f'<{tag_name}'
|
||||
end_tag = f'</{tag_name}>'
|
||||
|
||||
try:
|
||||
# Find start tag position
|
||||
start_pos = xml_chunk.find(start_tag)
|
||||
if start_pos == -1:
|
||||
return None, xml_chunk
|
||||
|
||||
# Find end of opening tag
|
||||
tag_end = xml_chunk.find('>', start_pos)
|
||||
if tag_end == -1:
|
||||
return None, xml_chunk
|
||||
|
||||
# Find matching closing tag
|
||||
content_start = tag_end + 1
|
||||
nesting_level = 1
|
||||
pos = content_start
|
||||
|
||||
while nesting_level > 0 and pos < len(xml_chunk):
|
||||
next_start = xml_chunk.find(start_tag, pos)
|
||||
next_end = xml_chunk.find(end_tag, pos)
|
||||
|
||||
if next_end == -1:
|
||||
return None, xml_chunk
|
||||
|
||||
if next_start != -1 and next_start < next_end:
|
||||
nesting_level += 1
|
||||
pos = next_start + len(start_tag)
|
||||
else:
|
||||
nesting_level -= 1
|
||||
if nesting_level == 0:
|
||||
content = xml_chunk[content_start:next_end]
|
||||
remaining = xml_chunk[next_end + len(end_tag):]
|
||||
return content, remaining
|
||||
else:
|
||||
pos = next_end + len(end_tag)
|
||||
|
||||
return None, xml_chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting tag content: {e}")
|
||||
self.trace.event(name="error_extracting_tag_content", level="ERROR", status_message=(f"Error extracting tag content: {e}"))
|
||||
return None, xml_chunk
|
||||
|
||||
def _extract_attribute(self, opening_tag: str, attr_name: str) -> Optional[str]:
|
||||
"""Extract attribute value from opening tag."""
|
||||
try:
|
||||
# Handle both single and double quotes with raw strings
|
||||
patterns = [
|
||||
fr'{attr_name}="([^"]*)"', # Double quotes
|
||||
fr"{attr_name}='([^']*)'", # Single quotes
|
||||
fr'{attr_name}=([^\s/>;]+)' # No quotes - fixed escape sequence
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, opening_tag)
|
||||
if match:
|
||||
value = match.group(1)
|
||||
# Unescape common XML entities
|
||||
value = value.replace('"', '"').replace(''', "'")
|
||||
value = value.replace('<', '<').replace('>', '>')
|
||||
value = value.replace('&', '&')
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting attribute: {e}")
|
||||
self.trace.event(name="error_extracting_attribute", level="ERROR", status_message=(f"Error extracting attribute: {e}"))
|
||||
return None
|
||||
|
||||
def _extract_xml_chunks(self, content: str) -> List[str]:
|
||||
"""Extract complete XML chunks using start and end pattern matching."""
|
||||
|
@ -1160,7 +1086,7 @@ class ResponseProcessor:
|
|||
current_tag = None
|
||||
|
||||
# Find the earliest occurrence of any registered tool function name
|
||||
# Since we no longer use xml_tools, check for available function names
|
||||
# Check for available function names
|
||||
available_functions = self.tool_registry.get_available_functions()
|
||||
for func_name in available_functions.keys():
|
||||
# Convert function name to potential tag name (underscore to dash)
|
||||
|
@ -1253,95 +1179,9 @@ class ResponseProcessor:
|
|||
logger.debug(f"Parsed new format tool call: {tool_call}")
|
||||
return tool_call, parsing_details
|
||||
|
||||
# Fall back to old format parsing
|
||||
# Extract tag name and validate
|
||||
tag_match = re.match(r'<([^\s>]+)', xml_chunk)
|
||||
if not tag_match:
|
||||
logger.error(f"No tag found in XML chunk: {xml_chunk}")
|
||||
self.trace.event(name="no_tag_found_in_xml_chunk", level="ERROR", status_message=(f"No tag found in XML chunk: {xml_chunk}"))
|
||||
return None
|
||||
|
||||
# This is the XML tag as it appears in the text (e.g., "create-file")
|
||||
xml_tag_name = tag_match.group(1)
|
||||
logger.info(f"Found XML tag: {xml_tag_name}")
|
||||
self.trace.event(name="found_xml_tag", level="DEFAULT", status_message=(f"Found XML tag: {xml_tag_name}"))
|
||||
|
||||
# Get tool info and schema from registry
|
||||
tool_info = self.tool_registry.get_xml_tool(xml_tag_name)
|
||||
if not tool_info or not tool_info['schema'].xml_schema:
|
||||
logger.error(f"No tool or schema found for tag: {xml_tag_name}")
|
||||
self.trace.event(name="no_tool_or_schema_found_for_tag", level="ERROR", status_message=(f"No tool or schema found for tag: {xml_tag_name}"))
|
||||
return None
|
||||
|
||||
# This is the actual function name to call (e.g., "create_file")
|
||||
function_name = tool_info['method']
|
||||
|
||||
schema = tool_info['schema'].xml_schema
|
||||
params = {}
|
||||
remaining_chunk: str = xml_chunk
|
||||
|
||||
# --- Store detailed parsing info ---
|
||||
parsing_details = {
|
||||
"attributes": {},
|
||||
"elements": {},
|
||||
"text_content": None,
|
||||
"root_content": None,
|
||||
"raw_chunk": xml_chunk # Store the original chunk for reference
|
||||
}
|
||||
# ---
|
||||
|
||||
# Process each mapping
|
||||
for mapping in schema.mappings:
|
||||
try:
|
||||
if mapping.node_type == "attribute":
|
||||
# Extract attribute from opening tag
|
||||
opening_tag = remaining_chunk.split('>', 1)[0]
|
||||
value = self._extract_attribute(opening_tag, mapping.param_name)
|
||||
if value is not None:
|
||||
params[mapping.param_name] = value
|
||||
parsing_details["attributes"][mapping.param_name] = value # Store raw attribute
|
||||
# logger.info(f"Found attribute {mapping.param_name}: {value}")
|
||||
|
||||
elif mapping.node_type == "element":
|
||||
# Extract element content
|
||||
content, new_remaining_chunk = self._extract_tag_content(remaining_chunk, mapping.path)
|
||||
if new_remaining_chunk is not None:
|
||||
remaining_chunk = new_remaining_chunk
|
||||
if content is not None:
|
||||
params[mapping.param_name] = content.strip()
|
||||
parsing_details["elements"][mapping.param_name] = content.strip() # Store raw element content
|
||||
# logger.info(f"Found element {mapping.param_name}: {content.strip()}")
|
||||
|
||||
elif mapping.node_type == "text":
|
||||
# Extract text content
|
||||
content, _ = self._extract_tag_content(remaining_chunk, xml_tag_name)
|
||||
if content is not None:
|
||||
params[mapping.param_name] = content.strip()
|
||||
parsing_details["text_content"] = content.strip() # Store raw text content
|
||||
# logger.info(f"Found text content for {mapping.param_name}: {content.strip()}")
|
||||
|
||||
elif mapping.node_type == "content":
|
||||
# Extract root content
|
||||
content, _ = self._extract_tag_content(remaining_chunk, xml_tag_name)
|
||||
if content is not None:
|
||||
params[mapping.param_name] = content.strip()
|
||||
parsing_details["root_content"] = content.strip() # Store raw root content
|
||||
# logger.info(f"Found root content for {mapping.param_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing mapping {mapping}: {e}")
|
||||
self.trace.event(name="error_processing_mapping", level="ERROR", status_message=(f"Error processing mapping {mapping}: {e}"))
|
||||
continue
|
||||
|
||||
# Create tool call with clear separation between function_name and xml_tag_name
|
||||
tool_call = {
|
||||
"function_name": function_name, # The actual method to call (e.g., create_file)
|
||||
"xml_tag_name": xml_tag_name, # The original XML tag (e.g., create-file)
|
||||
"arguments": params # The extracted parameters
|
||||
}
|
||||
|
||||
# logger.debug(f"Parsed old format tool call: {tool_call["function_name"]}")
|
||||
return tool_call, parsing_details # Return both dicts
|
||||
# If not the expected <function_calls><invoke> format, return None
|
||||
logger.error(f"XML chunk does not contain expected <function_calls><invoke> format: {xml_chunk}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing XML chunk: {e}")
|
||||
|
|
|
@ -437,7 +437,8 @@ When using the tools:
|
|||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||
|
||||
print(f"\n\n\n\n prepared_messages: {prepared_messages}\n\n\n\n")
|
||||
# print(f"\n\n\n\n prepared_messages: {prepared_messages}\n\n\n\n")
|
||||
|
||||
prepared_messages = self.context_manager.compress_messages(prepared_messages, llm_model)
|
||||
|
||||
# 5. Make LLM API call
|
||||
|
|
|
@ -18,14 +18,13 @@ from utils.logger import logger
|
|||
class SchemaType(Enum):
|
||||
"""Enumeration of supported schema types for tool definitions."""
|
||||
OPENAPI = "openapi"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""Container for tool schemas with type information.
|
||||
|
||||
Attributes:
|
||||
schema_type (SchemaType): Type of schema (OpenAPI or Custom)
|
||||
schema_type (SchemaType): Type of schema (OpenAPI)
|
||||
schema (Dict[str, Any]): The actual schema definition
|
||||
"""
|
||||
schema_type: SchemaType
|
||||
|
@ -124,12 +123,9 @@ def openapi_schema(schema: Dict[str, Any]):
|
|||
))
|
||||
return decorator
|
||||
|
||||
def custom_schema(schema: Dict[str, Any]):
|
||||
"""Decorator for custom schema tools."""
|
||||
def xml_schema(schema: Dict[str, Any]):
|
||||
"""Deprecated decorator - does nothing, kept for compatibility."""
|
||||
def decorator(func):
|
||||
logger.debug(f"Applying custom schema to function {func.__name__}")
|
||||
return _add_schema(func, ToolSchema(
|
||||
schema_type=SchemaType.CUSTOM,
|
||||
schema=schema
|
||||
))
|
||||
logger.debug(f"xml_schema decorator called on {func.__name__} - ignoring (deprecated)")
|
||||
return func
|
||||
return decorator
|
||||
|
|
|
@ -52,15 +52,9 @@ class XMLToolParser:
|
|||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
def __init__(self, strict_mode: bool = False):
|
||||
"""
|
||||
Initialize the XML tool parser.
|
||||
|
||||
Args:
|
||||
strict_mode: If True, only accept the exact format. If False,
|
||||
also try to parse legacy formats for backwards compatibility.
|
||||
"""
|
||||
self.strict_mode = strict_mode
|
||||
def __init__(self):
|
||||
"""Initialize the XML tool parser."""
|
||||
pass
|
||||
|
||||
def parse_content(self, content: str) -> List[XMLToolCall]:
|
||||
"""
|
||||
|
@ -74,7 +68,7 @@ class XMLToolParser:
|
|||
"""
|
||||
tool_calls = []
|
||||
|
||||
# First, try to find function_calls blocks
|
||||
# Find function_calls blocks
|
||||
function_calls_matches = self.FUNCTION_CALLS_PATTERN.findall(content)
|
||||
|
||||
for fc_content in function_calls_matches:
|
||||
|
@ -93,10 +87,6 @@ class XMLToolParser:
|
|||
except Exception as e:
|
||||
logger.error(f"Error parsing invoke block for {function_name}: {e}")
|
||||
|
||||
# If not in strict mode and no tool calls found, try legacy format
|
||||
if not self.strict_mode and not tool_calls:
|
||||
tool_calls.extend(self._parse_legacy_format(content))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _parse_invoke_block(
|
||||
|
@ -108,7 +98,6 @@ class XMLToolParser:
|
|||
"""Parse a single invoke block into an XMLToolCall."""
|
||||
parameters = {}
|
||||
parsing_details = {
|
||||
"format": "v2",
|
||||
"function_name": function_name,
|
||||
"raw_parameters": {}
|
||||
}
|
||||
|
@ -176,58 +165,6 @@ class XMLToolParser:
|
|||
# Return as string
|
||||
return value
|
||||
|
||||
def _parse_legacy_format(self, content: str) -> List[XMLToolCall]:
|
||||
"""
|
||||
Parse legacy XML tool formats for backwards compatibility.
|
||||
This handles formats like <tool_name>...</tool_name> or
|
||||
<tool_name param="value">...</tool_name>
|
||||
"""
|
||||
tool_calls = []
|
||||
|
||||
# Pattern for finding XML-like tags
|
||||
tag_pattern = re.compile(r'<([a-zA-Z][\w\-]*)((?:\s+[\w\-]+=["\'][^"\']*["\'])*)\s*>(.*?)</\1>', re.DOTALL)
|
||||
|
||||
for match in tag_pattern.finditer(content):
|
||||
tag_name = match.group(1)
|
||||
attributes_str = match.group(2)
|
||||
inner_content = match.group(3)
|
||||
|
||||
# Skip our own format tags
|
||||
if tag_name in ('function_calls', 'invoke', 'parameter'):
|
||||
continue
|
||||
|
||||
parameters = {}
|
||||
parsing_details = {
|
||||
"format": "legacy",
|
||||
"tag_name": tag_name,
|
||||
"attributes": {},
|
||||
"inner_content": inner_content.strip()
|
||||
}
|
||||
|
||||
# Parse attributes
|
||||
if attributes_str:
|
||||
attr_pattern = re.compile(r'([\w\-]+)=["\']([^"\']*)["\']')
|
||||
for attr_match in attr_pattern.finditer(attributes_str):
|
||||
attr_name = attr_match.group(1)
|
||||
attr_value = attr_match.group(2)
|
||||
parameters[attr_name] = self._parse_parameter_value(attr_value)
|
||||
parsing_details["attributes"][attr_name] = attr_value
|
||||
|
||||
# If there's inner content and no attributes, use it as a 'content' parameter
|
||||
if inner_content.strip() and not parameters:
|
||||
parameters['content'] = inner_content.strip()
|
||||
|
||||
# Convert tag name to function name (e.g., create-file -> create_file)
|
||||
function_name = tag_name.replace('-', '_')
|
||||
|
||||
tool_calls.append(XMLToolCall(
|
||||
function_name=function_name,
|
||||
parameters=parameters,
|
||||
raw_xml=match.group(0),
|
||||
parsing_details=parsing_details
|
||||
))
|
||||
|
||||
return tool_calls
|
||||
|
||||
def format_tool_call(self, function_name: str, parameters: Dict[str, Any]) -> str:
|
||||
"""
|
||||
|
@ -285,16 +222,15 @@ class XMLToolParser:
|
|||
|
||||
|
||||
# Convenience function for quick parsing
|
||||
def parse_xml_tool_calls(content: str, strict_mode: bool = False) -> List[XMLToolCall]:
|
||||
def parse_xml_tool_calls(content: str) -> List[XMLToolCall]:
|
||||
"""
|
||||
Parse XML tool calls from content.
|
||||
|
||||
Args:
|
||||
content: The text content potentially containing XML tool calls
|
||||
strict_mode: If True, only accept the Cursor-style format
|
||||
|
||||
Returns:
|
||||
List of parsed XMLToolCall objects
|
||||
"""
|
||||
parser = XMLToolParser(strict_mode=strict_mode)
|
||||
parser = XMLToolParser()
|
||||
return parser.parse_content(content)
|
|
@ -6,7 +6,7 @@ ENV_MODE = os.getenv("ENV_MODE", "LOCAL")
|
|||
if ENV_MODE.upper() == "PRODUCTION":
|
||||
default_level = "DEBUG"
|
||||
else:
|
||||
default_level = "WARNING"
|
||||
default_level = "INFO"
|
||||
|
||||
LOGGING_LEVEL = logging.getLevelNamesMapping().get(
|
||||
os.getenv("LOGGING_LEVEL", default_level).upper(),
|
||||
|
|
Loading…
Reference in New Issue