save assistant message id in tool result stream

This commit is contained in:
marko-kraemer 2025-04-18 02:17:29 +01:00
parent fe7ece2880
commit ecf78291cf
2 changed files with 37 additions and 19 deletions

View File

@ -36,6 +36,7 @@ class ToolExecutionContext:
function_name: Optional[str] = None function_name: Optional[str] = None
xml_tag_name: Optional[str] = None xml_tag_name: Optional[str] = None
error: Optional[Exception] = None error: Optional[Exception] = None
assistant_message_id: Optional[str] = None
@dataclass @dataclass
class ProcessorConfig: class ProcessorConfig:
@ -139,7 +140,6 @@ class ResponseProcessor:
# logger.debug(f"Starting to process streaming response for thread {thread_id}") # logger.debug(f"Starting to process streaming response for thread {thread_id}")
logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, " logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}") f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}")
logger.info(f"Avoiding duplicate tool results using tracking mechanism")
# if config.max_xml_tool_calls > 0: # if config.max_xml_tool_calls > 0:
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}") # logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
@ -215,7 +215,8 @@ class ResponseProcessor:
# Create a context for this tool execution # Create a context for this tool execution
context = self._create_tool_context( context = self._create_tool_context(
tool_call=tool_call, tool_call=tool_call,
tool_index=tool_index tool_index=tool_index,
assistant_message_id=last_assistant_message_id
) )
# Execute tool if needed, but in background # Execute tool if needed, but in background
@ -326,7 +327,8 @@ class ResponseProcessor:
# Create a context for this tool execution # Create a context for this tool execution
context = self._create_tool_context( context = self._create_tool_context(
tool_call=tool_call_data, tool_call=tool_call_data,
tool_index=tool_index tool_index=tool_index,
assistant_message_id=last_assistant_message_id
) )
# Yield tool execution start message # Yield tool execution start message
@ -375,7 +377,7 @@ class ResponseProcessor:
context = execution["context"] context = execution["context"]
context.result = result context.result = result
else: else:
context = self._create_tool_context(tool_call, tool_index) context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id)
context.result = result context.result = result
# Skip yielding if already yielded during streaming # Skip yielding if already yielded during streaming
@ -408,7 +410,7 @@ class ResponseProcessor:
context = execution["context"] context = execution["context"]
context.error = e context.error = e
else: else:
context = self._create_tool_context(tool_call, tool_index) context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id)
context.error = e context.error = e
# Yield error status for the tool # Yield error status for the tool
@ -542,14 +544,14 @@ class ResponseProcessor:
# We need a deterministic order, sort by index # We need a deterministic order, sort by index
for tool_idx in sorted(tool_results_map.keys()): for tool_idx in sorted(tool_results_map.keys()):
tool_call, result = tool_results_map[tool_idx] tool_call, result = tool_results_map[tool_idx]
context = self._create_tool_context(tool_call, tool_idx) context = self._create_tool_context(tool_call, tool_idx, last_assistant_message_id)
context.result = result context.result = result
# Yield start status (even if streamed, yield again here for strict order) # Yield start status (even if streamed, yield again here for strict order)
yield self._yield_tool_started(context, thread_run_id) yield self._yield_tool_started(context, thread_run_id)
# Save result to DB and get ID # Save result to DB and get ID
tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy) tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy, assistant_message_id=last_assistant_message_id)
if tool_msg_id: if tool_msg_id:
tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference
else: else:
@ -754,13 +756,14 @@ class ResponseProcessor:
thread_id, thread_id,
tool_call, tool_call,
result, result,
config.xml_adding_strategy config.xml_adding_strategy,
assistant_message_id=assistant_message_id
) )
if message_id: if message_id:
tool_result_message_ids[tool_index] = message_id tool_result_message_ids[tool_index] = message_id
# Create context for tool result # Create context for tool result
context = self._create_tool_context(tool_call, tool_index) context = self._create_tool_context(tool_call, tool_index, assistant_message_id)
context.result = result context.result = result
# Yield tool execution result # Yield tool execution result
@ -1190,7 +1193,8 @@ class ResponseProcessor:
thread_id: str, thread_id: str,
tool_call: Dict[str, Any], tool_call: Dict[str, Any],
result: ToolResult, result: ToolResult,
strategy: Union[XmlAddingStrategy, str] = "assistant_message" strategy: Union[XmlAddingStrategy, str] = "assistant_message",
assistant_message_id: Optional[str] = None
) -> Optional[str]: # Return the message ID ) -> Optional[str]: # Return the message ID
"""Add a tool result to the conversation thread based on the specified format. """Add a tool result to the conversation thread based on the specified format.
@ -1205,9 +1209,17 @@ class ResponseProcessor:
result: The result from the tool execution result: The result from the tool execution
strategy: How to add XML tool results to the conversation strategy: How to add XML tool results to the conversation
("user_message", "assistant_message", or "inline_edit") ("user_message", "assistant_message", or "inline_edit")
assistant_message_id: ID of the assistant message that generated this tool call
""" """
try: try:
message_id = None # Initialize message_id message_id = None # Initialize message_id
# Create metadata with assistant_message_id if provided
metadata = {}
if assistant_message_id:
metadata["assistant_message_id"] = assistant_message_id
logger.info(f"Linking tool result to assistant message: {assistant_message_id}")
# Check if this is a native function call (has id field) # Check if this is a native function call (has id field)
if "id" in tool_call: if "id" in tool_call:
# Format as a proper tool message according to OpenAI spec # Format as a proper tool message according to OpenAI spec
@ -1246,7 +1258,8 @@ class ResponseProcessor:
thread_id=thread_id, thread_id=thread_id,
type="tool", # Special type for tool responses type="tool", # Special type for tool responses
content=tool_message, content=tool_message,
is_llm_message=True is_llm_message=True,
metadata=metadata
) )
return message_id # Return the message ID return message_id # Return the message ID
@ -1255,7 +1268,7 @@ class ResponseProcessor:
result_role = "user" if strategy == "user_message" else "assistant" result_role = "user" if strategy == "user_message" else "assistant"
# Create a context for consistent formatting # Create a context for consistent formatting
context = self._create_tool_context(tool_call, 0) # Index doesn't matter for DB context = self._create_tool_context(tool_call, 0, assistant_message_id)
context.result = result context.result = result
# Format the content using the formatting helper # Format the content using the formatting helper
@ -1271,7 +1284,8 @@ class ResponseProcessor:
thread_id=thread_id, thread_id=thread_id,
type="tool", type="tool",
content=result_message, content=result_message,
is_llm_message=True is_llm_message=True,
metadata=metadata
) )
return message_id # Return the message ID return message_id # Return the message ID
except Exception as e: except Exception as e:
@ -1286,7 +1300,8 @@ class ResponseProcessor:
thread_id=thread_id, thread_id=thread_id,
type="tool", type="tool",
content=fallback_message, content=fallback_message,
is_llm_message=True is_llm_message=True,
metadata={"assistant_message_id": assistant_message_id} if assistant_message_id else {}
) )
return message_id # Return the message ID return message_id # Return the message ID
except Exception as e2: except Exception as e2:
@ -1323,7 +1338,8 @@ class ResponseProcessor:
"result": "Error: No result available in context", "result": "Error: No result available in context",
"tool_index": context.tool_index, "tool_index": context.tool_index,
"tool_message_id": tool_message_id, "tool_message_id": tool_message_id,
"thread_run_id": thread_run_id "thread_run_id": thread_run_id,
"assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None
} }
formatted_result = self._format_xml_tool_result(context.tool_call, context.result) formatted_result = self._format_xml_tool_result(context.tool_call, context.result)
@ -1334,14 +1350,16 @@ class ResponseProcessor:
"result": formatted_result, "result": formatted_result,
"tool_index": context.tool_index, "tool_index": context.tool_index,
"tool_message_id": tool_message_id, "tool_message_id": tool_message_id,
"thread_run_id": thread_run_id "thread_run_id": thread_run_id,
"assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None
} }
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int) -> ToolExecutionContext: def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int, assistant_message_id: Optional[str] = None) -> ToolExecutionContext:
"""Create a tool execution context with display name populated.""" """Create a tool execution context with display name populated."""
context = ToolExecutionContext( context = ToolExecutionContext(
tool_call=tool_call, tool_call=tool_call,
tool_index=tool_index tool_index=tool_index,
assistant_message_id=assistant_message_id
) )
# Set function_name and xml_tag_name fields # Set function_name and xml_tag_name fields