mirror of https://github.com/kortix-ai/suna.git
save assistant message id in tool result stream
This commit is contained in:
parent
fe7ece2880
commit
ecf78291cf
|
@ -36,6 +36,7 @@ class ToolExecutionContext:
|
||||||
function_name: Optional[str] = None
|
function_name: Optional[str] = None
|
||||||
xml_tag_name: Optional[str] = None
|
xml_tag_name: Optional[str] = None
|
||||||
error: Optional[Exception] = None
|
error: Optional[Exception] = None
|
||||||
|
assistant_message_id: Optional[str] = None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessorConfig:
|
class ProcessorConfig:
|
||||||
|
@ -139,7 +140,6 @@ class ResponseProcessor:
|
||||||
# logger.debug(f"Starting to process streaming response for thread {thread_id}")
|
# logger.debug(f"Starting to process streaming response for thread {thread_id}")
|
||||||
logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
|
||||||
f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}")
|
f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}")
|
||||||
logger.info(f"Avoiding duplicate tool results using tracking mechanism")
|
|
||||||
|
|
||||||
# if config.max_xml_tool_calls > 0:
|
# if config.max_xml_tool_calls > 0:
|
||||||
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
||||||
|
@ -215,7 +215,8 @@ class ResponseProcessor:
|
||||||
# Create a context for this tool execution
|
# Create a context for this tool execution
|
||||||
context = self._create_tool_context(
|
context = self._create_tool_context(
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
tool_index=tool_index
|
tool_index=tool_index,
|
||||||
|
assistant_message_id=last_assistant_message_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute tool if needed, but in background
|
# Execute tool if needed, but in background
|
||||||
|
@ -326,7 +327,8 @@ class ResponseProcessor:
|
||||||
# Create a context for this tool execution
|
# Create a context for this tool execution
|
||||||
context = self._create_tool_context(
|
context = self._create_tool_context(
|
||||||
tool_call=tool_call_data,
|
tool_call=tool_call_data,
|
||||||
tool_index=tool_index
|
tool_index=tool_index,
|
||||||
|
assistant_message_id=last_assistant_message_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield tool execution start message
|
# Yield tool execution start message
|
||||||
|
@ -375,7 +377,7 @@ class ResponseProcessor:
|
||||||
context = execution["context"]
|
context = execution["context"]
|
||||||
context.result = result
|
context.result = result
|
||||||
else:
|
else:
|
||||||
context = self._create_tool_context(tool_call, tool_index)
|
context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id)
|
||||||
context.result = result
|
context.result = result
|
||||||
|
|
||||||
# Skip yielding if already yielded during streaming
|
# Skip yielding if already yielded during streaming
|
||||||
|
@ -408,7 +410,7 @@ class ResponseProcessor:
|
||||||
context = execution["context"]
|
context = execution["context"]
|
||||||
context.error = e
|
context.error = e
|
||||||
else:
|
else:
|
||||||
context = self._create_tool_context(tool_call, tool_index)
|
context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id)
|
||||||
context.error = e
|
context.error = e
|
||||||
|
|
||||||
# Yield error status for the tool
|
# Yield error status for the tool
|
||||||
|
@ -542,14 +544,14 @@ class ResponseProcessor:
|
||||||
# We need a deterministic order, sort by index
|
# We need a deterministic order, sort by index
|
||||||
for tool_idx in sorted(tool_results_map.keys()):
|
for tool_idx in sorted(tool_results_map.keys()):
|
||||||
tool_call, result = tool_results_map[tool_idx]
|
tool_call, result = tool_results_map[tool_idx]
|
||||||
context = self._create_tool_context(tool_call, tool_idx)
|
context = self._create_tool_context(tool_call, tool_idx, last_assistant_message_id)
|
||||||
context.result = result
|
context.result = result
|
||||||
|
|
||||||
# Yield start status (even if streamed, yield again here for strict order)
|
# Yield start status (even if streamed, yield again here for strict order)
|
||||||
yield self._yield_tool_started(context, thread_run_id)
|
yield self._yield_tool_started(context, thread_run_id)
|
||||||
|
|
||||||
# Save result to DB and get ID
|
# Save result to DB and get ID
|
||||||
tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy)
|
tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy, assistant_message_id=last_assistant_message_id)
|
||||||
if tool_msg_id:
|
if tool_msg_id:
|
||||||
tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference
|
tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference
|
||||||
else:
|
else:
|
||||||
|
@ -754,13 +756,14 @@ class ResponseProcessor:
|
||||||
thread_id,
|
thread_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
result,
|
result,
|
||||||
config.xml_adding_strategy
|
config.xml_adding_strategy,
|
||||||
|
assistant_message_id=assistant_message_id
|
||||||
)
|
)
|
||||||
if message_id:
|
if message_id:
|
||||||
tool_result_message_ids[tool_index] = message_id
|
tool_result_message_ids[tool_index] = message_id
|
||||||
|
|
||||||
# Create context for tool result
|
# Create context for tool result
|
||||||
context = self._create_tool_context(tool_call, tool_index)
|
context = self._create_tool_context(tool_call, tool_index, assistant_message_id)
|
||||||
context.result = result
|
context.result = result
|
||||||
|
|
||||||
# Yield tool execution result
|
# Yield tool execution result
|
||||||
|
@ -1190,7 +1193,8 @@ class ResponseProcessor:
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
tool_call: Dict[str, Any],
|
tool_call: Dict[str, Any],
|
||||||
result: ToolResult,
|
result: ToolResult,
|
||||||
strategy: Union[XmlAddingStrategy, str] = "assistant_message"
|
strategy: Union[XmlAddingStrategy, str] = "assistant_message",
|
||||||
|
assistant_message_id: Optional[str] = None
|
||||||
) -> Optional[str]: # Return the message ID
|
) -> Optional[str]: # Return the message ID
|
||||||
"""Add a tool result to the conversation thread based on the specified format.
|
"""Add a tool result to the conversation thread based on the specified format.
|
||||||
|
|
||||||
|
@ -1205,9 +1209,17 @@ class ResponseProcessor:
|
||||||
result: The result from the tool execution
|
result: The result from the tool execution
|
||||||
strategy: How to add XML tool results to the conversation
|
strategy: How to add XML tool results to the conversation
|
||||||
("user_message", "assistant_message", or "inline_edit")
|
("user_message", "assistant_message", or "inline_edit")
|
||||||
|
assistant_message_id: ID of the assistant message that generated this tool call
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
message_id = None # Initialize message_id
|
message_id = None # Initialize message_id
|
||||||
|
|
||||||
|
# Create metadata with assistant_message_id if provided
|
||||||
|
metadata = {}
|
||||||
|
if assistant_message_id:
|
||||||
|
metadata["assistant_message_id"] = assistant_message_id
|
||||||
|
logger.info(f"Linking tool result to assistant message: {assistant_message_id}")
|
||||||
|
|
||||||
# Check if this is a native function call (has id field)
|
# Check if this is a native function call (has id field)
|
||||||
if "id" in tool_call:
|
if "id" in tool_call:
|
||||||
# Format as a proper tool message according to OpenAI spec
|
# Format as a proper tool message according to OpenAI spec
|
||||||
|
@ -1246,7 +1258,8 @@ class ResponseProcessor:
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
type="tool", # Special type for tool responses
|
type="tool", # Special type for tool responses
|
||||||
content=tool_message,
|
content=tool_message,
|
||||||
is_llm_message=True
|
is_llm_message=True,
|
||||||
|
metadata=metadata
|
||||||
)
|
)
|
||||||
return message_id # Return the message ID
|
return message_id # Return the message ID
|
||||||
|
|
||||||
|
@ -1255,7 +1268,7 @@ class ResponseProcessor:
|
||||||
result_role = "user" if strategy == "user_message" else "assistant"
|
result_role = "user" if strategy == "user_message" else "assistant"
|
||||||
|
|
||||||
# Create a context for consistent formatting
|
# Create a context for consistent formatting
|
||||||
context = self._create_tool_context(tool_call, 0) # Index doesn't matter for DB
|
context = self._create_tool_context(tool_call, 0, assistant_message_id)
|
||||||
context.result = result
|
context.result = result
|
||||||
|
|
||||||
# Format the content using the formatting helper
|
# Format the content using the formatting helper
|
||||||
|
@ -1271,7 +1284,8 @@ class ResponseProcessor:
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
type="tool",
|
type="tool",
|
||||||
content=result_message,
|
content=result_message,
|
||||||
is_llm_message=True
|
is_llm_message=True,
|
||||||
|
metadata=metadata
|
||||||
)
|
)
|
||||||
return message_id # Return the message ID
|
return message_id # Return the message ID
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1286,7 +1300,8 @@ class ResponseProcessor:
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
type="tool",
|
type="tool",
|
||||||
content=fallback_message,
|
content=fallback_message,
|
||||||
is_llm_message=True
|
is_llm_message=True,
|
||||||
|
metadata={"assistant_message_id": assistant_message_id} if assistant_message_id else {}
|
||||||
)
|
)
|
||||||
return message_id # Return the message ID
|
return message_id # Return the message ID
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
|
@ -1323,7 +1338,8 @@ class ResponseProcessor:
|
||||||
"result": "Error: No result available in context",
|
"result": "Error: No result available in context",
|
||||||
"tool_index": context.tool_index,
|
"tool_index": context.tool_index,
|
||||||
"tool_message_id": tool_message_id,
|
"tool_message_id": tool_message_id,
|
||||||
"thread_run_id": thread_run_id
|
"thread_run_id": thread_run_id,
|
||||||
|
"assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None
|
||||||
}
|
}
|
||||||
|
|
||||||
formatted_result = self._format_xml_tool_result(context.tool_call, context.result)
|
formatted_result = self._format_xml_tool_result(context.tool_call, context.result)
|
||||||
|
@ -1334,14 +1350,16 @@ class ResponseProcessor:
|
||||||
"result": formatted_result,
|
"result": formatted_result,
|
||||||
"tool_index": context.tool_index,
|
"tool_index": context.tool_index,
|
||||||
"tool_message_id": tool_message_id,
|
"tool_message_id": tool_message_id,
|
||||||
"thread_run_id": thread_run_id
|
"thread_run_id": thread_run_id,
|
||||||
|
"assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int) -> ToolExecutionContext:
|
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int, assistant_message_id: Optional[str] = None) -> ToolExecutionContext:
|
||||||
"""Create a tool execution context with display name populated."""
|
"""Create a tool execution context with display name populated."""
|
||||||
context = ToolExecutionContext(
|
context = ToolExecutionContext(
|
||||||
tool_call=tool_call,
|
tool_call=tool_call,
|
||||||
tool_index=tool_index
|
tool_index=tool_index,
|
||||||
|
assistant_message_id=assistant_message_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set function_name and xml_tag_name fields
|
# Set function_name and xml_tag_name fields
|
||||||
|
|
Loading…
Reference in New Issue