diff --git a/backend/agent/run.py b/backend/agent/run.py index e14576c8..851d7056 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -1,9 +1,9 @@ import os import json import asyncio -from typing import Optional, Dict, List, Any, AsyncGenerator, Tuple -from dataclasses import dataclass, field +from typing import Optional +# from agent.tools.message_tool import MessageTool from agent.tools.message_tool import MessageTool from agent.tools.sb_deploy_tool import SandboxDeployTool from agent.tools.sb_expose_tool import SandboxExposeTool @@ -27,699 +27,13 @@ from agent.tools.sb_vision_tool import SandboxVisionTool from agent.tools.sb_image_edit_tool import SandboxImageEditTool from services.langfuse import langfuse from langfuse.client import StatefulTraceClient +from services.langfuse import langfuse from agent.gemini_prompt import get_gemini_system_prompt from agent.tools.mcp_tool_wrapper import MCPToolWrapper from agentpress.tool import SchemaType load_dotenv() -@dataclass -class AgentRunConfig: - thread_id: str - project_id: str - stream: bool - thread_manager: Optional[ThreadManager] = None - native_max_auto_continues: int = 25 - max_iterations: int = 100 - model_name: str = "anthropic/claude-sonnet-4-20250514" - enable_thinking: Optional[bool] = False - reasoning_effort: Optional[str] = 'low' - enable_context_manager: bool = True - agent_config: Optional[dict] = None - trace: Optional[StatefulTraceClient] = None - is_agent_builder: Optional[bool] = False - target_agent_id: Optional[str] = None - -@dataclass -class ExecutionContext: - client: Any - account_id: str - project_data: Dict - sandbox_info: Dict - mcp_wrapper_instance: Optional[MCPToolWrapper] = None - -class AgentExecutionError(Exception): - pass - -def get_model_max_tokens(model_name: str) -> Optional[int]: - if "sonnet" in model_name.lower(): - return 8192 - elif "gpt-4" in model_name.lower(): - return 4096 - elif "gemini-2.5-pro" in model_name.lower(): - return 64000 - return None - -def is_vision_model(model_name: str) -> bool: - return any(x in model_name.lower() for x in ['gemini', 'anthropic', 'openai']) - -async def setup_execution_context(config: AgentRunConfig) -> ExecutionContext: - client = await config.thread_manager.db.client - - account_id = await get_account_id_from_thread(client, config.thread_id) - if not account_id: - raise AgentExecutionError("Could not determine account ID for thread") - - project = await client.table('projects').select('*').eq('project_id', config.project_id).execute() - if not project.data or len(project.data) == 0: - raise AgentExecutionError(f"Project {config.project_id} not found") - - project_data = project.data[0] - sandbox_info = project_data.get('sandbox', {}) - if not sandbox_info.get('id'): - raise AgentExecutionError(f"No sandbox found for project {config.project_id}") - - return ExecutionContext( - client=client, - account_id=account_id, - project_data=project_data, - sandbox_info=sandbox_info - ) - -def register_agent_builder_tools(thread_manager: ThreadManager, target_agent_id: str): - from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool - from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool - from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool - from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool - from agent.tools.agent_builder_tools.trigger_tool import TriggerTool - from services.supabase import DBConnection - - db = DBConnection() - thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - -def register_default_tools(thread_manager: ThreadManager, agent_config: AgentRunConfig): - thread_manager.add_tool(SandboxShellTool, project_id=agent_config.project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxFilesTool, project_id=agent_config.project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxBrowserTool, project_id=agent_config.project_id, thread_id=agent_config.thread_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxDeployTool, project_id=agent_config.project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxExposeTool, project_id=agent_config.project_id, thread_manager=thread_manager) - thread_manager.add_tool(ExpandMessageTool, thread_id=agent_config.thread_id, thread_manager=thread_manager) - thread_manager.add_tool(MessageTool) - thread_manager.add_tool(SandboxWebSearchTool, project_id=agent_config.project_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxVisionTool, project_id=agent_config.project_id, thread_id=agent_config.thread_id, thread_manager=thread_manager) - thread_manager.add_tool(SandboxImageEditTool, project_id=agent_config.project_id, thread_id=agent_config.thread_id, thread_manager=thread_manager) - - if config.RAPID_API_KEY: - thread_manager.add_tool(DataProvidersTool) - -def register_custom_tools(thread_manager: ThreadManager, agent_config: AgentRunConfig, enabled_tools: Dict): - thread_manager.add_tool(ExpandMessageTool, thread_id=agent_config.thread_id, thread_manager=thread_manager) - thread_manager.add_tool(MessageTool) - - tool_mapping = { - 'sb_shell_tool': (SandboxShellTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}), - 'sb_files_tool': (SandboxFilesTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}), - 'sb_browser_tool': (SandboxBrowserTool, {'project_id': agent_config.project_id, 'thread_id': agent_config.thread_id, 'thread_manager': thread_manager}), - 'sb_deploy_tool': (SandboxDeployTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}), - 'sb_expose_tool': (SandboxExposeTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}), - 'web_search_tool': (SandboxWebSearchTool, {'project_id': agent_config.project_id, 'thread_manager': thread_manager}), - 'sb_vision_tool': (SandboxVisionTool, {'project_id': agent_config.project_id, 'thread_id': agent_config.thread_id, 'thread_manager': thread_manager}), - } - - for tool_name, (tool_class, kwargs) in tool_mapping.items(): - if enabled_tools.get(tool_name, {}).get('enabled', False): - thread_manager.add_tool(tool_class, **kwargs) - - if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False): - thread_manager.add_tool(DataProvidersTool) - -async def setup_pipedream_mcp_config(custom_mcp: Dict, account_id: str) -> Dict: - if 'config' not in custom_mcp: - custom_mcp['config'] = {} - - if not custom_mcp['config'].get('external_user_id'): - profile_id = custom_mcp['config'].get('profile_id') - if profile_id: - try: - from pipedream.profiles import get_profile_manager - from services.supabase import DBConnection - profile_db = DBConnection() - profile_manager = get_profile_manager(profile_db) - - profile = await profile_manager.get_profile(account_id, profile_id) - if profile: - custom_mcp['config']['external_user_id'] = profile.external_user_id - logger.info(f"Retrieved external_user_id from profile {profile_id} for Pipedream MCP") - else: - logger.error(f"Could not find profile {profile_id} for Pipedream MCP") - except Exception as e: - logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}") - - if 'headers' in custom_mcp['config'] and 'x-pd-app-slug' in custom_mcp['config']['headers']: - custom_mcp['config']['app_slug'] = custom_mcp['config']['headers']['x-pd-app-slug'] - - return custom_mcp - -def create_mcp_config(custom_mcp: Dict, custom_type: str) -> Dict: - return { - 'name': custom_mcp['name'], - 'qualifiedName': f"custom_{custom_type}_{custom_mcp['name'].replace(' ', '_').lower()}", - 'config': custom_mcp['config'], - 'enabledTools': custom_mcp.get('enabledTools', []), - 'instructions': custom_mcp.get('instructions', ''), - 'isCustom': True, - 'customType': custom_type - } - -async def setup_mcp_tools(thread_manager: ThreadManager, config: AgentRunConfig, context: ExecutionContext) -> Optional[MCPToolWrapper]: - if not config.agent_config: - return None - - all_mcps = [] - - if config.agent_config.get('configured_mcps'): - all_mcps.extend(config.agent_config['configured_mcps']) - - if config.agent_config.get('custom_mcps'): - for custom_mcp in config.agent_config['custom_mcps']: - custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse')) - - if custom_type == 'pipedream': - custom_mcp = await setup_pipedream_mcp_config(custom_mcp, context.account_id) - - mcp_config = create_mcp_config(custom_mcp, custom_type) - all_mcps.append(mcp_config) - - if not all_mcps: - return None - - logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers") - thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps) - - mcp_wrapper_instance = None - for tool_name, tool_info in thread_manager.tool_registry.tools.items(): - if isinstance(tool_info['instance'], MCPToolWrapper): - mcp_wrapper_instance = tool_info['instance'] - break - - if mcp_wrapper_instance: - try: - await mcp_wrapper_instance.initialize_and_register_tools() - logger.info("MCP tools initialized successfully") - - updated_schemas = mcp_wrapper_instance.get_schemas() - logger.info(f"MCP wrapper has {len(updated_schemas)} schemas available") - - for method_name, schema_list in updated_schemas.items(): - if method_name != 'call_mcp_tool': - for schema in schema_list: - if schema.schema_type == SchemaType.OPENAPI: - thread_manager.tool_registry.tools[method_name] = { - "instance": mcp_wrapper_instance, - "schema": schema - } - logger.info(f"Registered dynamic MCP tool: {method_name}") - - all_tools = list(thread_manager.tool_registry.tools.keys()) - logger.info(f"All registered tools after MCP initialization: {all_tools}") - - except Exception as e: - logger.error(f"Failed to initialize MCP tools: {e}") - - return mcp_wrapper_instance - -def setup_tools(thread_manager: ThreadManager, agent_config: AgentRunConfig) -> None: - if agent_config.is_agent_builder: - register_agent_builder_tools(thread_manager, agent_config.target_agent_id) - - enabled_tools = None - if agent_config.agent_config and 'agentpress_tools' in agent_config.agent_config: - enabled_tools = agent_config.agent_config['agentpress_tools'] - logger.info("Using custom tool configuration from agent") - - if enabled_tools is None: - logger.info("No agent specified - registering all tools for full Suna capabilities") - register_default_tools(thread_manager, agent_config) - else: - logger.info("Custom agent specified - registering only enabled tools") - register_custom_tools(thread_manager, agent_config, enabled_tools) - -def get_base_system_prompt(model_name: str) -> str: - if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower(): - return get_gemini_system_prompt() - return get_system_prompt() - -def add_sample_response(system_content: str, model_name: str) -> str: - if "anthropic" not in model_name.lower(): - sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt') - with open(sample_response_path, 'r') as file: - sample_response = file.read() - return system_content + "\n\n " + sample_response + "" - return system_content - -async def add_knowledge_base_context(system_content: str, config: AgentRunConfig) -> str: - if not await is_enabled("knowledge_base"): - return system_content - - try: - from services.supabase import DBConnection - kb_db = DBConnection() - kb_client = await kb_db.client - - current_agent_id = config.agent_config.get('agent_id') if config.agent_config else None - - kb_result = await kb_client.rpc('get_combined_knowledge_base_context', { - 'p_thread_id': config.thread_id, - 'p_agent_id': current_agent_id, - 'p_max_tokens': 4000 - }).execute() - - if kb_result.data and kb_result.data.strip(): - logger.info(f"Adding combined knowledge base context to system prompt") - return system_content + "\n\n" + kb_result.data - else: - logger.debug(f"No knowledge base context found") - - except Exception as e: - logger.error(f"Error retrieving knowledge base context: {e}") - - return system_content - -def add_mcp_instructions(system_content: str, config: AgentRunConfig, mcp_wrapper_instance: Optional[MCPToolWrapper]) -> str: - if not (config.agent_config and (config.agent_config.get('configured_mcps') or config.agent_config.get('custom_mcps'))): - return system_content - - if not (mcp_wrapper_instance and mcp_wrapper_instance._initialized): - return system_content - - mcp_info = "\n\n--- MCP Tools Available ---\n" - mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n" - mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n" - mcp_info += '\n' - mcp_info += '\n' - mcp_info += 'value1\n' - mcp_info += 'value2\n' - mcp_info += '\n' - mcp_info += '\n\n' - - mcp_info += "Available MCP tools:\n" - try: - registered_schemas = mcp_wrapper_instance.get_schemas() - for method_name, schema_list in registered_schemas.items(): - if method_name == 'call_mcp_tool': - continue - - for schema in schema_list: - if schema.schema_type == SchemaType.OPENAPI: - func_info = schema.schema.get('function', {}) - description = func_info.get('description', 'No description available') - - mcp_info += f"- **{method_name}**: {description}\n" - - params = func_info.get('parameters', {}) - props = params.get('properties', {}) - if props: - mcp_info += f" Parameters: {', '.join(props.keys())}\n" - - except Exception as e: - logger.error(f"Error listing MCP tools: {e}") - mcp_info += "- Error loading MCP tool list\n" - - mcp_info += "\n🚨 CRITICAL MCP TOOL RESULT INSTRUCTIONS 🚨\n" - mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n" - mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n" - mcp_info += "2. For search tools: ONLY cite URLs, sources, and information from the actual search results\n" - mcp_info += "3. For any tool: Base your response entirely on the tool's output - do NOT add external information\n" - mcp_info += "4. DO NOT fabricate, invent, hallucinate, or make up any sources, URLs, or data\n" - mcp_info += "5. If you need more information, call the MCP tool again with different parameters\n" - mcp_info += "6. When writing reports/summaries: Reference ONLY the data from MCP tool results\n" - mcp_info += "7. If the MCP tool doesn't return enough information, explicitly state this limitation\n" - mcp_info += "8. Always double-check that every fact, URL, and reference comes from the MCP tool output\n" - mcp_info += "\nIMPORTANT: MCP tool results are your PRIMARY and ONLY source of truth for external data!\n" - mcp_info += "NEVER supplement MCP results with your training data or make assumptions beyond what the tools provide.\n" - - return system_content + mcp_info - -async def build_system_prompt(config: AgentRunConfig, mcp_wrapper_instance: Optional[MCPToolWrapper]) -> Dict: - base_prompt = get_base_system_prompt(config.model_name) - system_content = add_sample_response(base_prompt, config.model_name) - - if config.agent_config and config.agent_config.get('system_prompt'): - custom_system_prompt = config.agent_config['system_prompt'].strip() - system_content = custom_system_prompt - logger.info(f"Using ONLY custom agent system prompt") - elif config.is_agent_builder: - system_content = get_agent_builder_prompt() - logger.info("Using agent builder system prompt") - else: - logger.info("Using default system prompt only") - - system_content = await add_knowledge_base_context(system_content, config) - system_content = add_mcp_instructions(system_content, config, mcp_wrapper_instance) - - return {"role": "system", "content": system_content} - -async def get_browser_state_content(client: Any, thread_id: str, model_name: str, trace: Optional[StatefulTraceClient]) -> List[Dict]: - content_list = [] - - latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() - - if not (latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0): - return content_list - - try: - browser_content = latest_browser_state_msg.data[0]["content"] - if isinstance(browser_content, str): - browser_content = json.loads(browser_content) - - screenshot_base64 = browser_content.get("screenshot_base64") - screenshot_url = browser_content.get("image_url") - - browser_state_text = browser_content.copy() - browser_state_text.pop('screenshot_base64', None) - browser_state_text.pop('image_url', None) - - if browser_state_text: - content_list.append({ - "type": "text", - "text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}" - }) - - if is_vision_model(model_name): - if screenshot_url: - content_list.append({ - "type": "image_url", - "image_url": { - "url": screenshot_url, - "format": "image/jpeg" - } - }) - if trace: - trace.event(name="screenshot_url_added_to_temporary_message", level="DEFAULT") - elif screenshot_base64: - content_list.append({ - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{screenshot_base64}", - } - }) - if trace: - trace.event(name="screenshot_base64_added_to_temporary_message", level="WARNING") - else: - logger.warning("Browser state found but no screenshot data.") - if trace: - trace.event(name="browser_state_found_but_no_screenshot_data", level="WARNING") - else: - logger.warning("Model doesn't support vision, skipping screenshot.") - if trace: - trace.event(name="model_doesnt_support_vision", level="WARNING") - - except Exception as e: - logger.error(f"Error parsing browser state: {e}") - if trace: - trace.event(name="error_parsing_browser_state", level="ERROR", status_message=str(e)) - - return content_list - -async def get_image_context_content(client: Any, thread_id: str) -> List[Dict]: - content_list = [] - - latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute() - - if not (latest_image_context_msg.data and len(latest_image_context_msg.data) > 0): - return content_list - - try: - image_context_content = latest_image_context_msg.data[0]["content"] - if isinstance(image_context_content, str): - image_context_content = json.loads(image_context_content) - - base64_image = image_context_content.get("base64") - mime_type = image_context_content.get("mime_type") - file_path = image_context_content.get("file_path", "unknown file") - - if base64_image and mime_type: - content_list.extend([ - { - "type": "text", - "text": f"Here is the image you requested to see: '{file_path}'" - }, - { - "type": "image_url", - "image_url": { - "url": f"data:{mime_type};base64,{base64_image}", - } - } - ]) - else: - logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.") - - await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute() - - except Exception as e: - logger.error(f"Error parsing image context: {e}") - - return content_list - -async def build_temporary_message(client: Any, config: AgentRunConfig, trace: Optional[StatefulTraceClient]) -> Optional[Dict]: - temp_message_content_list = [] - - browser_content = await get_browser_state_content(client, config.thread_id, config.model_name, trace) - temp_message_content_list.extend(browser_content) - - image_content = await get_image_context_content(client, config.thread_id) - temp_message_content_list.extend(image_content) - - if temp_message_content_list: - return {"role": "user", "content": temp_message_content_list} - - return None - -async def should_continue_execution(client: Any, thread_id: str, trace: Optional[StatefulTraceClient]) -> bool: - latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute() - - if latest_message.data and len(latest_message.data) > 0: - message_type = latest_message.data[0].get('type') - if message_type == 'assistant': - logger.info("Last message was from assistant, stopping execution") - if trace: - trace.event(name="last_message_from_assistant", level="DEFAULT") - return False - - return True - -async def check_billing_limits(client: Any, account_id: str, trace: Optional[StatefulTraceClient]) -> Tuple[bool, str]: - can_run, message, subscription = await check_billing_status(client, account_id) - if not can_run: - error_msg = f"Billing limit reached: {message}" - if trace: - trace.event(name="billing_limit_reached", level="ERROR", status_message=error_msg) - return False, error_msg - return True, "" - -class ResponseProcessor: - def __init__(self, trace: Optional[StatefulTraceClient]): - self.trace = trace - self.last_tool_call = None - self.agent_should_terminate = False - self.full_response = "" - - def check_termination_signal(self, chunk: Dict) -> None: - if chunk.get('type') != 'status': - return - - try: - metadata = chunk.get('metadata', {}) - if isinstance(metadata, str): - metadata = json.loads(metadata) - - if metadata.get('agent_should_terminate'): - self.agent_should_terminate = True - logger.info("Agent termination signal detected") - if self.trace: - self.trace.event(name="agent_termination_signal_detected", level="DEFAULT") - - content = chunk.get('content', {}) - if isinstance(content, str): - content = json.loads(content) - - if content.get('function_name'): - self.last_tool_call = content['function_name'] - elif content.get('xml_tag_name'): - self.last_tool_call = content['xml_tag_name'] - - except Exception as e: - logger.debug(f"Error parsing status message for termination check: {e}") - - def check_xml_tools(self, chunk: Dict) -> None: - if chunk.get('type') != 'assistant' or 'content' not in chunk: - return - - try: - content = chunk.get('content', '{}') - if isinstance(content, str): - assistant_content_json = json.loads(content) - else: - assistant_content_json = content - - assistant_text = assistant_content_json.get('content', '') - self.full_response += assistant_text - - if isinstance(assistant_text, str): - for tool in ['ask', 'complete', 'web-browser-takeover']: - if f'' in assistant_text: - self.last_tool_call = tool - logger.info(f"Agent used XML tool: {tool}") - if self.trace: - self.trace.event(name="agent_used_xml_tool", level="DEFAULT", status_message=f"Agent used XML tool: {tool}") - break - - except json.JSONDecodeError: - logger.warning(f"Could not parse assistant content JSON: {chunk.get('content')}") - if self.trace: - self.trace.event(name="warning_could_not_parse_assistant_content_json", level="WARNING") - except Exception as e: - logger.error(f"Error processing assistant chunk: {e}") - if self.trace: - self.trace.event(name="error_processing_assistant_chunk", level="ERROR", status_message=str(e)) - - def process_chunk(self, chunk: Dict) -> None: - self.check_termination_signal(chunk) - self.check_xml_tools(chunk) - - def should_terminate(self) -> bool: - return (self.agent_should_terminate or - self.last_tool_call in ['ask', 'complete', 'web-browser-takeover']) - -async def process_llm_response(response: Any, processor: ResponseProcessor, trace: Optional[StatefulTraceClient]) -> AsyncGenerator[Dict, None]: - try: - if hasattr(response, '__aiter__') and not isinstance(response, dict): - async for chunk in response: - if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error': - logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}") - if trace: - trace.event(name="error_chunk_detected", level="ERROR", status_message=chunk.get('message', 'Unknown error')) - yield chunk - return - - processor.process_chunk(chunk) - yield chunk - else: - logger.error(f"Response is not iterable: {response}") - yield { - "type": "status", - "status": "error", - "message": "Response is not iterable" - } - - except Exception as e: - error_msg = f"Error during response streaming: {str(e)}" - logger.error(error_msg) - if trace: - trace.event(name="error_during_response_streaming", level="ERROR", status_message=error_msg) - yield { - "type": "status", - "status": "error", - "message": error_msg - } - -async def run_single_iteration( - config: AgentRunConfig, - context: ExecutionContext, - system_message: Dict, - temporary_message: Optional[Dict], - iteration_count: int -) -> AsyncGenerator[Dict, None]: - - can_continue, error_msg = await check_billing_limits(context.client, context.account_id, config.trace) - if not can_continue: - yield { - "type": "status", - "status": "stopped", - "message": error_msg - } - return - - if not await should_continue_execution(context.client, config.thread_id, config.trace): - return - - max_tokens = get_model_max_tokens(config.model_name) - generation = config.trace.generation(name="thread_manager.run_thread") if config.trace else None - processor = ResponseProcessor(config.trace) - - try: - response = await config.thread_manager.run_thread( - thread_id=config.thread_id, - system_prompt=system_message, - stream=config.stream, - llm_model=config.model_name, - llm_temperature=0, - llm_max_tokens=max_tokens, - tool_choice="auto", - max_xml_tool_calls=1, - temporary_message=temporary_message, - processor_config=ProcessorConfig( - xml_tool_calling=True, - native_tool_calling=False, - execute_tools=True, - execute_on_stream=True, - tool_execution_strategy="parallel", - xml_adding_strategy="user_message" - ), - native_max_auto_continues=config.native_max_auto_continues, - include_xml_examples=True, - enable_thinking=config.enable_thinking, - reasoning_effort=config.reasoning_effort, - enable_context_manager=config.enable_context_manager, - generation=generation - ) - - if isinstance(response, dict) and response.get("status") == "error": - logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}") - if config.trace: - config.trace.event(name="error_response_from_run_thread", level="ERROR", status_message=response.get('message', 'Unknown error')) - yield response - return - - async for chunk in process_llm_response(response, processor, config.trace): - yield chunk - - if processor.should_terminate(): - logger.info(f"Agent decided to stop with tool: {processor.last_tool_call}") - if config.trace: - config.trace.event(name="agent_decided_to_stop_with_tool", level="DEFAULT", status_message=f"Agent decided to stop with tool: {processor.last_tool_call}") - - if generation: - generation.end(output=processor.full_response, status_message="agent_stopped") - - yield { - "type": "status", - "status": "completed", - "terminate": True - } - else: - if generation: - generation.end(output=processor.full_response) - - except Exception as e: - error_msg = f"Error running thread: {str(e)}" - logger.error(error_msg) - if config.trace: - config.trace.event(name="error_running_thread", level="ERROR", status_message=error_msg) - if generation: - generation.end(output=processor.full_response, status_message=error_msg, level="ERROR") - yield { - "type": "status", - "status": "error", - "message": error_msg - } - -async def setup_trace_input(config: AgentRunConfig, context: ExecutionContext) -> None: - if not config.trace: - return - - latest_user_message = await context.client.table('messages').select('*').eq('thread_id', config.thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute() - - if latest_user_message.data and len(latest_user_message.data) > 0: - data = latest_user_message.data[0]['content'] - if isinstance(data, str): - data = json.loads(data) - config.trace.update(input=data['content']) - async def run_agent( thread_id: str, project_id: str, @@ -736,81 +50,625 @@ async def run_agent( is_agent_builder: Optional[bool] = False, target_agent_id: Optional[str] = None ): + """Run the development agent with specified configuration.""" logger.info(f"šŸš€ Starting agent with model: {model_name}") if agent_config: logger.info(f"Using custom agent: {agent_config.get('name', 'Unknown')}") - config = AgentRunConfig( - thread_id=thread_id, - project_id=project_id, - stream=stream, - thread_manager=thread_manager, - native_max_auto_continues=native_max_auto_continues, - max_iterations=max_iterations, - model_name=model_name, - enable_thinking=enable_thinking, - reasoning_effort=reasoning_effort, - enable_context_manager=enable_context_manager, - agent_config=agent_config, - trace=trace or langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id}), - is_agent_builder=is_agent_builder or False, - target_agent_id=target_agent_id - ) + if not trace: + trace = langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id}) + thread_manager = ThreadManager(trace=trace, is_agent_builder=is_agent_builder or False, target_agent_id=target_agent_id, agent_config=agent_config) - config.thread_manager = ThreadManager( - trace=config.trace, - is_agent_builder=config.is_agent_builder, - target_agent_id=config.target_agent_id, - agent_config=config.agent_config - ) + client = await thread_manager.db.client - try: - context = await setup_execution_context(config) - setup_tools(config.thread_manager, config) - mcp_wrapper_instance = await setup_mcp_tools(config.thread_manager, config, context) - system_message = await build_system_prompt(config, mcp_wrapper_instance) - await setup_trace_input(config, context) + # Get account ID from thread for billing checks + account_id = await get_account_id_from_thread(client, thread_id) + if not account_id: + raise ValueError("Could not determine account ID for thread") - iteration_count = 0 + # Get sandbox info from project + project = await client.table('projects').select('*').eq('project_id', project_id).execute() + if not project.data or len(project.data) == 0: + raise ValueError(f"Project {project_id} not found") + + project_data = project.data[0] + sandbox_info = project_data.get('sandbox', {}) + if not sandbox_info.get('id'): + raise ValueError(f"No sandbox found for project {project_id}") + + # Initialize tools with project_id instead of sandbox object + # This ensures each tool independently verifies it's operating on the correct project + + # Get enabled tools from agent config, or use defaults + enabled_tools = None + if agent_config and 'agentpress_tools' in agent_config: + enabled_tools = agent_config['agentpress_tools'] + logger.info(f"Using custom tool configuration from agent") + + + if is_agent_builder: + from agent.tools.agent_builder_tools.agent_config_tool import AgentConfigTool + from agent.tools.agent_builder_tools.mcp_search_tool import MCPSearchTool + from agent.tools.agent_builder_tools.credential_profile_tool import CredentialProfileTool + from agent.tools.agent_builder_tools.workflow_tool import WorkflowTool + from agent.tools.agent_builder_tools.trigger_tool import TriggerTool + from services.supabase import DBConnection + db = DBConnection() + + thread_manager.add_tool(AgentConfigTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) + thread_manager.add_tool(MCPSearchTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) + thread_manager.add_tool(CredentialProfileTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) + thread_manager.add_tool(WorkflowTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) + thread_manager.add_tool(TriggerTool, thread_manager=thread_manager, db_connection=db, agent_id=target_agent_id) - while iteration_count < config.max_iterations: - iteration_count += 1 - logger.info(f"šŸ”„ Running iteration {iteration_count} of {config.max_iterations}...") - temporary_message = await build_temporary_message(context.client, config, config.trace) + if enabled_tools is None: + logger.info("No agent specified - registering all tools for full Suna capabilities") + thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager) + thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager) + thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) + thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) + thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager) + thread_manager.add_tool(MessageTool) + thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) + thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + thread_manager.add_tool(SandboxImageEditTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + if config.RAPID_API_KEY: + thread_manager.add_tool(DataProvidersTool) + else: + logger.info("Custom agent specified - registering only enabled tools") + thread_manager.add_tool(ExpandMessageTool, thread_id=thread_id, thread_manager=thread_manager) + thread_manager.add_tool(MessageTool) + if enabled_tools.get('sb_shell_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxShellTool, project_id=project_id, thread_manager=thread_manager) + if enabled_tools.get('sb_files_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxFilesTool, project_id=project_id, thread_manager=thread_manager) + if enabled_tools.get('sb_browser_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxBrowserTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + if enabled_tools.get('sb_deploy_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxDeployTool, project_id=project_id, thread_manager=thread_manager) + if enabled_tools.get('sb_expose_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxExposeTool, project_id=project_id, thread_manager=thread_manager) + if enabled_tools.get('web_search_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxWebSearchTool, project_id=project_id, thread_manager=thread_manager) + if enabled_tools.get('sb_vision_tool', {}).get('enabled', False): + thread_manager.add_tool(SandboxVisionTool, project_id=project_id, thread_id=thread_id, thread_manager=thread_manager) + if config.RAPID_API_KEY and enabled_tools.get('data_providers_tool', {}).get('enabled', False): + thread_manager.add_tool(DataProvidersTool) + + # Register MCP tool wrapper if agent has configured MCPs or custom MCPs + mcp_wrapper_instance = None + if agent_config: + # Merge configured_mcps and custom_mcps + all_mcps = [] + + # Add standard configured MCPs + if agent_config.get('configured_mcps'): + all_mcps.extend(agent_config['configured_mcps']) + + # Add custom MCPs + if agent_config.get('custom_mcps'): + for custom_mcp in agent_config['custom_mcps']: + # Transform custom MCP to standard format + custom_type = custom_mcp.get('customType', custom_mcp.get('type', 'sse')) + + # For Pipedream MCPs, ensure we have the user ID and proper config + if custom_type == 'pipedream': + # Get user ID from thread + if 'config' not in custom_mcp: + custom_mcp['config'] = {} + + # Get external_user_id from profile if not present + if not custom_mcp['config'].get('external_user_id'): + profile_id = custom_mcp['config'].get('profile_id') + if profile_id: + try: + from pipedream.profiles import get_profile_manager + from services.supabase import DBConnection + profile_db = DBConnection() + profile_manager = get_profile_manager(profile_db) + + # Get the profile to retrieve external_user_id + profile = await profile_manager.get_profile(account_id, profile_id) + if profile: + custom_mcp['config']['external_user_id'] = profile.external_user_id + logger.info(f"Retrieved external_user_id from profile {profile_id} for Pipedream MCP") + else: + logger.error(f"Could not find profile {profile_id} for Pipedream MCP") + except Exception as e: + logger.error(f"Error retrieving external_user_id from profile {profile_id}: {e}") + + if 'headers' in custom_mcp['config'] and 'x-pd-app-slug' in custom_mcp['config']['headers']: + custom_mcp['config']['app_slug'] = custom_mcp['config']['headers']['x-pd-app-slug'] + + mcp_config = { + 'name': custom_mcp['name'], + 'qualifiedName': f"custom_{custom_type}_{custom_mcp['name'].replace(' ', '_').lower()}", + 'config': custom_mcp['config'], + 'enabledTools': custom_mcp.get('enabledTools', []), + 'instructions': custom_mcp.get('instructions', ''), + 'isCustom': True, + 'customType': custom_type + } + all_mcps.append(mcp_config) + + if all_mcps: + logger.info(f"Registering MCP tool wrapper for {len(all_mcps)} MCP servers (including {len(agent_config.get('custom_mcps', []))} custom)") + thread_manager.add_tool(MCPToolWrapper, mcp_configs=all_mcps) - should_terminate = False - async for result in run_single_iteration(config, context, system_message, temporary_message, iteration_count): - yield result - if result.get('terminate'): - should_terminate = True - break - if result.get('type') == 'status' and result.get('status') in ['error', 'stopped']: - should_terminate = True + for tool_name, tool_info in thread_manager.tool_registry.tools.items(): + if isinstance(tool_info['instance'], MCPToolWrapper): + mcp_wrapper_instance = tool_info['instance'] break - if should_terminate: + if mcp_wrapper_instance: + try: + await mcp_wrapper_instance.initialize_and_register_tools() + logger.info("MCP tools initialized successfully") + updated_schemas = mcp_wrapper_instance.get_schemas() + logger.info(f"MCP wrapper has {len(updated_schemas)} schemas available") + for method_name, schema_list in updated_schemas.items(): + if method_name != 'call_mcp_tool': + for schema in schema_list: + if schema.schema_type == SchemaType.OPENAPI: + thread_manager.tool_registry.tools[method_name] = { + "instance": mcp_wrapper_instance, + "schema": schema + } + logger.info(f"Registered dynamic MCP tool: {method_name}") + + # Log all registered tools for debugging + all_tools = list(thread_manager.tool_registry.tools.keys()) + logger.info(f"All registered tools after MCP initialization: {all_tools}") + mcp_tools = [tool for tool in all_tools if tool not in ['call_mcp_tool', 'sb_files_tool', 'message_tool', 'expand_msg_tool', 'web_search_tool', 'sb_shell_tool', 'sb_vision_tool', 'sb_browser_tool', 'computer_use_tool', 'data_providers_tool', 'sb_deploy_tool', 'sb_expose_tool', 'update_agent_tool']] + logger.info(f"MCP tools registered: {mcp_tools}") + + except Exception as e: + logger.error(f"Failed to initialize MCP tools: {e}") + # Continue without MCP tools if initialization fails + + # Prepare system prompt + # First, get the default system prompt + if "gemini-2.5-flash" in model_name.lower() and "gemini-2.5-pro" not in model_name.lower(): + default_system_content = get_gemini_system_prompt() + else: + # Use the original prompt - the LLM can only use tools that are registered + default_system_content = get_system_prompt() + + # Add sample response for non-anthropic models + if "anthropic" not in model_name.lower(): + sample_response_path = os.path.join(os.path.dirname(__file__), 'sample_responses/1.txt') + with open(sample_response_path, 'r') as file: + sample_response = file.read() + default_system_content = default_system_content + "\n\n " + sample_response + "" + + # Handle custom agent system prompt + if agent_config and agent_config.get('system_prompt'): + custom_system_prompt = agent_config['system_prompt'].strip() + + # Completely replace the default system prompt with the custom one + # This prevents confusion and tool hallucination + system_content = custom_system_prompt + logger.info(f"Using ONLY custom agent system prompt for: {agent_config.get('name', 'Unknown')}") + elif is_agent_builder: + system_content = get_agent_builder_prompt() + logger.info("Using agent builder system prompt") + else: + # Use just the default system prompt + system_content = default_system_content + logger.info("Using default system prompt only") + + if await is_enabled("knowledge_base"): + try: + from services.supabase import DBConnection + kb_db = DBConnection() + kb_client = await kb_db.client + + current_agent_id = agent_config.get('agent_id') if agent_config else None + + kb_result = await kb_client.rpc('get_combined_knowledge_base_context', { + 'p_thread_id': thread_id, + 'p_agent_id': current_agent_id, + 'p_max_tokens': 4000 + }).execute() + + if kb_result.data and kb_result.data.strip(): + logger.info(f"Adding combined knowledge base context to system prompt for thread {thread_id}, agent {current_agent_id}") + system_content += "\n\n" + kb_result.data + else: + logger.debug(f"No knowledge base context found for thread {thread_id}, agent {current_agent_id}") + + except Exception as e: + logger.error(f"Error retrieving knowledge base context for thread {thread_id}: {e}") + + + if agent_config and (agent_config.get('configured_mcps') or agent_config.get('custom_mcps')) and mcp_wrapper_instance and mcp_wrapper_instance._initialized: + mcp_info = "\n\n--- MCP Tools Available ---\n" + mcp_info += "You have access to external MCP (Model Context Protocol) server tools.\n" + mcp_info += "MCP tools can be called directly using their native function names in the standard function calling format:\n" + mcp_info += '\n' + mcp_info += '\n' + mcp_info += 'value1\n' + mcp_info += 'value2\n' + mcp_info += '\n' + mcp_info += '\n\n' + + # List available MCP tools + mcp_info += "Available MCP tools:\n" + try: + # Get the actual registered schemas from the wrapper + registered_schemas = mcp_wrapper_instance.get_schemas() + for method_name, schema_list in registered_schemas.items(): + if method_name == 'call_mcp_tool': + continue # Skip the fallback method + + # Get the schema info + for schema in schema_list: + if schema.schema_type == SchemaType.OPENAPI: + func_info = schema.schema.get('function', {}) + description = func_info.get('description', 'No description available') + # Extract server name from description if available + server_match = description.find('(MCP Server: ') + if server_match != -1: + server_end = description.find(')', server_match) + server_info = description[server_match:server_end+1] + else: + server_info = '' + + mcp_info += f"- **{method_name}**: {description}\n" + + # Show parameter info + params = func_info.get('parameters', {}) + props = params.get('properties', {}) + if props: + mcp_info += f" Parameters: {', '.join(props.keys())}\n" + + except Exception as e: + logger.error(f"Error listing MCP tools: {e}") + mcp_info += "- Error loading MCP tool list\n" + + # Add critical instructions for using search results + mcp_info += "\n🚨 CRITICAL MCP TOOL RESULT INSTRUCTIONS 🚨\n" + mcp_info += "When you use ANY MCP (Model Context Protocol) tools:\n" + mcp_info += "1. ALWAYS read and use the EXACT results returned by the MCP tool\n" + mcp_info += "2. For search tools: ONLY cite URLs, sources, and information from the actual search results\n" + mcp_info += "3. For any tool: Base your response entirely on the tool's output - do NOT add external information\n" + mcp_info += "4. DO NOT fabricate, invent, hallucinate, or make up any sources, URLs, or data\n" + mcp_info += "5. If you need more information, call the MCP tool again with different parameters\n" + mcp_info += "6. When writing reports/summaries: Reference ONLY the data from MCP tool results\n" + mcp_info += "7. If the MCP tool doesn't return enough information, explicitly state this limitation\n" + mcp_info += "8. Always double-check that every fact, URL, and reference comes from the MCP tool output\n" + mcp_info += "\nIMPORTANT: MCP tool results are your PRIMARY and ONLY source of truth for external data!\n" + mcp_info += "NEVER supplement MCP results with your training data or make assumptions beyond what the tools provide.\n" + + system_content += mcp_info + + system_message = { "role": "system", "content": system_content } + + iteration_count = 0 + continue_execution = True + + latest_user_message = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'user').order('created_at', desc=True).limit(1).execute() + if latest_user_message.data and len(latest_user_message.data) > 0: + data = latest_user_message.data[0]['content'] + if isinstance(data, str): + data = json.loads(data) + if trace: + trace.update(input=data['content']) + + while continue_execution and iteration_count < max_iterations: + iteration_count += 1 + logger.info(f"šŸ”„ Running iteration {iteration_count} of {max_iterations}...") + + # Billing check on each iteration - still needed within the iterations + can_run, message, subscription = await check_billing_status(client, account_id) + if not can_run: + error_msg = f"Billing limit reached: {message}" + if trace: + trace.event(name="billing_limit_reached", level="ERROR", status_message=(f"{error_msg}")) + # Yield a special message to indicate billing limit reached + yield { + "type": "status", + "status": "stopped", + "message": error_msg + } + break + # Check if last message is from assistant using direct Supabase query + latest_message = await client.table('messages').select('*').eq('thread_id', thread_id).in_('type', ['assistant', 'tool', 'user']).order('created_at', desc=True).limit(1).execute() + if latest_message.data and len(latest_message.data) > 0: + message_type = latest_message.data[0].get('type') + if message_type == 'assistant': + logger.info(f"Last message was from assistant, stopping execution") + if trace: + trace.event(name="last_message_from_assistant", level="DEFAULT", status_message=(f"Last message was from assistant, stopping execution")) + continue_execution = False break - except AgentExecutionError as e: - error_msg = str(e) - logger.error(f"Agent execution error: {error_msg}") - if config.trace: - config.trace.event(name="agent_execution_error", level="ERROR", status_message=error_msg) - yield { - "type": "status", - "status": "error", - "message": error_msg - } - except Exception as e: - error_msg = f"Unexpected error in run_agent: {str(e)}" - logger.error(error_msg) - if config.trace: - config.trace.event(name="unexpected_error_in_run_agent", level="ERROR", status_message=error_msg) - yield { - "type": "status", - "status": "error", - "message": error_msg - } - finally: - asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush())) + # ---- Temporary Message Handling (Browser State & Image Context) ---- + temporary_message = None + temp_message_content_list = [] # List to hold text/image blocks + + # Get the latest browser_state message + latest_browser_state_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'browser_state').order('created_at', desc=True).limit(1).execute() + if latest_browser_state_msg.data and len(latest_browser_state_msg.data) > 0: + try: + browser_content = latest_browser_state_msg.data[0]["content"] + if isinstance(browser_content, str): + browser_content = json.loads(browser_content) + screenshot_base64 = browser_content.get("screenshot_base64") + screenshot_url = browser_content.get("image_url") + + # Create a copy of the browser state without screenshot data + browser_state_text = browser_content.copy() + browser_state_text.pop('screenshot_base64', None) + browser_state_text.pop('image_url', None) + + if browser_state_text: + temp_message_content_list.append({ + "type": "text", + "text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}" + }) + + # Only add screenshot if model is not Gemini, Anthropic, or OpenAI + if 'gemini' in model_name.lower() or 'anthropic' in model_name.lower() or 'openai' in model_name.lower(): + # Prioritize screenshot_url if available + if screenshot_url: + temp_message_content_list.append({ + "type": "image_url", + "image_url": { + "url": screenshot_url, + "format": "image/jpeg" + } + }) + if trace: + trace.event(name="screenshot_url_added_to_temporary_message", level="DEFAULT", status_message=(f"Screenshot URL added to temporary message.")) + elif screenshot_base64: + # Fallback to base64 if URL not available + temp_message_content_list.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{screenshot_base64}", + } + }) + if trace: + trace.event(name="screenshot_base64_added_to_temporary_message", level="WARNING", status_message=(f"Screenshot base64 added to temporary message. Prefer screenshot_url if available.")) + else: + logger.warning("Browser state found but no screenshot data.") + if trace: + trace.event(name="browser_state_found_but_no_screenshot_data", level="WARNING", status_message=(f"Browser state found but no screenshot data.")) + else: + logger.warning("Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message.") + if trace: + trace.event(name="model_is_gemini_anthropic_or_openai", level="WARNING", status_message=(f"Model is Gemini, Anthropic, or OpenAI, so not adding screenshot to temporary message.")) + + except Exception as e: + logger.error(f"Error parsing browser state: {e}") + if trace: + trace.event(name="error_parsing_browser_state", level="ERROR", status_message=(f"{e}")) + + # Get the latest image_context message (NEW) + latest_image_context_msg = await client.table('messages').select('*').eq('thread_id', thread_id).eq('type', 'image_context').order('created_at', desc=True).limit(1).execute() + if latest_image_context_msg.data and len(latest_image_context_msg.data) > 0: + try: + image_context_content = latest_image_context_msg.data[0]["content"] if isinstance(latest_image_context_msg.data[0]["content"], dict) else json.loads(latest_image_context_msg.data[0]["content"]) + base64_image = image_context_content.get("base64") + mime_type = image_context_content.get("mime_type") + file_path = image_context_content.get("file_path", "unknown file") + + if base64_image and mime_type: + temp_message_content_list.append({ + "type": "text", + "text": f"Here is the image you requested to see: '{file_path}'" + }) + temp_message_content_list.append({ + "type": "image_url", + "image_url": { + "url": f"data:{mime_type};base64,{base64_image}", + } + }) + else: + logger.warning(f"Image context found for '{file_path}' but missing base64 or mime_type.") + + await client.table('messages').delete().eq('message_id', latest_image_context_msg.data[0]["message_id"]).execute() + except Exception as e: + logger.error(f"Error parsing image context: {e}") + if trace: + trace.event(name="error_parsing_image_context", level="ERROR", status_message=(f"{e}")) + + # If we have any content, construct the temporary_message + if temp_message_content_list: + temporary_message = {"role": "user", "content": temp_message_content_list} + # logger.debug(f"Constructed temporary message with {len(temp_message_content_list)} content blocks.") + # ---- End Temporary Message Handling ---- + + # Set max_tokens based on model + max_tokens = None + if "sonnet" in model_name.lower(): + # Claude 3.5 Sonnet has a limit of 8192 tokens + max_tokens = 8192 + elif "gpt-4" in model_name.lower(): + max_tokens = 4096 + elif "gemini-2.5-pro" in model_name.lower(): + # Gemini 2.5 Pro has 64k max output tokens + max_tokens = 64000 + elif "kimi-k2" in model_name.lower(): + # Kimi-K2 has 120K context, set reasonable max output tokens + max_tokens = 8192 + + generation = trace.generation(name="thread_manager.run_thread") if trace else None + try: + # Make the LLM call and process the response + response = await thread_manager.run_thread( + thread_id=thread_id, + system_prompt=system_message, + stream=stream, + llm_model=model_name, + llm_temperature=0, + llm_max_tokens=max_tokens, + tool_choice="auto", + max_xml_tool_calls=1, + temporary_message=temporary_message, + processor_config=ProcessorConfig( + xml_tool_calling=True, + native_tool_calling=False, + execute_tools=True, + execute_on_stream=True, + tool_execution_strategy="parallel", + xml_adding_strategy="user_message" + ), + native_max_auto_continues=native_max_auto_continues, + include_xml_examples=True, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager, + generation=generation + ) + + if isinstance(response, dict) and "status" in response and response["status"] == "error": + logger.error(f"Error response from run_thread: {response.get('message', 'Unknown error')}") + if trace: + trace.event(name="error_response_from_run_thread", level="ERROR", status_message=(f"{response.get('message', 'Unknown error')}")) + yield response + break + + # Track if we see ask, complete, or web-browser-takeover tool calls + last_tool_call = None + agent_should_terminate = False + + # Process the response + error_detected = False + full_response = "" + try: + # Check if response is iterable (async generator) or a dict (error case) + if hasattr(response, '__aiter__') and not isinstance(response, dict): + async for chunk in response: + # If we receive an error chunk, we should stop after this iteration + if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error': + logger.error(f"Error chunk detected: {chunk.get('message', 'Unknown error')}") + if trace: + trace.event(name="error_chunk_detected", level="ERROR", status_message=(f"{chunk.get('message', 'Unknown error')}")) + error_detected = True + yield chunk # Forward the error chunk + continue # Continue processing other chunks but don't break yet + + # Check for termination signal in status messages + if chunk.get('type') == 'status': + try: + # Parse the metadata to check for termination signal + metadata = chunk.get('metadata', {}) + if isinstance(metadata, str): + metadata = json.loads(metadata) + + if metadata.get('agent_should_terminate'): + agent_should_terminate = True + logger.info("Agent termination signal detected in status message") + if trace: + trace.event(name="agent_termination_signal_detected", level="DEFAULT", status_message="Agent termination signal detected in status message") + + # Extract the tool name from the status content if available + content = chunk.get('content', {}) + if isinstance(content, str): + content = json.loads(content) + + if content.get('function_name'): + last_tool_call = content['function_name'] + elif content.get('xml_tag_name'): + last_tool_call = content['xml_tag_name'] + + except Exception as e: + logger.debug(f"Error parsing status message for termination check: {e}") + + # Check for XML versions like , , or in assistant content chunks + if chunk.get('type') == 'assistant' and 'content' in chunk: + try: + # The content field might be a JSON string or object + content = chunk.get('content', '{}') + if isinstance(content, str): + assistant_content_json = json.loads(content) + else: + assistant_content_json = content + + # The actual text content is nested within + assistant_text = assistant_content_json.get('content', '') + full_response += assistant_text + if isinstance(assistant_text, str): + if '' in assistant_text or '' in assistant_text or '' in assistant_text: + if '' in assistant_text: + xml_tool = 'ask' + elif '' in assistant_text: + xml_tool = 'complete' + elif '' in assistant_text: + xml_tool = 'web-browser-takeover' + + last_tool_call = xml_tool + logger.info(f"Agent used XML tool: {xml_tool}") + if trace: + trace.event(name="agent_used_xml_tool", level="DEFAULT", status_message=(f"Agent used XML tool: {xml_tool}")) + + except json.JSONDecodeError: + # Handle cases where content might not be valid JSON + logger.warning(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}") + if trace: + trace.event(name="warning_could_not_parse_assistant_content_json", level="WARNING", status_message=(f"Warning: Could not parse assistant content JSON: {chunk.get('content')}")) + except Exception as e: + logger.error(f"Error processing assistant chunk: {e}") + if trace: + trace.event(name="error_processing_assistant_chunk", level="ERROR", status_message=(f"Error processing assistant chunk: {e}")) + + yield chunk + else: + # Response is not iterable, likely an error dict + logger.error(f"Response is not iterable: {response}") + error_detected = True + + # Check if we should stop based on the last tool call or error + if error_detected: + logger.info(f"Stopping due to error detected in response") + if trace: + trace.event(name="stopping_due_to_error_detected_in_response", level="DEFAULT", status_message=(f"Stopping due to error detected in response")) + if generation: + generation.end(output=full_response, status_message="error_detected", level="ERROR") + break + + if agent_should_terminate or last_tool_call in ['ask', 'complete', 'web-browser-takeover']: + logger.info(f"Agent decided to stop with tool: {last_tool_call}") + if trace: + trace.event(name="agent_decided_to_stop_with_tool", level="DEFAULT", status_message=(f"Agent decided to stop with tool: {last_tool_call}")) + if generation: + generation.end(output=full_response, status_message="agent_stopped") + continue_execution = False + + except Exception as e: + # Just log the error and re-raise to stop all iterations + error_msg = f"Error during response streaming: {str(e)}" + logger.error(f"Error: {error_msg}") + if trace: + trace.event(name="error_during_response_streaming", level="ERROR", status_message=(f"Error during response streaming: {str(e)}")) + if generation: + generation.end(output=full_response, status_message=error_msg, level="ERROR") + yield { + "type": "status", + "status": "error", + "message": error_msg + } + # Stop execution immediately on any error + break + + except Exception as e: + # Just log the error and re-raise to stop all iterations + error_msg = f"Error running thread: {str(e)}" + logger.error(f"Error: {error_msg}") + if trace: + trace.event(name="error_running_thread", level="ERROR", status_message=(f"Error running thread: {str(e)}")) + yield { + "type": "status", + "status": "error", + "message": error_msg + } + # Stop execution immediately on any error + break + if generation: + generation.end(output=full_response) + + asyncio.create_task(asyncio.to_thread(lambda: langfuse.flush())) \ No newline at end of file diff --git a/backend/services/llm.py b/backend/services/llm.py index 6187e247..7045294c 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -232,6 +232,12 @@ def prepare_params( use_thinking = enable_thinking if enable_thinking is not None else False is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower() is_xai = "xai" in effective_model_name.lower() or model_name.startswith("xai/") + is_kimi_k2 = "kimi-k2" in effective_model_name.lower() or model_name.startswith("moonshotai/kimi-k2") + + if is_kimi_k2: + params["provider"] = { + "order": ["groq", "together/fp8"] + } if is_anthropic and use_thinking: effort_level = reasoning_effort if reasoning_effort else 'low' diff --git a/backend/utils/constants.py b/backend/utils/constants.py index 9d85684f..4a0ac589 100644 --- a/backend/utils/constants.py +++ b/backend/utils/constants.py @@ -53,6 +53,14 @@ MODELS = { }, "tier_availability": ["paid"] }, + "openrouter/moonshotai/kimi-k2": { + "aliases": ["moonshotai/kimi-k2", "kimi-k2"], + "pricing": { + "input_cost_per_million_tokens": 1.00, + "output_cost_per_million_tokens": 3.00 + }, + "tier_availability": ["paid"] + }, "openai/gpt-4o": { "aliases": ["gpt-4o"], "pricing": { diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 69cea002..74bf2b3b 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -9,6 +9,20 @@ WORKDIR /app # Install dependencies based on the preferred package manager COPY package.json yarn.lock* package-lock.json* pnpm-lock.yaml* .npmrc* ./ + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + make \ + g++ \ + build-essential \ + pkg-config \ + libcairo2-dev \ + libpango1.0-dev \ + libjpeg-dev \ + libgif-dev \ + librsvg2-dev \ + && rm -rf /var/lib/apt/lists/* + RUN \ if [ -f yarn.lock ]; then yarn --frozen-lockfile; \ elif [ -f package-lock.json ]; then npm ci; \ diff --git a/frontend/src/components/dashboard/layout-content.tsx b/frontend/src/components/dashboard/layout-content.tsx index 0af30b05..f07c9ad3 100644 --- a/frontend/src/components/dashboard/layout-content.tsx +++ b/frontend/src/components/dashboard/layout-content.tsx @@ -9,7 +9,7 @@ import { useAccounts } from '@/hooks/use-accounts'; import { useAuth } from '@/components/AuthProvider'; import { useRouter } from 'next/navigation'; import { Loader2 } from 'lucide-react'; -import { checkApiHealth } from '@/lib/api'; +import { useApiHealth } from '@/hooks/react-query'; import { MaintenancePage } from '@/components/maintenance/maintenance-page'; import { DeleteOperationProvider } from '@/contexts/DeleteOperationContext'; import { StatusOverlay } from '@/components/ui/status-overlay'; @@ -28,37 +28,19 @@ export default function DashboardLayoutContent({ }: DashboardLayoutContentProps) { // const [showPricingAlert, setShowPricingAlert] = useState(false) const [showMaintenanceAlert, setShowMaintenanceAlert] = useState(false); - const [isApiHealthy, setIsApiHealthy] = useState(true); - const [isCheckingHealth, setIsCheckingHealth] = useState(true); const { data: accounts } = useAccounts(); const personalAccount = accounts?.find((account) => account.personal_account); const { user, isLoading } = useAuth(); const router = useRouter(); + const { data: healthData, isLoading: isCheckingHealth, error: healthError } = useApiHealth(); useEffect(() => { // setShowPricingAlert(false) setShowMaintenanceAlert(false); }, []); - // Check API health - useEffect(() => { - const checkHealth = async () => { - try { - const health = await checkApiHealth(); - setIsApiHealthy(health.status === 'ok'); - } catch (error) { - console.error('API health check failed:', error); - setIsApiHealthy(false); - } finally { - setIsCheckingHealth(false); - } - }; - - checkHealth(); - // Check health every 30 seconds - const interval = setInterval(checkHealth, 30000); - return () => clearInterval(interval); - }, []); + // API health is now managed by useApiHealth hook + const isApiHealthy = healthData?.status === 'ok' && !healthError; // Check authentication status useEffect(() => { @@ -107,8 +89,8 @@ export default function DashboardLayoutContent({ return null; } - // Show maintenance page if API is not healthy - if (!isApiHealthy) { + // Show maintenance page if API is not healthy (but not during initial loading) + if (!isCheckingHealth && !isApiHealthy) { return ; } diff --git a/frontend/src/components/thread/chat-input/_use-model-selection.ts b/frontend/src/components/thread/chat-input/_use-model-selection.ts index 7687e1ed..d1ac4643 100644 --- a/frontend/src/components/thread/chat-input/_use-model-selection.ts +++ b/frontend/src/components/thread/chat-input/_use-model-selection.ts @@ -70,6 +70,12 @@ export const MODELS = { recommended: false, lowQuality: false }, + 'moonshotai/kimi-k2': { + tier: 'premium', + priority: 96, + recommended: false, + lowQuality: false + }, 'gpt-4.1': { tier: 'premium', priority: 96, diff --git a/frontend/src/components/thread/chat-input/message-input.tsx b/frontend/src/components/thread/chat-input/message-input.tsx index 78508bfb..ded27ec4 100644 --- a/frontend/src/components/thread/chat-input/message-input.tsx +++ b/frontend/src/components/thread/chat-input/message-input.tsx @@ -16,6 +16,7 @@ import { Tooltip } from '@/components/ui/tooltip'; import { TooltipProvider, TooltipTrigger } from '@radix-ui/react-tooltip'; import { BillingModal } from '@/components/billing/billing-modal'; import ChatDropdown from './chat-dropdown'; +import { handleFiles } from './file-upload-handler'; interface MessageInputProps { value: string; @@ -129,6 +130,29 @@ export const MessageInput = forwardRef( } }; + const handlePaste = (e: React.ClipboardEvent) => { + if (!e.clipboardData) return; + const items = Array.from(e.clipboardData.items); + const imageFiles: File[] = []; + for (const item of items) { + if (item.kind === 'file' && item.type.startsWith('image/')) { + const file = item.getAsFile(); + if (file) imageFiles.push(file); + } + } + if (imageFiles.length > 0) { + e.preventDefault(); + handleFiles( + imageFiles, + sandboxId, + setPendingFiles, + setUploadedFiles, + setIsUploading, + messages, + ); + } + }; + const renderDropdown = () => { if (isLoggedIn) { const showAdvancedFeatures = enableAdvancedConfig || (customAgentsEnabled && !flagsLoading); @@ -167,6 +191,7 @@ export const MessageInput = forwardRef( value={value} onChange={onChange} onKeyDown={handleKeyDown} + onPaste={handlePaste} placeholder={placeholder} className={cn( 'w-full bg-transparent dark:bg-transparent border-none shadow-none focus-visible:ring-0 px-0.5 pb-6 pt-4 !text-[15px] min-h-[36px] max-h-[200px] overflow-y-auto resize-none',