save assistant message id in tool result stream

This commit is contained in:
marko-kraemer 2025-04-18 02:17:29 +01:00
parent fe7ece2880
commit ecf78291cf
2 changed files with 37 additions and 19 deletions

View File

@ -36,6 +36,7 @@ class ToolExecutionContext:
function_name: Optional[str] = None
xml_tag_name: Optional[str] = None
error: Optional[Exception] = None
assistant_message_id: Optional[str] = None
@dataclass
class ProcessorConfig:
@ -139,7 +140,6 @@ class ResponseProcessor:
# logger.debug(f"Starting to process streaming response for thread {thread_id}")
logger.info(f"Config: XML={config.xml_tool_calling}, Native={config.native_tool_calling}, "
f"Execute on stream={config.execute_on_stream}, Execution strategy={config.tool_execution_strategy}")
logger.info(f"Avoiding duplicate tool results using tracking mechanism")
# if config.max_xml_tool_calls > 0:
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
@ -215,7 +215,8 @@ class ResponseProcessor:
# Create a context for this tool execution
context = self._create_tool_context(
tool_call=tool_call,
tool_index=tool_index
tool_index=tool_index,
assistant_message_id=last_assistant_message_id
)
# Execute tool if needed, but in background
@ -326,7 +327,8 @@ class ResponseProcessor:
# Create a context for this tool execution
context = self._create_tool_context(
tool_call=tool_call_data,
tool_index=tool_index
tool_index=tool_index,
assistant_message_id=last_assistant_message_id
)
# Yield tool execution start message
@ -375,7 +377,7 @@ class ResponseProcessor:
context = execution["context"]
context.result = result
else:
context = self._create_tool_context(tool_call, tool_index)
context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id)
context.result = result
# Skip yielding if already yielded during streaming
@ -408,7 +410,7 @@ class ResponseProcessor:
context = execution["context"]
context.error = e
else:
context = self._create_tool_context(tool_call, tool_index)
context = self._create_tool_context(tool_call, tool_index, last_assistant_message_id)
context.error = e
# Yield error status for the tool
@ -542,14 +544,14 @@ class ResponseProcessor:
# We need a deterministic order, sort by index
for tool_idx in sorted(tool_results_map.keys()):
tool_call, result = tool_results_map[tool_idx]
context = self._create_tool_context(tool_call, tool_idx)
context = self._create_tool_context(tool_call, tool_idx, last_assistant_message_id)
context.result = result
# Yield start status (even if streamed, yield again here for strict order)
yield self._yield_tool_started(context, thread_run_id)
# Save result to DB and get ID
tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy)
tool_msg_id = await self._add_tool_result(thread_id, tool_call, result, config.xml_adding_strategy, assistant_message_id=last_assistant_message_id)
if tool_msg_id:
tool_result_message_ids[tool_idx] = tool_msg_id # Store for reference
else:
@ -754,13 +756,14 @@ class ResponseProcessor:
thread_id,
tool_call,
result,
config.xml_adding_strategy
config.xml_adding_strategy,
assistant_message_id=assistant_message_id
)
if message_id:
tool_result_message_ids[tool_index] = message_id
# Create context for tool result
context = self._create_tool_context(tool_call, tool_index)
context = self._create_tool_context(tool_call, tool_index, assistant_message_id)
context.result = result
# Yield tool execution result
@ -1190,7 +1193,8 @@ class ResponseProcessor:
thread_id: str,
tool_call: Dict[str, Any],
result: ToolResult,
strategy: Union[XmlAddingStrategy, str] = "assistant_message"
strategy: Union[XmlAddingStrategy, str] = "assistant_message",
assistant_message_id: Optional[str] = None
) -> Optional[str]: # Return the message ID
"""Add a tool result to the conversation thread based on the specified format.
@ -1205,9 +1209,17 @@ class ResponseProcessor:
result: The result from the tool execution
strategy: How to add XML tool results to the conversation
("user_message", "assistant_message", or "inline_edit")
assistant_message_id: ID of the assistant message that generated this tool call
"""
try:
message_id = None # Initialize message_id
# Create metadata with assistant_message_id if provided
metadata = {}
if assistant_message_id:
metadata["assistant_message_id"] = assistant_message_id
logger.info(f"Linking tool result to assistant message: {assistant_message_id}")
# Check if this is a native function call (has id field)
if "id" in tool_call:
# Format as a proper tool message according to OpenAI spec
@ -1246,7 +1258,8 @@ class ResponseProcessor:
thread_id=thread_id,
type="tool", # Special type for tool responses
content=tool_message,
is_llm_message=True
is_llm_message=True,
metadata=metadata
)
return message_id # Return the message ID
@ -1255,7 +1268,7 @@ class ResponseProcessor:
result_role = "user" if strategy == "user_message" else "assistant"
# Create a context for consistent formatting
context = self._create_tool_context(tool_call, 0) # Index doesn't matter for DB
context = self._create_tool_context(tool_call, 0, assistant_message_id)
context.result = result
# Format the content using the formatting helper
@ -1271,7 +1284,8 @@ class ResponseProcessor:
thread_id=thread_id,
type="tool",
content=result_message,
is_llm_message=True
is_llm_message=True,
metadata=metadata
)
return message_id # Return the message ID
except Exception as e:
@ -1286,7 +1300,8 @@ class ResponseProcessor:
thread_id=thread_id,
type="tool",
content=fallback_message,
is_llm_message=True
is_llm_message=True,
metadata={"assistant_message_id": assistant_message_id} if assistant_message_id else {}
)
return message_id # Return the message ID
except Exception as e2:
@ -1323,7 +1338,8 @@ class ResponseProcessor:
"result": "Error: No result available in context",
"tool_index": context.tool_index,
"tool_message_id": tool_message_id,
"thread_run_id": thread_run_id
"thread_run_id": thread_run_id,
"assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None
}
formatted_result = self._format_xml_tool_result(context.tool_call, context.result)
@ -1334,14 +1350,16 @@ class ResponseProcessor:
"result": formatted_result,
"tool_index": context.tool_index,
"tool_message_id": tool_message_id,
"thread_run_id": thread_run_id
"thread_run_id": thread_run_id,
"assistant_message_id": context.assistant_message_id if hasattr(context, "assistant_message_id") else None
}
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int) -> ToolExecutionContext:
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int, assistant_message_id: Optional[str] = None) -> ToolExecutionContext:
"""Create a tool execution context with display name populated."""
context = ToolExecutionContext(
tool_call=tool_call,
tool_index=tool_index
tool_index=tool_index,
assistant_message_id=assistant_message_id
)
# Set function_name and xml_tag_name fields

View File

@ -264,7 +264,7 @@
--primary: oklch(0.985 0 0);
--primary-foreground: oklch(0.205 0 0);
--secondary: oklch(54.65% 0.246 262.87);
--secondary-foreground: oklch(0.985 0 0);
--secondary-foreground: oklch(0.985 0 0);
--muted: oklch(0.269 0 0);
--muted-foreground: oklch(0.708 0 0);
--accent: oklch(27.39% 0.005 286.03);