diff --git a/backend/agent/run.py b/backend/agent/run.py
index e14576c8..851d7056 100644
--- a/backend/agent/run.py
+++ b/backend/agent/run.py
@@ -1,9 +1,9 @@
import os
import json
import asyncio
-from typing import Optional, Dict, List, Any, AsyncGenerator, Tuple
-from dataclasses import dataclass, field
+from typing import Optional
+# from agent.tools.message_tool import MessageTool
from agent.tools.message_tool import MessageTool
from agent.tools.sb_deploy_tool import SandboxDeployTool
from agent.tools.sb_expose_tool import SandboxExposeTool
@@ -27,699 +27,13 @@ from agent.tools.sb_vision_tool import SandboxVisionTool
from agent.tools.sb_image_edit_tool import SandboxImageEditTool
from services.langfuse import langfuse
from langfuse.client import StatefulTraceClient
+from services.langfuse import langfuse
from agent.gemini_prompt import get_gemini_system_prompt
from agent.tools.mcp_tool_wrapper import MCPToolWrapper
from agentpress.tool import SchemaType
load_dotenv()
-@dataclass
-class AgentRunConfig:
- thread_id: str
- project_id: str
- stream: bool
- thread_manager: Optional[ThreadManager] = None
- native_max_auto_continues: int = 25
- max_iterations: int = 100
- model_name: str = "anthropic/claude-sonnet-4-20250514"
- enable_thinking: Optional[bool] = False
- reasoning_effort: Optional[str] = 'low'
- enable_context_manager: bool = True
- agent_config: Optional[dict] = None
- trace: Optional[StatefulTraceClient] = None
- is_agent_builder: Optional[bool] = False
- target_agent_id: Optional[str] = None
-
-@dataclass
-class ExecutionContext:
- client: Any
- account_id: str
- project_data: Dict
- sandbox_info: Dict
- mcp_wrapper_instance: Optional[MCPToolWrapper] = None
-
-class AgentExecutionError(Exception):
- pass
-
-def get_model_max_tokens(model_name: str) -> Optional[int]:
- if "sonnet" in model_name.lower():
- return 8192
- elif "gpt-4" in model_name.lower():
- return 4096
- elif "gemini-2.5-pro" in model_name.lower():
- return 64000
- return None
-
-def is_vision_model(model_name: str) -> bool:
- return any(x in model_name.lower() for x in ['gemini', 'anthropic', 'openai'])
-
-async def setup_execution_context(config: AgentRunConfig) -> ExecutionContext:
- client = await config.thread_manager.db.client
-
- account_id = await get_account_id_from_thread(client, config.thread_id)
- if not account_id:
- raise AgentExecutionError("Could not determine account ID for thread")
-
- project = await client.table('projects').select('*').eq('project_id', config.project_id).execute()
- if not project.data or len(project.data) == 0:
- raise AgentExecutionError(f"Project {config.project_id} not found")
-
- project_data = project.data[0]
- sandbox_info = project_data.get('sandbox', {})
- if not sandbox_info.get('id'):
- raise AgentExecutionError(f"No sandbox found for project {config.project_id}")
-
- return ExecutionContext(
- client=client,
- account_id=account_id,
- project_data=project_data,
- sandbox_info=sandbox_info
- )
-
-def register_agent_builder_tools(thread_manager: ThreadManager, target_agent_id: str):
- from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool
- from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool
- from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool
- from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool
- from agent.tools.agent_builder_tools.trigger_tool import TriggerTool
- from services.supabase import DBConnection
-
- db = DBConnection()
- thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
- thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
- thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
- thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
- thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
-
-def register_default_tools(thread_manager: ThreadManager, agent_config: AgentRunConfig):
- thread_manager.add_tool(SandboxShellTool, project_id=agent_config.project_id, thread_manager=thread_manager)
- thread_manager.add_tool(SandboxFilesTool, project_id=agent_config.project_id, thread_manager=thread_manager)
- thread_manager.add_tool(SandboxBrowserTool, project_id=agent_config.project_id, thread_id=agent_config.thread_id, thread_manager=thread_manager)
- thread_manager.add_tool(SandboxDeployTool, project_id=agent_config.project_id, thread_manager=thread_manager)
- thread_manager.add_tool(SandboxExposeTool, project_id=agent_config.project_id, thread_manager=thread_manager)
- thread_manager.add_tool(ExpandMessageTool, thread_id=agent_config.thread_id, thread_manager=thread_manager)
- thread_manager.add_tool(MessageTool)
- thread_manager.add_tool(SandboxWebSearchTool, project_id=agent_config.project_id, thread_manager=thread_manager)
- thread_manager.add_tool(SandboxVisionTool, project_id=agent_config.project_id, thread_id=agent_config.thread_id, thread_manager=thread_manager)
- thread_manager.add_tool(SandboxImageEditTool, project_id=agent_config.project_id, thread_id=agent_config.thread_id, thread_manager=thread_manager)
-
- if config.RAPID_API_KEY:
- thread_manager.add_tool(DataProvidersTool)
-
-def register_custom_tools(thread_manager: ThreadManager, agent_config: AgentRunConfig, enabled_tools: Dict):
- thread_manager.add_tool(ExpandMessageTool, thread_id=agent_config.thread_id, thread_manager=thread_manager)
- thread_manager.add_tool(MessageTool)
-
- tool_mapping = {
- 'sb_shell_tool': (SandboxShellTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}),
- 'sb_files_tool': (SandboxFilesTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}),
- 'sb_browser_tool': (SandboxBrowserTool, {'project_id': agent_config.project_id, 'thread_id': agent_config.thread_id, 'thread_manager': thread_manager}),
- 'sb_deploy_tool': (SandboxDeployTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}),
- 'sb_expose_tool': (SandboxExposeTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}),
- 'web_search_tool': (SandboxWebSearchTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}),
- 'sb_vision_tool': (SandboxVisionTool, {'project_id': agent_config.project_id, 'thread_id': agent_config.thread_id, 'thread_manager': thread_manager}),
- }
-
- for tool_name, (tool_class, kwargs) in tool_mapping.items():
- if enabled_tools.get(tool_name, {}).get('enabled', False):
- thread_manager.add_tool(tool_class, **kwargs)
-
- if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False):
- thread_manager.add_tool(DataProvidersTool)
-
-async def setup_pipedream_mcp_config(custom_mcp: Dict, account_id: str) -> Dict:
- if 'config' not in custom_mcp:
- custom_mcp['config'] = {}
-
- if not custom_mcp['config'].get('external_user_id'):
- profile_id = custom_mcp['config'].get('profile_id')
- if profile_id:
- try:
- from pipedream.profiles import get_profile_manager
- from services.supabase import DBConnection
- profile_db = DBConnection()
- profile_manager = get_profile_manager(profile_db)
-
- profile = await profile_manager.get_profile(account_id, profile_id)
- if profile:
- custom_mcp['config']['external_user_id'] = profile.external_user_id
- logger.info(f"Retrieved external_user_id from profile {profile_id} for Pipedream MCP")
- else:
- logger.error(f"Could not find profile {profile_id} for Pipedream MCP")
- except Exception as e:
- logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}")
-
- if 'headers' in custom_mcp['config'] and 'x-pd-app-slug' in custom_mcp['config']['headers']:
- custom_mcp['config']['app_slug'] = custom_mcp['config']['headers']['x-pd-app-slug']
-
- return custom_mcp
-
-def create_mcp_config(custom_mcp: Dict, custom_type: str) -> Dict:
- return {
- 'name': custom_mcp['name'],
- 'qualifiedName': f"custom_{custom_type}_{custom_mcp['name'].replace(' ', '_').lower()}",
- 'config': custom_mcp['config'],
- 'enabledTools': custom_mcp.get('enabledTools', []),
- 'instructions': custom_mcp.get('instructions', ''),
- 'isCustom': True,
- 'customType': custom_type
- }
-
-async def setup_mcp_tools(thread_manager: ThreadManager, config: AgentRunConfig, context: ExecutionContext) -> Optional[MCPToolWrapper]:
- if not config.agent_config:
- return None
-
- all_mcps = []
-
- if config.agent_config.get('configured_mcps'):
- all_mcps.extend(config.agent_config['configured_mcps'])
-
- if config.agent_config.get('custom_mcps'):
- for custom_mcp in config.agent_config['custom_mcps']:
- custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse'))
-
- if custom_type == 'pipedream':
- custom_mcp = await setup_pipedream_mcp_config(custom_mcp, context.account_id)
-
- mcp_config = create_mcp_config(custom_mcp, custom_type)
- all_mcps.append(mcp_config)
-
- if not all_mcps:
- return None
-
- logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers")
- thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps)
-
- mcp_wrapper_instance = None
- for tool_name, tool_info in thread_manager.tool_registry.tools.items():
- if isinstance(tool_info['instance'], MCPToolWrapper):
- mcp_wrapper_instance = tool_info['instance']
- break
-
- if mcp_wrapper_instance:
- try:
- await mcp_wrapper_instance.initialize_and_register_tools()
- logger.info("MCP tools initialized successfully")
-
- updated_schemas = mcp_wrapper_instance.get_schemas()
- logger.info(f"MCP wrapper has {len(updated_schemas)} schemas available")
-
- for method_name, schema_list in updated_schemas.items():
- if method_name != 'call_mcp_tool':
- for schema in schema_list:
- if schema.schema_type == SchemaType.OPENAPI:
- thread_manager.tool_registry.tools[method_name] = {
- "instance": mcp_wrapper_instance,
- "schema": schema
- }
- logger.info(f"Registered dynamic MCP tool: {method_name}")
-
- all_tools = list(thread_manager.tool_registry.tools.keys())
- logger.info(f"All registered tools after MCP initialization: {all_tools}")
-
- except Exception as e:
- logger.error(f"Failed to initialize MCP tools: {e}")
-
- return mcp_wrapper_instance
-
-def setup_tools(thread_manager: ThreadManager, agent_config: AgentRunConfig) -> None:
- if agent_config.is_agent_builder:
- register_agent_builder_tools(thread_manager, agent_config.target_agent_id)
-
- enabled_tools = None
- if agent_config.agent_config and 'agentpress_tools' in agent_config.agent_config:
- enabled_tools = agent_config.agent_config['agentpress_tools']
- logger.info("Using custom tool configuration from agent")
-
- if enabled_tools is None:
- logger.info("No agent specified - registering all tools for full Suna capabilities")
- register_default_tools(thread_manager, agent_config)
- else:
- logger.info("Custom agent specified - registering only enabled tools")
- register_custom_tools(thread_manager, agent_config, enabled_tools)
-
-def get_base_system_prompt(model_name: str) -> str:
- if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower():
- return get_gemini_system_prompt()
- return get_system_prompt()
-
-def add_sample_response(system_content: str, model_name: str) -> str:
- if "anthropic" not in model_name.lower():
- sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt')
- with open(sample_response_path, 'r') as file:
- sample_response = file.read()
- return system_content + "\n\n " + sample_response + ""
- return system_content
-
-async def add_knowledge_base_context(system_content: str, config: AgentRunConfig) -> str:
- if not await is_enabled("knowledge_base"):
- return system_content
-
- try:
- from services.supabase import DBConnection
- kb_db = DBConnection()
- kb_client = await kb_db.client
-
- current_agent_id = config.agent_config.get('agent_id') if config.agent_config else None
-
- kb_result = await kb_client.rpc('get_combined_knowledge_base_context', {
- 'p_thread_id': config.thread_id,
- 'p_agent_id': current_agent_id,
- 'p_max_tokens': 4000
- }).execute()
-
- if kb_result.data and kb_result.data.strip():
- logger.info(f"Adding combined knowledge base context to system prompt")
- return system_content + "\n\n" + kb_result.data
- else:
- logger.debug(f"No knowledge base context found")
-
- except Exception as e:
- logger.error(f"Error retrieving knowledge base context: {e}")
-
- return system_content
-
-def add_mcp_instructions(system_content: str, config: AgentRunConfig, mcp_wrapper_instance: Optional[MCPToolWrapper]) -> str:
- if not (config.agent_config and (config.agent_config.get('configured_mcps') or config.agent_config.get('custom_mcps'))):
- return system_content
-
- if not (mcp_wrapper_instance and mcp_wrapper_instance._initialized):
- return system_content
-
- mcp_info = "\n\n--- MCP Tools Available ---\n"
- mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n"
- mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n"
- mcp_info += '\n'
- mcp_info += '\n'
- mcp_info += 'value1\n'
- mcp_info += 'value2\n'
- mcp_info += '\n'
- mcp_info += '\n\n'
-
- mcp_info += "Available MCP tools:\n"
- try:
- registered_schemas = mcp_wrapper_instance.get_schemas()
- for method_name, schema_list in registered_schemas.items():
- if method_name == 'call_mcp_tool':
- continue
-
- for schema in schema_list:
- if schema.schema_type == SchemaType.OPENAPI:
- func_info = schema.schema.get('function', {})
- description = func_info.get('description', 'No description available')
-
- mcp_info += f"- **{method_name}**: {description}\n"
-
- params = func_info.get('parameters', {})
- props = params.get('properties', {})
- if props:
- mcp_info += f" Parameters: {', '.join(props.keys())}\n"
-
- except Exception as e:
- logger.error(f"Error listing MCP tools: {e}")
- mcp_info += "- Error loading MCP tool list\n"
-
- mcp_info += "\nšØ CRITICAL MCP TOOL RESULT INSTRUCTIONS šØ\n"
- mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n"
- mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n"
- mcp_info += "2. For search tools: ONLY cite URLs, sources, and information from the actual search results\n"
- mcp_info += "3. For any tool: Base your response entirely on the tool's output - do NOT add external information\n"
- mcp_info += "4. DO NOT fabricate, invent, hallucinate, or make up any sources, URLs, or data\n"
- mcp_info += "5. If you need more information, call the MCP tool again with different parameters\n"
- mcp_info += "6. When writing reports/summaries: Reference ONLY the data from MCP tool results\n"
- mcp_info += "7. If the MCP tool doesn't return enough information, explicitly state this limitation\n"
- mcp_info += "8. Always double-check that every fact, URL, and reference comes from the MCP tool output\n"
- mcp_info += "\nIMPORTANT: MCP tool results are your PRIMARY and ONLY source of truth for external data!\n"
- mcp_info += "NEVER supplement MCP results with your training data or make assumptions beyond what the tools provide.\n"
-
- return system_content + mcp_info
-
-async def build_system_prompt(config: AgentRunConfig, mcp_wrapper_instance: Optional[MCPToolWrapper]) -> Dict:
- base_prompt = get_base_system_prompt(config.model_name)
- system_content = add_sample_response(base_prompt, config.model_name)
-
- if config.agent_config and config.agent_config.get('system_prompt'):
- custom_system_prompt = config.agent_config['system_prompt'].strip()
- system_content = custom_system_prompt
- logger.info(f"Using ONLY custom agent system prompt")
- elif config.is_agent_builder:
- system_content = get_agent_builder_prompt()
- logger.info("Using agent builder system prompt")
- else:
- logger.info("Using default system prompt only")
-
- system_content = await add_knowledge_base_context(system_content, config)
- system_content = add_mcp_instructions(system_content, config, mcp_wrapper_instance)
-
- return {"role": "system", "content": system_content}
-
-async def get_browser_state_content(client: Any, thread_id: str, model_name: str, trace: Optional[StatefulTraceClient]) -> List[Dict]:
- content_list = []
-
- latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
-
- if not (latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0):
- return content_list
-
- try:
- browser_content = latest_browser_state_msg.data[0]["content"]
- if isinstance(browser_content, str):
- browser_content = json.loads(browser_content)
-
- screenshot_base64 = browser_content.get("screenshot_base64")
- screenshot_url = browser_content.get("image_url")
-
- browser_state_text = browser_content.copy()
- browser_state_text.pop('screenshot_base64', None)
- browser_state_text.pop('image_url', None)
-
- if browser_state_text:
- content_list.append({
- "type": "text",
- "text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}"
- })
-
- if is_vision_model(model_name):
- if screenshot_url:
- content_list.append({
- "type": "image_url",
- "image_url": {
- "url": screenshot_url,
- "format": "image/jpeg"
- }
- })
- if trace:
- trace.event(name="screenshot_url_added_to_temporary_message", level="DEFAULT")
- elif screenshot_base64:
- content_list.append({
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpeg;base64,{screenshot_base64}",
- }
- })
- if trace:
- trace.event(name="screenshot_base64_added_to_temporary_message", level="WARNING")
- else:
- logger.warning("Browser state found but no screenshot data.")
- if trace:
- trace.event(name="browser_state_found_but_no_screenshot_data", level="WARNING")
- else:
- logger.warning("Model doesn't support vision, skipping screenshot.")
- if trace:
- trace.event(name="model_doesnt_support_vision", level="WARNING")
-
- except Exception as e:
- logger.error(f"Error parsing browser state: {e}")
- if trace:
- trace.event(name="error_parsing_browser_state", level="ERROR", status_message=str(e))
-
- return content_list
-
-async def get_image_context_content(client: Any, thread_id: str) -> List[Dict]:
- content_list = []
-
- latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute()
-
- if not (latest_image_context_msg.data and len(latest_image_context_msg.data) > 0):
- return content_list
-
- try:
- image_context_content = latest_image_context_msg.data[0]["content"]
- if isinstance(image_context_content, str):
- image_context_content = json.loads(image_context_content)
-
- base64_image = image_context_content.get("base64")
- mime_type = image_context_content.get("mime_type")
- file_path = image_context_content.get("file_path", "unknown file")
-
- if base64_image and mime_type:
- content_list.extend([
- {
- "type": "text",
- "text": f"Here is the image you requested to see: '{file_path}'"
- },
- {
- "type": "image_url",
- "image_url": {
- "url": f"data:{mime_type};base64,{base64_image}",
- }
- }
- ])
- else:
- logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.")
-
- await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute()
-
- except Exception as e:
- logger.error(f"Error parsing image context: {e}")
-
- return content_list
-
-async def build_temporary_message(client: Any, config: AgentRunConfig, trace: Optional[StatefulTraceClient]) -> Optional[Dict]:
- temp_message_content_list = []
-
- browser_content = await get_browser_state_content(client, config.thread_id, config.model_name, trace)
- temp_message_content_list.extend(browser_content)
-
- image_content = await get_image_context_content(client, config.thread_id)
- temp_message_content_list.extend(image_content)
-
- if temp_message_content_list:
- return {"role": "user", "content": temp_message_content_list}
-
- return None
-
-async def should_continue_execution(client: Any, thread_id: str, trace: Optional[StatefulTraceClient]) -> bool:
- latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute()
-
- if latest_message.data and len(latest_message.data) > 0:
- message_type = latest_message.data[0].get('type')
- if message_type == 'assistant':
- logger.info("Last message was from assistant, stopping execution")
- if trace:
- trace.event(name="last_message_from_assistant", level="DEFAULT")
- return False
-
- return True
-
-async def check_billing_limits(client: Any, account_id: str, trace: Optional[StatefulTraceClient]) -> Tuple[bool, str]:
- can_run, message, subscription = await check_billing_status(client, account_id)
- if not can_run:
- error_msg = f"Billing limit reached: {message}"
- if trace:
- trace.event(name="billing_limit_reached", level="ERROR", status_message=error_msg)
- return False, error_msg
- return True, ""
-
-class ResponseProcessor:
- def __init__(self, trace: Optional[StatefulTraceClient]):
- self.trace = trace
- self.last_tool_call = None
- self.agent_should_terminate = False
- self.full_response = ""
-
- def check_termination_signal(self, chunk: Dict) -> None:
- if chunk.get('type') != 'status':
- return
-
- try:
- metadata = chunk.get('metadata', {})
- if isinstance(metadata, str):
- metadata = json.loads(metadata)
-
- if metadata.get('agent_should_terminate'):
- self.agent_should_terminate = True
- logger.info("Agent termination signal detected")
- if self.trace:
- self.trace.event(name="agent_termination_signal_detected", level="DEFAULT")
-
- content = chunk.get('content', {})
- if isinstance(content, str):
- content = json.loads(content)
-
- if content.get('function_name'):
- self.last_tool_call = content['function_name']
- elif content.get('xml_tag_name'):
- self.last_tool_call = content['xml_tag_name']
-
- except Exception as e:
- logger.debug(f"Error parsing status message for termination check: {e}")
-
- def check_xml_tools(self, chunk: Dict) -> None:
- if chunk.get('type') != 'assistant' or 'content' not in chunk:
- return
-
- try:
- content = chunk.get('content', '{}')
- if isinstance(content, str):
- assistant_content_json = json.loads(content)
- else:
- assistant_content_json = content
-
- assistant_text = assistant_content_json.get('content', '')
- self.full_response += assistant_text
-
- if isinstance(assistant_text, str):
- for tool in ['ask', 'complete', 'web-browser-takeover']:
- if f'{tool}>' in assistant_text:
- self.last_tool_call = tool
- logger.info(f"Agent used XML tool: {tool}")
- if self.trace:
- self.trace.event(name="agent_used_xml_tool", level="DEFAULT", status_message=f"Agent used XML tool: {tool}")
- break
-
- except json.JSONDecodeError:
- logger.warning(f"Could not parse assistant content JSON: {chunk.get('content')}")
- if self.trace:
- self.trace.event(name="warning_could_not_parse_assistant_content_json", level="WARNING")
- except Exception as e:
- logger.error(f"Error processing assistant chunk: {e}")
- if self.trace:
- self.trace.event(name="error_processing_assistant_chunk", level="ERROR", status_message=str(e))
-
- def process_chunk(self, chunk: Dict) -> None:
- self.check_termination_signal(chunk)
- self.check_xml_tools(chunk)
-
- def should_terminate(self) -> bool:
- return (self.agent_should_terminate or
- self.last_tool_call in ['ask', 'complete', 'web-browser-takeover'])
-
-async def process_llm_response(response: Any, processor: ResponseProcessor, trace: Optional[StatefulTraceClient]) -> AsyncGenerator[Dict, None]:
- try:
- if hasattr(response, '__aiter__') and not isinstance(response, dict):
- async for chunk in response:
- if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error':
- logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}")
- if trace:
- trace.event(name="error_chunk_detected", level="ERROR", status_message=chunk.get('message', 'Unknown error'))
- yield chunk
- return
-
- processor.process_chunk(chunk)
- yield chunk
- else:
- logger.error(f"Response is not iterable: {response}")
- yield {
- "type": "status",
- "status": "error",
- "message": "Response is not iterable"
- }
-
- except Exception as e:
- error_msg = f"Error during response streaming: {str(e)}"
- logger.error(error_msg)
- if trace:
- trace.event(name="error_during_response_streaming", level="ERROR", status_message=error_msg)
- yield {
- "type": "status",
- "status": "error",
- "message": error_msg
- }
-
-async def run_single_iteration(
- config: AgentRunConfig,
- context: ExecutionContext,
- system_message: Dict,
- temporary_message: Optional[Dict],
- iteration_count: int
-) -> AsyncGenerator[Dict, None]:
-
- can_continue, error_msg = await check_billing_limits(context.client, context.account_id, config.trace)
- if not can_continue:
- yield {
- "type": "status",
- "status": "stopped",
- "message": error_msg
- }
- return
-
- if not await should_continue_execution(context.client, config.thread_id, config.trace):
- return
-
- max_tokens = get_model_max_tokens(config.model_name)
- generation = config.trace.generation(name="thread_manager.run_thread") if config.trace else None
- processor = ResponseProcessor(config.trace)
-
- try:
- response = await config.thread_manager.run_thread(
- thread_id=config.thread_id,
- system_prompt=system_message,
- stream=config.stream,
- llm_model=config.model_name,
- llm_temperature=0,
- llm_max_tokens=max_tokens,
- tool_choice="auto",
- max_xml_tool_calls=1,
- temporary_message=temporary_message,
- processor_config=ProcessorConfig(
- xml_tool_calling=True,
- native_tool_calling=False,
- execute_tools=True,
- execute_on_stream=True,
- tool_execution_strategy="parallel",
- xml_adding_strategy="user_message"
- ),
- native_max_auto_continues=config.native_max_auto_continues,
- include_xml_examples=True,
- enable_thinking=config.enable_thinking,
- reasoning_effort=config.reasoning_effort,
- enable_context_manager=config.enable_context_manager,
- generation=generation
- )
-
- if isinstance(response, dict) and response.get("status") == "error":
- logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}")
- if config.trace:
- config.trace.event(name="error_response_from_run_thread", level="ERROR", status_message=response.get('message', 'Unknown error'))
- yield response
- return
-
- async for chunk in process_llm_response(response, processor, config.trace):
- yield chunk
-
- if processor.should_terminate():
- logger.info(f"Agent decided to stop with tool: {processor.last_tool_call}")
- if config.trace:
- config.trace.event(name="agent_decided_to_stop_with_tool", level="DEFAULT", status_message=f"Agent decided to stop with tool: {processor.last_tool_call}")
-
- if generation:
- generation.end(output=processor.full_response, status_message="agent_stopped")
-
- yield {
- "type": "status",
- "status": "completed",
- "terminate": True
- }
- else:
- if generation:
- generation.end(output=processor.full_response)
-
- except Exception as e:
- error_msg = f"Error running thread: {str(e)}"
- logger.error(error_msg)
- if config.trace:
- config.trace.event(name="error_running_thread", level="ERROR", status_message=error_msg)
- if generation:
- generation.end(output=processor.full_response, status_message=error_msg, level="ERROR")
- yield {
- "type": "status",
- "status": "error",
- "message": error_msg
- }
-
-async def setup_trace_input(config: AgentRunConfig, context: ExecutionContext) -> None:
- if not config.trace:
- return
-
- latest_user_message = await context.client.table('messages').select('*').eq('thread_id', config.thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute()
-
- if latest_user_message.data and len(latest_user_message.data) > 0:
- data = latest_user_message.data[0]['content']
- if isinstance(data, str):
- data = json.loads(data)
- config.trace.update(input=data['content'])
-
async def run_agent(
thread_id: str,
project_id: str,
@@ -736,81 +50,625 @@ async def run_agent(
is_agent_builder: Optional[bool] = False,
target_agent_id: Optional[str] = None
):
+ """Run the development agent with specified configuration."""
logger.info(f"š Starting agent with model: {model_name}")
if agent_config:
logger.info(f"Using custom agent: {agent_config.get('name', 'Unknown')}")
- config = AgentRunConfig(
- thread_id=thread_id,
- project_id=project_id,
- stream=stream,
- thread_manager=thread_manager,
- native_max_auto_continues=native_max_auto_continues,
- max_iterations=max_iterations,
- model_name=model_name,
- enable_thinking=enable_thinking,
- reasoning_effort=reasoning_effort,
- enable_context_manager=enable_context_manager,
- agent_config=agent_config,
- trace=trace or langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id}),
- is_agent_builder=is_agent_builder or False,
- target_agent_id=target_agent_id
- )
+ if not trace:
+ trace = langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id})
+ thread_manager = ThreadManager(trace=trace, is_agent_builder=is_agent_builder or False, target_agent_id=target_agent_id, agent_config=agent_config)
- config.thread_manager = ThreadManager(
- trace=config.trace,
- is_agent_builder=config.is_agent_builder,
- target_agent_id=config.target_agent_id,
- agent_config=config.agent_config
- )
+ client = await thread_manager.db.client
- try:
- context = await setup_execution_context(config)
- setup_tools(config.thread_manager, config)
- mcp_wrapper_instance = await setup_mcp_tools(config.thread_manager, config, context)
- system_message = await build_system_prompt(config, mcp_wrapper_instance)
- await setup_trace_input(config, context)
+ # Get account ID from thread for billing checks
+ account_id = await get_account_id_from_thread(client, thread_id)
+ if not account_id:
+ raise ValueError("Could not determine account ID for thread")
- iteration_count = 0
+ # Get sandbox info from project
+ project = await client.table('projects').select('*').eq('project_id', project_id).execute()
+ if not project.data or len(project.data) == 0:
+ raise ValueError(f"Project {project_id} not found")
+
+ project_data = project.data[0]
+ sandbox_info = project_data.get('sandbox', {})
+ if not sandbox_info.get('id'):
+ raise ValueError(f"No sandbox found for project {project_id}")
+
+ # Initialize tools with project_id instead of sandbox object
+ # This ensures each tool independently verifies it's operating on the correct project
+
+ # Get enabled tools from agent config, or use defaults
+ enabled_tools = None
+ if agent_config and 'agentpress_tools' in agent_config:
+ enabled_tools = agent_config['agentpress_tools']
+ logger.info(f"Using custom tool configuration from agent")
+
+
+ if is_agent_builder:
+ from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool
+ from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool
+ from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool
+ from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool
+ from agent.tools.agent_builder_tools.trigger_tool import TriggerTool
+ from services.supabase import DBConnection
+ db = DBConnection()
+
+ thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
+ thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
+ thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
+ thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
+ thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id)
- while iteration_count < config.max_iterations:
- iteration_count += 1
- logger.info(f"š Running iteration {iteration_count} of {config.max_iterations}...")
- temporary_message = await build_temporary_message(context.client, config, config.trace)
+ if enabled_tools is None:
+ logger.info("No agent specified - registering all tools for full Suna capabilities")
+ thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
+ thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
+ thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
+ thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
+ thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
+ thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
+ thread_manager.add_tool(MessageTool)
+ thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
+ thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
+ thread_manager.add_tool(SandboxImageEditTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
+ if config.RAPID_API_KEY:
+ thread_manager.add_tool(DataProvidersTool)
+ else:
+ logger.info("Custom agent specified - registering only enabled tools")
+ thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager)
+ thread_manager.add_tool(MessageTool)
+ if enabled_tools.get('sb_shell_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager)
+ if enabled_tools.get('sb_files_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager)
+ if enabled_tools.get('sb_browser_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
+ if enabled_tools.get('sb_deploy_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager)
+ if enabled_tools.get('sb_expose_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager)
+ if enabled_tools.get('web_search_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager)
+ if enabled_tools.get('sb_vision_tool', {}).get('enabled', False):
+ thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager)
+ if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False):
+ thread_manager.add_tool(DataProvidersTool)
+
+ # Register MCP tool wrapper if agent has configured MCPs or custom MCPs
+ mcp_wrapper_instance = None
+ if agent_config:
+ # Merge configured_mcps and custom_mcps
+ all_mcps = []
+
+ # Add standard configured MCPs
+ if agent_config.get('configured_mcps'):
+ all_mcps.extend(agent_config['configured_mcps'])
+
+ # Add custom MCPs
+ if agent_config.get('custom_mcps'):
+ for custom_mcp in agent_config['custom_mcps']:
+ # Transform custom MCP to standard format
+ custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse'))
+
+ # For Pipedream MCPs, ensure we have the user ID and proper config
+ if custom_type == 'pipedream':
+ # Get user ID from thread
+ if 'config' not in custom_mcp:
+ custom_mcp['config'] = {}
+
+ # Get external_user_id from profile if not present
+ if not custom_mcp['config'].get('external_user_id'):
+ profile_id = custom_mcp['config'].get('profile_id')
+ if profile_id:
+ try:
+ from pipedream.profiles import get_profile_manager
+ from services.supabase import DBConnection
+ profile_db = DBConnection()
+ profile_manager = get_profile_manager(profile_db)
+
+ # Get the profile to retrieve external_user_id
+ profile = await profile_manager.get_profile(account_id, profile_id)
+ if profile:
+ custom_mcp['config']['external_user_id'] = profile.external_user_id
+ logger.info(f"Retrieved external_user_id from profile {profile_id} for Pipedream MCP")
+ else:
+ logger.error(f"Could not find profile {profile_id} for Pipedream MCP")
+ except Exception as e:
+ logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}")
+
+ if 'headers' in custom_mcp['config'] and 'x-pd-app-slug' in custom_mcp['config']['headers']:
+ custom_mcp['config']['app_slug'] = custom_mcp['config']['headers']['x-pd-app-slug']
+
+ mcp_config = {
+ 'name': custom_mcp['name'],
+ 'qualifiedName': f"custom_{custom_type}_{custom_mcp['name'].replace(' ', '_').lower()}",
+ 'config': custom_mcp['config'],
+ 'enabledTools': custom_mcp.get('enabledTools', []),
+ 'instructions': custom_mcp.get('instructions', ''),
+ 'isCustom': True,
+ 'customType': custom_type
+ }
+ all_mcps.append(mcp_config)
+
+ if all_mcps:
+ logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers (including {len(agent_config.get('custom_mcps', []))} custom)")
+ thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps)
- should_terminate = False
- async for result in run_single_iteration(config, context, system_message, temporary_message, iteration_count):
- yield result
- if result.get('terminate'):
- should_terminate = True
- break
- if result.get('type') == 'status' and result.get('status') in ['error', 'stopped']:
- should_terminate = True
+ for tool_name, tool_info in thread_manager.tool_registry.tools.items():
+ if isinstance(tool_info['instance'], MCPToolWrapper):
+ mcp_wrapper_instance = tool_info['instance']
break
- if should_terminate:
+ if mcp_wrapper_instance:
+ try:
+ await mcp_wrapper_instance.initialize_and_register_tools()
+ logger.info("MCP tools initialized successfully")
+ updated_schemas = mcp_wrapper_instance.get_schemas()
+ logger.info(f"MCP wrapper has {len(updated_schemas)} schemas available")
+ for method_name, schema_list in updated_schemas.items():
+ if method_name != 'call_mcp_tool':
+ for schema in schema_list:
+ if schema.schema_type == SchemaType.OPENAPI:
+ thread_manager.tool_registry.tools[method_name] = {
+ "instance": mcp_wrapper_instance,
+ "schema": schema
+ }
+ logger.info(f"Registered dynamic MCP tool: {method_name}")
+
+ # Log all registered tools for debugging
+ all_tools = list(thread_manager.tool_registry.tools.keys())
+ logger.info(f"All registered tools after MCP initialization: {all_tools}")
+ mcp_tools = [tool for tool in all_tools if tool not in ['call_mcp_tool', 'sb_files_tool', 'message_tool', 'expand_msg_tool', 'web_search_tool', 'sb_shell_tool', 'sb_vision_tool', 'sb_browser_tool', 'computer_use_tool', 'data_providers_tool', 'sb_deploy_tool', 'sb_expose_tool', 'update_agent_tool']]
+ logger.info(f"MCP tools registered: {mcp_tools}")
+
+ except Exception as e:
+ logger.error(f"Failed to initialize MCP tools: {e}")
+ # Continue without MCP tools if initialization fails
+
+ # Prepare system prompt
+ # First, get the default system prompt
+ if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower():
+ default_system_content = get_gemini_system_prompt()
+ else:
+ # Use the original prompt - the LLM can only use tools that are registered
+ default_system_content = get_system_prompt()
+
+ # Add sample response for non-anthropic models
+ if "anthropic" not in model_name.lower():
+ sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt')
+ with open(sample_response_path, 'r') as file:
+ sample_response = file.read()
+ default_system_content = default_system_content + "\n\n " + sample_response + ""
+
+ # Handle custom agent system prompt
+ if agent_config and agent_config.get('system_prompt'):
+ custom_system_prompt = agent_config['system_prompt'].strip()
+
+ # Completely replace the default system prompt with the custom one
+ # This prevents confusion and tool hallucination
+ system_content = custom_system_prompt
+ logger.info(f"Using ONLY custom agent system prompt for: {agent_config.get('name', 'Unknown')}")
+ elif is_agent_builder:
+ system_content = get_agent_builder_prompt()
+ logger.info("Using agent builder system prompt")
+ else:
+ # Use just the default system prompt
+ system_content = default_system_content
+ logger.info("Using default system prompt only")
+
+ if await is_enabled("knowledge_base"):
+ try:
+ from services.supabase import DBConnection
+ kb_db = DBConnection()
+ kb_client = await kb_db.client
+
+ current_agent_id = agent_config.get('agent_id') if agent_config else None
+
+ kb_result = await kb_client.rpc('get_combined_knowledge_base_context', {
+ 'p_thread_id': thread_id,
+ 'p_agent_id': current_agent_id,
+ 'p_max_tokens': 4000
+ }).execute()
+
+ if kb_result.data and kb_result.data.strip():
+ logger.info(f"Adding combined knowledge base context to system prompt for thread {thread_id}, agent {current_agent_id}")
+ system_content += "\n\n" + kb_result.data
+ else:
+ logger.debug(f"No knowledge base context found for thread {thread_id}, agent {current_agent_id}")
+
+ except Exception as e:
+ logger.error(f"Error retrieving knowledge base context for thread {thread_id}: {e}")
+
+
+ if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized:
+ mcp_info = "\n\n--- MCP Tools Available ---\n"
+ mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n"
+ mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n"
+ mcp_info += '\n'
+ mcp_info += '\n'
+ mcp_info += 'value1\n'
+ mcp_info += 'value2\n'
+ mcp_info += '\n'
+ mcp_info += '\n\n'
+
+ # List available MCP tools
+ mcp_info += "Available MCP tools:\n"
+ try:
+ # Get the actual registered schemas from the wrapper
+ registered_schemas = mcp_wrapper_instance.get_schemas()
+ for method_name, schema_list in registered_schemas.items():
+ if method_name == 'call_mcp_tool':
+ continue # Skip the fallback method
+
+ # Get the schema info
+ for schema in schema_list:
+ if schema.schema_type == SchemaType.OPENAPI:
+ func_info = schema.schema.get('function', {})
+ description = func_info.get('description', 'No description available')
+ # Extract server name from description if available
+ server_match = description.find('(MCP Server: ')
+ if server_match != -1:
+ server_end = description.find(')', server_match)
+ server_info = description[server_match:server_end+1]
+ else:
+ server_info = ''
+
+ mcp_info += f"- **{method_name}**: {description}\n"
+
+ # Show parameter info
+ params = func_info.get('parameters', {})
+ props = params.get('properties', {})
+ if props:
+ mcp_info += f" Parameters: {', '.join(props.keys())}\n"
+
+ except Exception as e:
+ logger.error(f"Error listing MCP tools: {e}")
+ mcp_info += "- Error loading MCP tool list\n"
+
+ # Add critical instructions for using search results
+ mcp_info += "\nšØ CRITICAL MCP TOOL RESULT INSTRUCTIONS šØ\n"
+ mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n"
+ mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n"
+ mcp_info += "2. For search tools: ONLY cite URLs, sources, and information from the actual search results\n"
+ mcp_info += "3. For any tool: Base your response entirely on the tool's output - do NOT add external information\n"
+ mcp_info += "4. DO NOT fabricate, invent, hallucinate, or make up any sources, URLs, or data\n"
+ mcp_info += "5. If you need more information, call the MCP tool again with different parameters\n"
+ mcp_info += "6. When writing reports/summaries: Reference ONLY the data from MCP tool results\n"
+ mcp_info += "7. If the MCP tool doesn't return enough information, explicitly state this limitation\n"
+ mcp_info += "8. Always double-check that every fact, URL, and reference comes from the MCP tool output\n"
+ mcp_info += "\nIMPORTANT: MCP tool results are your PRIMARY and ONLY source of truth for external data!\n"
+ mcp_info += "NEVER supplement MCP results with your training data or make assumptions beyond what the tools provide.\n"
+
+ system_content += mcp_info
+
+ system_message = { "role": "system", "content": system_content }
+
+ iteration_count = 0
+ continue_execution = True
+
+ latest_user_message = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute()
+ if latest_user_message.data and len(latest_user_message.data) > 0:
+ data = latest_user_message.data[0]['content']
+ if isinstance(data, str):
+ data = json.loads(data)
+ if trace:
+ trace.update(input=data['content'])
+
+ while continue_execution and iteration_count < max_iterations:
+ iteration_count += 1
+ logger.info(f"š Running iteration {iteration_count} of {max_iterations}...")
+
+ # Billing check on each iteration - still needed within the iterations
+ can_run, message, subscription = await check_billing_status(client, account_id)
+ if not can_run:
+ error_msg = f"Billing limit reached: {message}"
+ if trace:
+ trace.event(name="billing_limit_reached", level="ERROR", status_message=(f"{error_msg}"))
+ # Yield a special message to indicate billing limit reached
+ yield {
+ "type": "status",
+ "status": "stopped",
+ "message": error_msg
+ }
+ break
+ # Check if last message is from assistant using direct Supabase query
+ latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute()
+ if latest_message.data and len(latest_message.data) > 0:
+ message_type = latest_message.data[0].get('type')
+ if message_type == 'assistant':
+ logger.info(f"Last message was from assistant, stopping execution")
+ if trace:
+ trace.event(name="last_message_from_assistant", level="DEFAULT", status_message=(f"Last message was from assistant, stopping execution"))
+ continue_execution = False
break
- except AgentExecutionError as e:
- error_msg = str(e)
- logger.error(f"Agent execution error: {error_msg}")
- if config.trace:
- config.trace.event(name="agent_execution_error", level="ERROR", status_message=error_msg)
- yield {
- "type": "status",
- "status": "error",
- "message": error_msg
- }
- except Exception as e:
- error_msg = f"Unexpected error in run_agent: {str(e)}"
- logger.error(error_msg)
- if config.trace:
- config.trace.event(name="unexpected_error_in_run_agent", level="ERROR", status_message=error_msg)
- yield {
- "type": "status",
- "status": "error",
- "message": error_msg
- }
- finally:
- asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush()))
+ # ---- Temporary Message Handling (Browser State & Image Context) ----
+ temporary_message = None
+ temp_message_content_list = [] # List to hold text/image blocks
+
+ # Get the latest browser_state message
+ latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute()
+ if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0:
+ try:
+ browser_content = latest_browser_state_msg.data[0]["content"]
+ if isinstance(browser_content, str):
+ browser_content = json.loads(browser_content)
+ screenshot_base64 = browser_content.get("screenshot_base64")
+ screenshot_url = browser_content.get("image_url")
+
+ # Create a copy of the browser state without screenshot data
+ browser_state_text = browser_content.copy()
+ browser_state_text.pop('screenshot_base64', None)
+ browser_state_text.pop('image_url', None)
+
+ if browser_state_text:
+ temp_message_content_list.append({
+ "type": "text",
+ "text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}"
+ })
+
+ # Only add screenshot if model is not Gemini, Anthropic, or OpenAI
+ if 'gemini' in model_name.lower() or 'anthropic' in model_name.lower() or 'openai' in model_name.lower():
+ # Prioritize screenshot_url if available
+ if screenshot_url:
+ temp_message_content_list.append({
+ "type": "image_url",
+ "image_url": {
+ "url": screenshot_url,
+ "format": "image/jpeg"
+ }
+ })
+ if trace:
+ trace.event(name="screenshot_url_added_to_temporary_message", level="DEFAULT", status_message=(f"Screenshot URL added to temporary message."))
+ elif screenshot_base64:
+ # Fallback to base64 if URL not available
+ temp_message_content_list.append({
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:image/jpeg;base64,{screenshot_base64}",
+ }
+ })
+ if trace:
+ trace.event(name="screenshot_base64_added_to_temporary_message", level="WARNING", status_message=(f"Screenshot base64 added to temporary message. Prefer screenshot_url if available."))
+ else:
+ logger.warning("Browser state found but no screenshot data.")
+ if trace:
+ trace.event(name="browser_state_found_but_no_screenshot_data", level="WARNING", status_message=(f"Browser state found but no screenshot data."))
+ else:
+ logger.warning("Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message.")
+ if trace:
+ trace.event(name="model_is_gemini_anthropic_or_openai", level="WARNING", status_message=(f"Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message."))
+
+ except Exception as e:
+ logger.error(f"Error parsing browser state: {e}")
+ if trace:
+ trace.event(name="error_parsing_browser_state", level="ERROR", status_message=(f"{e}"))
+
+ # Get the latest image_context message (NEW)
+ latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute()
+ if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0:
+ try:
+ image_context_content = latest_image_context_msg.data[0]["content"] if isinstance(latest_image_context_msg.data[0]["content"], dict) else json.loads(latest_image_context_msg.data[0]["content"])
+ base64_image = image_context_content.get("base64")
+ mime_type = image_context_content.get("mime_type")
+ file_path = image_context_content.get("file_path", "unknown file")
+
+ if base64_image and mime_type:
+ temp_message_content_list.append({
+ "type": "text",
+ "text": f"Here is the image you requested to see: '{file_path}'"
+ })
+ temp_message_content_list.append({
+ "type": "image_url",
+ "image_url": {
+ "url": f"data:{mime_type};base64,{base64_image}",
+ }
+ })
+ else:
+ logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.")
+
+ await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute()
+ except Exception as e:
+ logger.error(f"Error parsing image context: {e}")
+ if trace:
+ trace.event(name="error_parsing_image_context", level="ERROR", status_message=(f"{e}"))
+
+ # If we have any content, construct the temporary_message
+ if temp_message_content_list:
+ temporary_message = {"role": "user", "content": temp_message_content_list}
+ # logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.")
+ # ---- End Temporary Message Handling ----
+
+ # Set max_tokens based on model
+ max_tokens = None
+ if "sonnet" in model_name.lower():
+ # Claude 3.5 Sonnet has a limit of 8192 tokens
+ max_tokens = 8192
+ elif "gpt-4" in model_name.lower():
+ max_tokens = 4096
+ elif "gemini-2.5-pro" in model_name.lower():
+ # Gemini 2.5 Pro has 64k max output tokens
+ max_tokens = 64000
+ elif "kimi-k2" in model_name.lower():
+ # Kimi-K2 has 120K context, set reasonable max output tokens
+ max_tokens = 8192
+
+ generation = trace.generation(name="thread_manager.run_thread") if trace else None
+ try:
+ # Make the LLM call and process the response
+ response = await thread_manager.run_thread(
+ thread_id=thread_id,
+ system_prompt=system_message,
+ stream=stream,
+ llm_model=model_name,
+ llm_temperature=0,
+ llm_max_tokens=max_tokens,
+ tool_choice="auto",
+ max_xml_tool_calls=1,
+ temporary_message=temporary_message,
+ processor_config=ProcessorConfig(
+ xml_tool_calling=True,
+ native_tool_calling=False,
+ execute_tools=True,
+ execute_on_stream=True,
+ tool_execution_strategy="parallel",
+ xml_adding_strategy="user_message"
+ ),
+ native_max_auto_continues=native_max_auto_continues,
+ include_xml_examples=True,
+ enable_thinking=enable_thinking,
+ reasoning_effort=reasoning_effort,
+ enable_context_manager=enable_context_manager,
+ generation=generation
+ )
+
+ if isinstance(response, dict) and "status" in response and response["status"] == "error":
+ logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}")
+ if trace:
+ trace.event(name="error_response_from_run_thread", level="ERROR", status_message=(f"{response.get('message', 'Unknown error')}"))
+ yield response
+ break
+
+ # Track if we see ask, complete, or web-browser-takeover tool calls
+ last_tool_call = None
+ agent_should_terminate = False
+
+ # Process the response
+ error_detected = False
+ full_response = ""
+ try:
+ # Check if response is iterable (async generator) or a dict (error case)
+ if hasattr(response, '__aiter__') and not isinstance(response, dict):
+ async for chunk in response:
+ # If we receive an error chunk, we should stop after this iteration
+ if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error':
+ logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}")
+ if trace:
+ trace.event(name="error_chunk_detected", level="ERROR", status_message=(f"{chunk.get('message', 'Unknown error')}"))
+ error_detected = True
+ yield chunk # Forward the error chunk
+ continue # Continue processing other chunks but don't break yet
+
+ # Check for termination signal in status messages
+ if chunk.get('type') == 'status':
+ try:
+ # Parse the metadata to check for termination signal
+ metadata = chunk.get('metadata', {})
+ if isinstance(metadata, str):
+ metadata = json.loads(metadata)
+
+ if metadata.get('agent_should_terminate'):
+ agent_should_terminate = True
+ logger.info("Agent termination signal detected in status message")
+ if trace:
+ trace.event(name="agent_termination_signal_detected", level="DEFAULT", status_message="Agent termination signal detected in status message")
+
+ # Extract the tool name from the status content if available
+ content = chunk.get('content', {})
+ if isinstance(content, str):
+ content = json.loads(content)
+
+ if content.get('function_name'):
+ last_tool_call = content['function_name']
+ elif content.get('xml_tag_name'):
+ last_tool_call = content['xml_tag_name']
+
+ except Exception as e:
+ logger.debug(f"Error parsing status message for termination check: {e}")
+
+ # Check for XML versions like , , or in assistant content chunks
+ if chunk.get('type') == 'assistant' and 'content' in chunk:
+ try:
+ # The content field might be a JSON string or object
+ content = chunk.get('content', '{}')
+ if isinstance(content, str):
+ assistant_content_json = json.loads(content)
+ else:
+ assistant_content_json = content
+
+ # The actual text content is nested within
+ assistant_text = assistant_content_json.get('content', '')
+ full_response += assistant_text
+ if isinstance(assistant_text, str):
+ if '' in assistant_text or '' in assistant_text or '' in assistant_text:
+ if '' in assistant_text:
+ xml_tool = 'ask'
+ elif '' in assistant_text:
+ xml_tool = 'complete'
+ elif '' in assistant_text:
+ xml_tool = 'web-browser-takeover'
+
+ last_tool_call = xml_tool
+ logger.info(f"Agent used XML tool: {xml_tool}")
+ if trace:
+ trace.event(name="agent_used_xml_tool", level="DEFAULT", status_message=(f"Agent used XML tool: {xml_tool}"))
+
+ except json.JSONDecodeError:
+ # Handle cases where content might not be valid JSON
+ logger.warning(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}")
+ if trace:
+ trace.event(name="warning_could_not_parse_assistant_content_json", level="WARNING", status_message=(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}"))
+ except Exception as e:
+ logger.error(f"Error processing assistant chunk: {e}")
+ if trace:
+ trace.event(name="error_processing_assistant_chunk", level="ERROR", status_message=(f"Error processing assistant chunk: {e}"))
+
+ yield chunk
+ else:
+ # Response is not iterable, likely an error dict
+ logger.error(f"Response is not iterable: {response}")
+ error_detected = True
+
+ # Check if we should stop based on the last tool call or error
+ if error_detected:
+ logger.info(f"Stopping due to error detected in response")
+ if trace:
+ trace.event(name="stopping_due_to_error_detected_in_response", level="DEFAULT", status_message=(f"Stopping due to error detected in response"))
+ if generation:
+ generation.end(output=full_response, status_message="error_detected", level="ERROR")
+ break
+
+ if agent_should_terminate or last_tool_call in ['ask', 'complete', 'web-browser-takeover']:
+ logger.info(f"Agent decided to stop with tool: {last_tool_call}")
+ if trace:
+ trace.event(name="agent_decided_to_stop_with_tool", level="DEFAULT", status_message=(f"Agent decided to stop with tool: {last_tool_call}"))
+ if generation:
+ generation.end(output=full_response, status_message="agent_stopped")
+ continue_execution = False
+
+ except Exception as e:
+ # Just log the error and re-raise to stop all iterations
+ error_msg = f"Error during response streaming: {str(e)}"
+ logger.error(f"Error: {error_msg}")
+ if trace:
+ trace.event(name="error_during_response_streaming", level="ERROR", status_message=(f"Error during response streaming: {str(e)}"))
+ if generation:
+ generation.end(output=full_response, status_message=error_msg, level="ERROR")
+ yield {
+ "type": "status",
+ "status": "error",
+ "message": error_msg
+ }
+ # Stop execution immediately on any error
+ break
+
+ except Exception as e:
+ # Just log the error and re-raise to stop all iterations
+ error_msg = f"Error running thread: {str(e)}"
+ logger.error(f"Error: {error_msg}")
+ if trace:
+ trace.event(name="error_running_thread", level="ERROR", status_message=(f"Error running thread: {str(e)}"))
+ yield {
+ "type": "status",
+ "status": "error",
+ "message": error_msg
+ }
+ # Stop execution immediately on any error
+ break
+ if generation:
+ generation.end(output=full_response)
+
+ asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush()))
\ No newline at end of file
diff --git a/backend/services/llm.py b/backend/services/llm.py
index 6187e247..7045294c 100644
--- a/backend/services/llm.py
+++ b/backend/services/llm.py
@@ -232,6 +232,12 @@ def prepare_params(
use_thinking = enable_thinking if enable_thinking is not None else False
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
is_xai = "xai" in effective_model_name.lower() or model_name.startswith("xai/")
+ is_kimi_k2 = "kimi-k2" in effective_model_name.lower() or model_name.startswith("moonshotai/kimi-k2")
+
+ if is_kimi_k2:
+ params["provider"] = {
+ "order": ["groq", "together/fp8"]
+ }
if is_anthropic and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
diff --git a/backend/utils/constants.py b/backend/utils/constants.py
index 9d85684f..4a0ac589 100644
--- a/backend/utils/constants.py
+++ b/backend/utils/constants.py
@@ -53,6 +53,14 @@ MODELS = {
},
"tier_availability": ["paid"]
},
+ "openrouter/moonshotai/kimi-k2": {
+ "aliases": ["moonshotai/kimi-k2", "kimi-k2"],
+ "pricing": {
+ "input_cost_per_million_tokens": 1.00,
+ "output_cost_per_million_tokens": 3.00
+ },
+ "tier_availability": ["paid"]
+ },
"openai/gpt-4o": {
"aliases": ["gpt-4o"],
"pricing": {
diff --git a/frontend/Dockerfile b/frontend/Dockerfile
index 69cea002..74bf2b3b 100644
--- a/frontend/Dockerfile
+++ b/frontend/Dockerfile
@@ -9,6 +9,20 @@ WORKDIR /app
# Install dependencies based on the preferred package manager
COPY package.json yarn.lock* package-lock.json* pnpm-lock.yaml* .npmrc* ./
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ python3 \
+ make \
+ g++ \
+ build-essential \
+ pkg-config \
+ libcairo2-dev \
+ libpango1.0-dev \
+ libjpeg-dev \
+ libgif-dev \
+ librsvg2-dev \
+ && rm -rf /var/lib/apt/lists/*
+
RUN \
if [ -f yarn.lock ]; then yarn --frozen-lockfile; \
elif [ -f package-lock.json ]; then npm ci; \
diff --git a/frontend/src/components/dashboard/layout-content.tsx b/frontend/src/components/dashboard/layout-content.tsx
index 0af30b05..f07c9ad3 100644
--- a/frontend/src/components/dashboard/layout-content.tsx
+++ b/frontend/src/components/dashboard/layout-content.tsx
@@ -9,7 +9,7 @@ import { useAccounts } from '@/hooks/use-accounts';
import { useAuth } from '@/components/AuthProvider';
import { useRouter } from 'next/navigation';
import { Loader2 } from 'lucide-react';
-import { checkApiHealth } from '@/lib/api';
+import { useApiHealth } from '@/hooks/react-query';
import { MaintenancePage } from '@/components/maintenance/maintenance-page';
import { DeleteOperationProvider } from '@/contexts/DeleteOperationContext';
import { StatusOverlay } from '@/components/ui/status-overlay';
@@ -28,37 +28,19 @@ export default function DashboardLayoutContent({
}: DashboardLayoutContentProps) {
// const [showPricingAlert, setShowPricingAlert] = useState(false)
const [showMaintenanceAlert, setShowMaintenanceAlert] = useState(false);
- const [isApiHealthy, setIsApiHealthy] = useState(true);
- const [isCheckingHealth, setIsCheckingHealth] = useState(true);
const { data: accounts } = useAccounts();
const personalAccount = accounts?.find((account) => account.personal_account);
const { user, isLoading } = useAuth();
const router = useRouter();
+ const { data: healthData, isLoading: isCheckingHealth, error: healthError } = useApiHealth();
useEffect(() => {
// setShowPricingAlert(false)
setShowMaintenanceAlert(false);
}, []);
- // Check API health
- useEffect(() => {
- const checkHealth = async () => {
- try {
- const health = await checkApiHealth();
- setIsApiHealthy(health.status === 'ok');
- } catch (error) {
- console.error('API health check failed:', error);
- setIsApiHealthy(false);
- } finally {
- setIsCheckingHealth(false);
- }
- };
-
- checkHealth();
- // Check health every 30 seconds
- const interval = setInterval(checkHealth, 30000);
- return () => clearInterval(interval);
- }, []);
+ // API health is now managed by useApiHealth hook
+ const isApiHealthy = healthData?.status === 'ok' && !healthError;
// Check authentication status
useEffect(() => {
@@ -107,8 +89,8 @@ export default function DashboardLayoutContent({
return null;
}
- // Show maintenance page if API is not healthy
- if (!isApiHealthy) {
+ // Show maintenance page if API is not healthy (but not during initial loading)
+ if (!isCheckingHealth && !isApiHealthy) {
return ;
}
diff --git a/frontend/src/components/thread/chat-input/_use-model-selection.ts b/frontend/src/components/thread/chat-input/_use-model-selection.ts
index 7687e1ed..d1ac4643 100644
--- a/frontend/src/components/thread/chat-input/_use-model-selection.ts
+++ b/frontend/src/components/thread/chat-input/_use-model-selection.ts
@@ -70,6 +70,12 @@ export const MODELS = {
recommended: false,
lowQuality: false
},
+ 'moonshotai/kimi-k2': {
+ tier: 'premium',
+ priority: 96,
+ recommended: false,
+ lowQuality: false
+ },
'gpt-4.1': {
tier: 'premium',
priority: 96,
diff --git a/frontend/src/components/thread/chat-input/message-input.tsx b/frontend/src/components/thread/chat-input/message-input.tsx
index 78508bfb..ded27ec4 100644
--- a/frontend/src/components/thread/chat-input/message-input.tsx
+++ b/frontend/src/components/thread/chat-input/message-input.tsx
@@ -16,6 +16,7 @@ import { Tooltip } from '@/components/ui/tooltip';
import { TooltipProvider, TooltipTrigger } from '@radix-ui/react-tooltip';
import { BillingModal } from '@/components/billing/billing-modal';
import ChatDropdown from './chat-dropdown';
+import { handleFiles } from './file-upload-handler';
interface MessageInputProps {
value: string;
@@ -129,6 +130,29 @@ export const MessageInput = forwardRef(
}
};
+ const handlePaste = (e: React.ClipboardEvent) => {
+ if (!e.clipboardData) return;
+ const items = Array.from(e.clipboardData.items);
+ const imageFiles: File[] = [];
+ for (const item of items) {
+ if (item.kind === 'file' && item.type.startsWith('image/')) {
+ const file = item.getAsFile();
+ if (file) imageFiles.push(file);
+ }
+ }
+ if (imageFiles.length > 0) {
+ e.preventDefault();
+ handleFiles(
+ imageFiles,
+ sandboxId,
+ setPendingFiles,
+ setUploadedFiles,
+ setIsUploading,
+ messages,
+ );
+ }
+ };
+
const renderDropdown = () => {
if (isLoggedIn) {
const showAdvancedFeatures = enableAdvancedConfig || (customAgentsEnabled && !flagsLoading);
@@ -167,6 +191,7 @@ export const MessageInput = forwardRef(
value={value}
onChange={onChange}
onKeyDown={handleKeyDown}
+ onPaste={handlePaste}
placeholder={placeholder}
className={cn(
'w-full bg-transparent dark:bg-transparent border-none shadow-none focus-visible:ring-0 px-0.5 pb-6 pt-4 !text-[15px] min-h-[36px] max-h-[200px] overflow-y-auto resize-none',