diff --git a/backend/agent/api.py b/backend/agent/api.py index edb54c35..93da2d3f 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends, Request +from fastapi import APIRouter, HTTPException, Depends, Request, Body from fastapi.responses import StreamingResponse import asyncio import json @@ -7,6 +7,7 @@ from datetime import datetime, timezone import uuid from typing import Optional, List, Dict, Any import jwt +from pydantic import BaseModel from agentpress.thread_manager import ThreadManager from services.supabase import DBConnection @@ -26,6 +27,19 @@ db = None # In-memory storage for active agent runs and their responses active_agent_runs: Dict[str, List[Any]] = {} +MODEL_NAME_ALIASES = { + "sonnet-3.7": "anthropic/claude-3-7-sonnet-latest", + "gpt-4.1": "openai/gpt-4.1-2025-04-14", + "gemini-flash-2.5": "openrouter/google/gemini-2.5-flash-preview", +} + +class AgentStartRequest(BaseModel): + model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest" + enable_thinking: Optional[bool] = False + reasoning_effort: Optional[str] = 'low' + stream: Optional[bool] = True + enable_context_manager: Optional[bool] = False + def initialize( _thread_manager: ThreadManager, _db: DBConnection, @@ -237,9 +251,13 @@ async def _cleanup_agent_run(agent_run_id: str): # Non-fatal error, can continue @router.post("/thread/{thread_id}/agent/start") -async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)): +async def start_agent( + thread_id: str, + body: AgentStartRequest = Body(...), # Accept request body + user_id: str = Depends(get_current_user_id) +): """Start an agent for a specific thread in the background.""" - logger.info(f"Starting new agent for thread: {thread_id}") + logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager}") client = await db.client # Verify user has access to this thread @@ -314,7 +332,18 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id # Run the agent in the background task = asyncio.create_task( - run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox) + run_agent_background( + agent_run_id=agent_run_id, + thread_id=thread_id, + instance_id=instance_id, + project_id=project_id, + sandbox=sandbox, + model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name), + enable_thinking=body.enable_thinking, + reasoning_effort=body.reasoning_effort, + stream=body.stream, + enable_context_manager=body.enable_context_manager + ) ) # Set a callback to clean up when task is done @@ -441,9 +470,20 @@ async def stream_agent_run( } ) -async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox): +async def run_agent_background( + agent_run_id: str, + thread_id: str, + instance_id: str, + project_id: str, + sandbox, + model_name: str, + enable_thinking: Optional[bool], + reasoning_effort: Optional[str], + stream: bool, + enable_context_manager: bool +): """Run the agent in the background and handle status updates.""" - logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})") + logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}, context_manager={enable_context_manager}") client = await db.client # Tracking variables @@ -561,9 +601,17 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s try: # Run the agent logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})") - agent_gen = run_agent(thread_id, stream=True, - thread_manager=thread_manager, project_id=project_id, - sandbox=sandbox) + agent_gen = run_agent( + thread_id=thread_id, + project_id=project_id, + stream=stream, + thread_manager=thread_manager, + sandbox=sandbox, + model_name=model_name, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager + ) # Collect all responses to save to database all_responses = [] diff --git a/backend/agent/run.py b/backend/agent/run.py index 0c784395..b860c7b0 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -22,7 +22,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread load_dotenv() -async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150): +async def run_agent( + thread_id: str, + project_id: str, + sandbox, + stream: bool, + thread_manager: Optional[ThreadManager] = None, + native_max_auto_continues: int = 25, + max_iterations: int = 150, + model_name: str = "anthropic/claude-3-7-sonnet-latest", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low', + enable_context_manager: bool = True +): """Run the development agent with specified configuration.""" if not thread_manager: @@ -42,17 +54,15 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru thread_manager.add_tool(SandboxDeployTool, sandbox=sandbox) thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool - if os.getenv("EXA_API_KEY"): + if os.getenv("TAVILY_API_KEY"): thread_manager.add_tool(WebSearchTool) + else: + print("TAVILY_API_KEY not found, WebSearchTool will not be available.") if os.getenv("RAPID_API_KEY"): thread_manager.add_tool(DataProvidersTool) - xml_examples = "" - for tag_name, example in thread_manager.tool_registry.get_xml_examples().items(): - xml_examples += f"{example}\n" - - system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"\n{xml_examples}\n" } + system_message = { "role": "system", "content": get_system_prompt() } iteration_count = 0 continue_execution = True @@ -112,14 +122,16 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru except Exception as e: print(f"Error parsing browser state: {e}") # print(latest_browser_state.data[0]) + + max_tokens = 64000 if "sonnet" in model_name.lower() else None response = await thread_manager.run_thread( thread_id=thread_id, system_prompt=system_message, stream=stream, - llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), + llm_model=model_name, llm_temperature=0, - llm_max_tokens=64000, + llm_max_tokens=max_tokens, tool_choice="auto", max_xml_tool_calls=1, temporary_message=temporary_message, @@ -133,6 +145,9 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru ), native_max_auto_continues=native_max_auto_continues, include_xml_examples=True, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager ) if isinstance(response, dict) and "status" in response and response["status"] == "error": @@ -267,7 +282,16 @@ async def test_agent(): print("\nšŸ‘‹ Test completed. Goodbye!") -async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager): +async def process_agent_response( + thread_id: str, + project_id: str, + thread_manager: ThreadManager, + stream: bool = True, + model_name: str = "anthropic/claude-3-7-sonnet-latest", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low', + enable_context_manager: bool = True +): """Process the streaming response from the agent.""" chunk_counter = 0 current_response = "" @@ -276,9 +300,20 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager # Create a test sandbox for processing sandbox_pass = str(uuid4()) sandbox = create_sandbox(sandbox_pass) - print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m") + print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m") - async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25): + async for chunk in run_agent( + thread_id=thread_id, + project_id=project_id, + sandbox=sandbox, + stream=stream, + thread_manager=thread_manager, + native_max_auto_continues=25, + model_name=model_name, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort, + enable_context_manager=enable_context_manager + ): chunk_counter += 1 # print(f"CHUNK: {chunk}") # Uncomment for debugging diff --git a/backend/agent/tools/web_search_tool.py b/backend/agent/tools/web_search_tool.py index 2964ee14..e375cf39 100644 --- a/backend/agent/tools/web_search_tool.py +++ b/backend/agent/tools/web_search_tool.py @@ -1,4 +1,5 @@ -from exa_py import Exa +from tavily import AsyncTavilyClient +import httpx from typing import List, Optional from datetime import datetime import os @@ -15,10 +16,12 @@ class WebSearchTool(Tool): # Load environment variables load_dotenv() # Use the provided API key or get it from environment variables - self.api_key = api_key or os.getenv("EXA_API_KEY") + self.api_key = api_key or os.getenv("TAVILY_API_KEY") if not self.api_key: - raise ValueError("EXA_API_KEY not found in environment variables") - self.exa = Exa(api_key=self.api_key) + raise ValueError("TAVILY_API_KEY not found in environment variables") + + # Tavily asynchronous search client + self.tavily_client = AsyncTavilyClient(api_key=self.api_key) @openapi_schema({ "type": "function", @@ -111,57 +114,49 @@ class WebSearchTool(Tool): if not query or not isinstance(query, str): return self.fail_response("A valid search query is required.") - # Basic parameters - use only the minimum required to avoid API errors - params = { - "query": query, - "type": "auto", - "livecrawl": "auto" - } - - # Handle summary parameter (boolean conversion) - if summary is None: - params["summary"] = True - elif isinstance(summary, bool): - params["summary"] = summary - elif isinstance(summary, str): - params["summary"] = summary.lower() == "true" - else: - params["summary"] = True - - # Handle num_results parameter (integer conversion) + # ---------- Tavily search parameters ---------- + # num_results normalisation (1‑50) if num_results is None: - params["num_results"] = 20 + num_results = 20 elif isinstance(num_results, int): - params["num_results"] = max(1, min(num_results, 50)) + num_results = max(1, min(num_results, 50)) elif isinstance(num_results, str): try: - params["num_results"] = max(1, min(int(num_results), 50)) + num_results = max(1, min(int(num_results), 50)) except ValueError: - params["num_results"] = 20 + num_results = 20 else: - params["num_results"] = 20 - - # Execute the search with minimal parameters - search_response = self.exa.search_and_contents(**params) - - # Format the results + num_results = 20 + + # Execute the search with Tavily + search_response = await self.tavily_client.search( + query=query, + max_results=num_results, + include_answer=False, + include_images=False, + ) + + # `tavily` may return a dict with `results` or a bare list + raw_results = ( + search_response.get("results") + if isinstance(search_response, dict) + else search_response + ) + formatted_results = [] - for result in search_response.results: + for result in raw_results: formatted_result = { - "Title": result.title, - "URL": result.url + "Title": result.get("title"), + "URL": result.get("url"), } - - # Add optional fields if they exist - if hasattr(result, 'summary') and result.summary: - formatted_result["Summary"] = result.summary - - if hasattr(result, 'published_date') and result.published_date: - formatted_result["Published Date"] = result.published_date - - if hasattr(result, 'score'): - formatted_result["Score"] = result.score - + + if summary: + # Prefer full content; fall back to description + if result.get("content"): + formatted_result["Summary"] = result["content"] + elif result.get("description"): + formatted_result["Summary"] = result["description"] + formatted_results.append(formatted_result) return self.success_response(formatted_results) @@ -243,26 +238,50 @@ class WebSearchTool(Tool): else: return self.fail_response("URL must be a string.") - # Execute the crawl with the parsed URL - result = self.exa.get_contents( - [url], - text=True, - livecrawl="auto" - ) - - # Format the results to include all available fields - formatted_results = [] - for content in result.results: - formatted_result = { - "Title": content.title, - "URL": content.url, - "Text": content.text + # ---------- Tavily extract endpoint ---------- + async with httpx.AsyncClient() as client: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", } - - # Add optional fields if they exist - if hasattr(content, 'published_date') and content.published_date: - formatted_result["Published Date"] = content.published_date - + payload = { + "urls": url, + "include_images": False, + "extract_depth": "basic", + } + response = await client.post( + "https://api.tavily.com/extract", + json=payload, + headers=headers, + timeout=60, + ) + response.raise_for_status() + data = response.json() + print(f"--- Raw Tavily Response ---") + print(data) + print(f"--------------------------") + + # Normalise Tavily extract output to a list of dicts + extracted = [] + if isinstance(data, list): + extracted = data + elif isinstance(data, dict): + if "results" in data and isinstance(data["results"], list): + extracted = data["results"] + elif "urls" in data and isinstance(data["urls"], dict): + extracted = list(data["urls"].values()) + else: + extracted = [data] + + formatted_results = [] + for item in extracted: + formatted_result = { + "Title": item.get("title"), + "URL": item.get("url") or url, + "Text": item.get("content") or item.get("text") or "", + } + if item.get("published_date"): + formatted_result["Published Date"] = item["published_date"] formatted_results.append(formatted_result) return self.success_response(formatted_results) @@ -279,27 +298,27 @@ class WebSearchTool(Tool): if __name__ == "__main__": import asyncio - # async def test_web_search(): - # """Test function for the web search tool""" - # search_tool = WebSearchTool() - # result = await search_tool.web_search( - # query="rubber gym mats best prices comparison", - # summary=True, - # num_results=20 - # ) - # print(result) + async def test_web_search(): + """Test function for the web search tool""" + search_tool = WebSearchTool() + result = await search_tool.web_search( + query="rubber gym mats best prices comparison", + summary=True, + num_results=20 + ) + print(result) async def test_crawl_webpage(): """Test function for the webpage crawl tool""" search_tool = WebSearchTool() result = await search_tool.crawl_webpage( - url="https://example.com" + url="https://google.com" ) print(result) async def run_tests(): """Run all test functions""" - # await test_web_search() + await test_web_search() await test_crawl_webpage() - asyncio.run(run_tests()) + asyncio.run(run_tests()) \ No newline at end of file diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 096c8fdf..02a28ed9 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -96,10 +96,19 @@ class ResponseProcessor: self, llm_response: AsyncGenerator, thread_id: str, + prompt_messages: List[Dict[str, Any]], + llm_model: str, config: ProcessorConfig = ProcessorConfig(), ) -> AsyncGenerator[Dict[str, Any], None]: """Process a streaming LLM response, handling tool calls and execution. + Args: + llm_response: Streaming response from the LLM + thread_id: ID of the conversation thread + prompt_messages: List of messages sent to the LLM (the prompt) + llm_model: The name of the LLM model used + config: Configuration for parsing and execution + Yields: Complete message objects matching the DB schema, except for content chunks. """ @@ -144,8 +153,14 @@ class ResponseProcessor: if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None + + # Check for and log Anthropic thinking content + if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content: + logger.info(f"[THINKING]: {delta.reasoning_content}") + # Append reasoning to main content to be saved in the final message + accumulated_content += delta.reasoning_content - # --- Process Content Chunk --- + # Process content chunk if delta and hasattr(delta, 'content') and delta.content: chunk_content = delta.content accumulated_content += chunk_content @@ -263,8 +278,8 @@ class ResponseProcessor: tool_index += 1 if finish_reason == "xml_tool_limit_reached": - logger.info("Stopping stream due to XML tool call limit") - break # Exit the async for loop + logger.info("Stopping stream processing after loop due to XML tool call limit") + break # --- After Streaming Loop --- @@ -529,10 +544,19 @@ class ResponseProcessor: self, llm_response: Any, thread_id: str, + prompt_messages: List[Dict[str, Any]], + llm_model: str, config: ProcessorConfig = ProcessorConfig() ) -> AsyncGenerator[Dict[str, Any], None]: """Process a non-streaming LLM response, handling tool calls and execution. - + + Args: + llm_response: Response from the LLM + thread_id: ID of the conversation thread + prompt_messages: List of messages sent to the LLM (the prompt) + llm_model: The name of the LLM model used + config: Configuration for parsing and execution + Yields: Complete message objects matching the DB schema. """ diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index e9ac32c8..987e4cce 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -161,6 +161,9 @@ class ThreadManager: native_max_auto_continues: int = 25, max_xml_tool_calls: int = 0, include_xml_examples: bool = False, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low', + enable_context_manager: bool = True ) -> Union[Dict[str, Any], AsyncGenerator]: """Run a conversation thread with LLM integration and tool execution. @@ -178,6 +181,9 @@ class ThreadManager: finish_reason="tool_calls" (0 disables auto-continue) max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit) include_xml_examples: Whether to include XML tool examples in the system prompt + enable_thinking: Whether to enable thinking before making a decision + reasoning_effort: The effort level for reasoning + enable_context_manager: Whether to enable automatic context summarization. Returns: An async generator yielding response chunks or error dict @@ -187,6 +193,52 @@ class ThreadManager: logger.debug(f"Parameters: model={llm_model}, temperature={llm_temperature}, max_tokens={llm_max_tokens}") logger.debug(f"Auto-continue: max={native_max_auto_continues}, XML tool limit={max_xml_tool_calls}") + # Use a default config if none was provided (needed for XML examples check) + if processor_config is None: + processor_config = ProcessorConfig() + + # Apply max_xml_tool_calls if specified and not already set in config + if max_xml_tool_calls > 0 and not processor_config.max_xml_tool_calls: + processor_config.max_xml_tool_calls = max_xml_tool_calls + + # Create a working copy of the system prompt to potentially modify + working_system_prompt = system_prompt.copy() + + # Add XML examples to system prompt if requested, do this only ONCE before the loop + if include_xml_examples and processor_config.xml_tool_calling: + xml_examples = self.tool_registry.get_xml_examples() + if xml_examples: + examples_content = """ +--- XML TOOL CALLING --- + +In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format. +Format your tool calls using the specified XML tags. Place parameters marked as 'attribute' within the opening tag (e.g., ``). Place parameters marked as 'content' between the opening and closing tags. Place parameters marked as 'element' within their own child tags (e.g., `value`). Refer to the examples provided below for the exact structure of each tool. +String and scalar parameters should be specified as attributes, while content goes between tags. +Note that spaces for string values are not stripped. The output is parsed with regular expressions. + +Here are the XML tools available with examples: +""" + for tag_name, example in xml_examples.items(): + examples_content += f"<{tag_name}> Example: {example}\\n" + + system_content = working_system_prompt.get('content') + + if isinstance(system_content, str): + working_system_prompt['content'] += examples_content + logger.debug("Appended XML examples to string system prompt content.") + elif isinstance(system_content, list): + appended = False + for item in working_system_prompt['content']: # Modify the copy + if isinstance(item, dict) and item.get('type') == 'text' and 'text' in item: + item['text'] += examples_content + logger.debug("Appended XML examples to the first text block in list system prompt content.") + appended = True + break + if not appended: + logger.warning("System prompt content is a list but no text block found to append XML examples.") + else: + logger.warning(f"System prompt content is of unexpected type ({type(system_content)}), cannot add XML examples.") + # Control whether we need to auto-continue due to tool_calls finish reason auto_continue = True auto_continue_count = 0 @@ -195,81 +247,46 @@ class ThreadManager: async def _run_once(temp_msg=None): try: # Ensure processor_config is available in this scope - nonlocal processor_config - - # Use a default config if none was provided - if processor_config is None: - processor_config = ProcessorConfig() - - # Apply max_xml_tool_calls if specified and not already set - if max_xml_tool_calls > 0: - processor_config.max_xml_tool_calls = max_xml_tool_calls - - # Add XML examples to system prompt if requested - if include_xml_examples and processor_config.xml_tool_calling: - xml_examples = self.tool_registry.get_xml_examples() - if xml_examples: - # logger.debug(f"Adding {len(xml_examples)} XML examples to system prompt") - - # Create or append to content - if isinstance(system_prompt['content'], str): - examples_content = """ ---- XML TOOL CALLING --- - -In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format. -{{ FORMATTING INSTRUCTIONS }} -String and scalar parameters should be specified as attributes, while content goes between tags. -Note that spaces for string values are not stripped. The output is parsed with regular expressions. - -Here are the XML tools available with examples: -""" - for tag_name, example in xml_examples.items(): - examples_content += f"<{tag_name}> Example: {example}\n" - - system_prompt['content'] += examples_content - else: - # If content is not a string (might be a list or dict), log a warning - logger.warning("System prompt content is not a string, cannot add XML examples") + nonlocal processor_config + # Note: processor_config is now guaranteed to exist due to check above # 1. Get messages from thread for LLM call messages = await self.get_llm_messages(thread_id) # 2. Check token count before proceeding - # Use litellm to count tokens in the messages token_count = 0 try: from litellm import token_counter - token_count = token_counter(model=llm_model, messages=[system_prompt] + messages) + # Use the potentially modified working_system_prompt for token counting + token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages) token_threshold = self.context_manager.token_threshold logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)") - # If we're over the threshold, summarize the thread - if token_count >= token_threshold: + if token_count >= token_threshold and enable_context_manager: logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...") - - # Create summary using context manager summarized = await self.context_manager.check_and_summarize_if_needed( thread_id=thread_id, add_message_callback=self.add_message, model=llm_model, - force=True # Force summarization + force=True ) - if summarized: - # If summarization was successful, get the updated messages - # This will now include the summary message and only messages after it logger.info("Summarization complete, fetching updated messages with summary") messages = await self.get_llm_messages(thread_id) - # Recount tokens after summarization - new_token_count = token_counter(model=llm_model, messages=[system_prompt] + messages) + # Recount tokens after summarization, using the modified prompt + new_token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages) logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}") else: logger.warning("Summarization failed or wasn't needed - proceeding with original messages") + elif not enable_context_manager: # Added condition for clarity + logger.info("Automatic summarization disabled. Skipping token count check and summarization.") + except Exception as e: logger.error(f"Error counting tokens or summarizing: {str(e)}") # 3. Prepare messages for LLM call + add temporary message if it exists - prepared_messages = [system_prompt] + # Use the working_system_prompt which may contain the XML examples + prepared_messages = [working_system_prompt] # Find the last user message index last_user_index = -1 @@ -306,13 +323,15 @@ Here are the XML tools available with examples: logger.debug("Making LLM API call") try: llm_response = await make_llm_api_call( - prepared_messages, + prepared_messages, # Pass the potentially modified messages llm_model, temperature=llm_temperature, max_tokens=llm_max_tokens, tools=openapi_tool_schemas, tool_choice=tool_choice if processor_config.native_tool_calling else None, - stream=stream + stream=stream, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) logger.debug("Successfully received raw LLM API response stream/object") @@ -326,22 +345,27 @@ Here are the XML tools available with examples: response_generator = self.response_processor.process_streaming_response( llm_response=llm_response, thread_id=thread_id, - config=processor_config + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model ) return response_generator else: logger.debug("Processing non-streaming response") try: - response = await self.response_processor.process_non_streaming_response( + # Return the async generator directly, don't await it + response_generator = self.response_processor.process_non_streaming_response( llm_response=llm_response, thread_id=thread_id, - config=processor_config + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model ) - return response + return response_generator # Return the generator except Exception as e: - logger.error(f"Error in non-streaming response: {str(e)}", exc_info=True) - raise + logger.error(f"Error setting up non-streaming response: {str(e)}", exc_info=True) + raise # Re-raise the exception to be caught by the outer handler except Exception as e: logger.error(f"Error in run_thread: {str(e)}", exc_info=True) @@ -358,8 +382,9 @@ Here are the XML tools available with examples: # Reset auto_continue for this iteration auto_continue = False - # Run the thread once - response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None) + # Run the thread once, passing the potentially modified system prompt + # Pass temp_msg only on the first iteration + response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None) # Handle error responses if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error": @@ -402,7 +427,8 @@ Here are the XML tools available with examples: # If auto-continue is disabled (max=0), just run once if native_max_auto_continues == 0: logger.info("Auto-continue is disabled (native_max_auto_continues=0)") - return await _run_once(temporary_message) + # Pass the potentially modified system prompt and temp message + return await _run_once(temporary_message) # Otherwise return the auto-continue wrapper generator return auto_continue_wrapper() diff --git a/backend/requirements.txt b/backend/requirements.txt index a8967c5b..c5ca0888 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,6 @@ streamlit-quill==0.0.3 python-dotenv==1.0.1 -litellm>=1.44.0 +litellm>=1.66.2 click==8.1.7 questionary==2.0.1 requests>=2.31.0 @@ -22,4 +22,5 @@ certifi==2024.2.2 python-ripgrep==0.0.6 daytona_sdk>=0.12.0 boto3>=1.34.0 -exa-py>=1.9.1 +pydantic +tavily-python>=0.5.4 \ No newline at end of file diff --git a/backend/sandbox/api.py b/backend/sandbox/api.py index de4d2e47..33effa6f 100644 --- a/backend/sandbox/api.py +++ b/backend/sandbox/api.py @@ -151,9 +151,11 @@ async def list_files( for file in files: # Convert file information to our model + # Ensure forward slashes are used for paths, regardless of OS + full_path = f"{path.rstrip('/')}/{file.name}" if path != '/' else f"/{file.name}" file_info = FileInfo( name=file.name, - path=os.path.join(path, file.name), + path=full_path, # Use the constructed path is_dir=file.is_dir, size=file.size, mod_time=str(file.mod_time), diff --git a/backend/services/llm.py b/backend/services/llm.py index e559c4bd..2bcfb7ae 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -17,6 +17,8 @@ import asyncio from openai import OpenAIError import litellm from utils.logger import logger +from datetime import datetime +import traceback # litellm.set_verbose=True litellm.modify_params=True @@ -82,7 +84,9 @@ def prepare_params( api_base: Optional[str] = None, stream: bool = False, top_p: Optional[float] = None, - model_id: Optional[str] = None + model_id: Optional[str] = None, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Dict[str, Any]: """Prepare parameters for the API call.""" params = { @@ -152,6 +156,75 @@ def prepare_params( params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}") + # Apply Anthropic prompt caching (minimal implementation) + # Check model name *after* potential modifications (like adding bedrock/ prefix) + effective_model_name = params.get("model", model_name) # Use model from params if set, else original + if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower(): + logger.debug("Applying minimal Anthropic prompt caching.") + messages = params["messages"] # Direct reference, modification affects params + + # Ensure messages is a list + if not isinstance(messages, list): + logger.warning(f"Messages is not a list ({type(messages)}), skipping Anthropic cache control.") + return params # Return early if messages format is unexpected + + # 1. Process the first message if it's a system prompt with string content + if messages and messages[0].get("role") == "system": + content = messages[0].get("content") + if isinstance(content, str): + # Wrap the string content in the required list structure + messages[0]["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + logger.debug("Applied cache_control to system message (converted from string).") + elif isinstance(content, list): + # If content is already a list, check if the first text block needs cache_control + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + if "cache_control" not in item: + item["cache_control"] = {"type": "ephemeral"} + break # Apply to the first text block only for system prompt + else: + logger.warning("System message content is not a string or list, skipping cache_control.") + + # 2. Find and process the last user message + last_user_idx = -1 + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + last_user_idx = i + break + + if last_user_idx != -1: + last_user_message = messages[last_user_idx] + content = last_user_message.get("content") + + if isinstance(content, str): + # Wrap the string content in the required list structure + last_user_message["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).") + elif isinstance(content, list): + # Modify text blocks within the list directly + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + # Add cache_control if not already present + if "cache_control" not in item: + item["cache_control"] = {"type": "ephemeral"} + + else: + logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.") + + # Add reasoning_effort for Anthropic models if enabled + use_thinking = enable_thinking if enable_thinking is not None else False + is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower() + + if is_anthropic and use_thinking: + effort_level = reasoning_effort if reasoning_effort else 'low' + params["reasoning_effort"] = effort_level + params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used + logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'") + return params async def make_llm_api_call( @@ -166,7 +239,9 @@ async def make_llm_api_call( api_base: Optional[str] = None, stream: bool = False, top_p: Optional[float] = None, - model_id: Optional[str] = None + model_id: Optional[str] = None, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Union[Dict[str, Any], AsyncGenerator]: """ Make an API call to a language model using LiteLLM. @@ -184,6 +259,8 @@ async def make_llm_api_call( stream: Whether to stream the response top_p: Top-p sampling parameter model_id: Optional ARN for Bedrock inference profiles + enable_thinking: Whether to enable thinking + reasoning_effort: Level of reasoning effort Returns: Union[Dict[str, Any], AsyncGenerator]: API response or stream @@ -192,7 +269,7 @@ async def make_llm_api_call( LLMRetryError: If API call fails after retries LLMError: For other API-related errors """ - logger.debug(f"Making LLM API call to model: {model_name}") + logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})") params = prepare_params( messages=messages, model_name=model_name, @@ -205,7 +282,9 @@ async def make_llm_api_call( api_base=api_base, stream=stream, top_p=top_p, - model_id=model_id + model_id=model_id, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) last_error = None diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index 6a6bf875..99a5b62e 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -12,7 +12,7 @@ function DashboardContent() { const [isSubmitting, setIsSubmitting] = useState(false); const router = useRouter(); - const handleSubmit = async (message: string) => { + const handleSubmit = async (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => { if (!message.trim() || isSubmitting) return; setIsSubmitting(true); @@ -34,7 +34,11 @@ function DashboardContent() { await addUserMessage(thread.thread_id, message.trim()); // 4. Start the agent with the thread ID - const agentRun = await startAgent(thread.thread_id); + const agentRun = await startAgent(thread.thread_id, { + model_name: options?.model_name, + enable_thinking: options?.enable_thinking, + stream: true + }); // 5. Navigate to the new agent's thread page router.push(`/dashboard/agents/${thread.thread_id}`); diff --git a/frontend/src/components/thread/chat-input.tsx b/frontend/src/components/thread/chat-input.tsx index ebafed1b..117c3221 100644 --- a/frontend/src/components/thread/chat-input.tsx +++ b/frontend/src/components/thread/chat-input.tsx @@ -3,7 +3,7 @@ import React, { useState, useRef, useEffect } from 'react'; import { Textarea } from "@/components/ui/textarea"; import { Button } from "@/components/ui/button"; -import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText } from "lucide-react"; +import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText, ChevronDown, Cpu } from "lucide-react"; import { createClient } from "@/lib/supabase/client"; import { toast } from "sonner"; import { AnimatePresence, motion } from "framer-motion"; @@ -13,13 +13,22 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/ui/tooltip"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { cn } from "@/lib/utils"; // Define API_URL const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || ''; +// Local storage keys +const STORAGE_KEY_MODEL = 'suna-preferred-model'; + interface ChatInputProps { - onSubmit: (message: string) => void; + onSubmit: (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => void; placeholder?: string; loading?: boolean; disabled?: boolean; @@ -40,7 +49,7 @@ interface UploadedFile { export function ChatInput({ onSubmit, - placeholder = "Type your message... (Enter to send, Shift+Enter for new line)", + placeholder = "Describe what you need help with...", loading = false, disabled = false, isAgentRunning = false, @@ -52,47 +61,63 @@ export function ChatInput({ sandboxId }: ChatInputProps) { const [inputValue, setInputValue] = useState(value || ""); + const [selectedModel, setSelectedModel] = useState("sonnet-3.7"); const textareaRef = useRef(null); const fileInputRef = useRef(null); const [uploadedFiles, setUploadedFiles] = useState([]); const [isUploading, setIsUploading] = useState(false); const [isDraggingOver, setIsDraggingOver] = useState(false); - // Allow controlled or uncontrolled usage + useEffect(() => { + if (typeof window !== 'undefined') { + try { + const savedModel = localStorage.getItem(STORAGE_KEY_MODEL); + if (savedModel) { + setSelectedModel(savedModel); + } + } catch (error) { + console.warn('Failed to load preferences from localStorage:', error); + } + } + }, []); + const isControlled = value !== undefined && onChange !== undefined; - // Update local state if controlled and value changes useEffect(() => { if (isControlled && value !== inputValue) { setInputValue(value); } }, [value, isControlled, inputValue]); - // Auto-focus on textarea when component loads useEffect(() => { if (autoFocus && textareaRef.current) { textareaRef.current.focus(); } }, [autoFocus]); - // Adjust textarea height based on content useEffect(() => { const textarea = textareaRef.current; if (!textarea) return; const adjustHeight = () => { textarea.style.height = 'auto'; - const newHeight = Math.min(textarea.scrollHeight, 200); // Max height of 200px + const newHeight = Math.min(Math.max(textarea.scrollHeight, 50), 200); // Min 50px, max 200px textarea.style.height = `${newHeight}px`; }; adjustHeight(); - // Adjust on window resize too window.addEventListener('resize', adjustHeight); return () => window.removeEventListener('resize', adjustHeight); }, [inputValue]); + const handleModelChange = (model: string) => { + setSelectedModel(model); + if (typeof window !== 'undefined') { + localStorage.setItem(STORAGE_KEY_MODEL, model); + } + }; + const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); if ((!inputValue.trim() && uploadedFiles.length === 0) || loading || (disabled && !isAgentRunning)) return; @@ -104,7 +129,6 @@ export function ChatInput({ let message = inputValue; - // Add file information to the message if files were uploaded if (uploadedFiles.length > 0) { const fileInfo = uploadedFiles.map(file => `[Uploaded file: ${file.name} (${formatFileSize(file.size)}) at ${file.path}]` @@ -112,13 +136,22 @@ export function ChatInput({ message = message ? `${message}\n\n${fileInfo}` : fileInfo; } - onSubmit(message); + let baseModelName = selectedModel; + let thinkingEnabled = false; + if (selectedModel === "sonnet-3.7-thinking") { + baseModelName = "sonnet-3.7"; + thinkingEnabled = true; + } + + onSubmit(message, { + model_name: baseModelName, + enable_thinking: thinkingEnabled + }); if (!isControlled) { setInputValue(""); } - // Reset the uploaded files after sending setUploadedFiles([]); }; @@ -175,7 +208,6 @@ export function ChatInput({ const files = Array.from(event.target.files); await uploadFiles(files); - // Reset the input event.target.value = ''; }; @@ -191,11 +223,9 @@ export function ChatInput({ continue; } - // Create a FormData object const formData = new FormData(); formData.append('file', file); - // Upload to workspace root by default const uploadPath = `/workspace/${file.name}`; formData.append('path', uploadPath); @@ -206,7 +236,6 @@ export function ChatInput({ throw new Error('No access token available'); } - // Upload using FormData const response = await fetch(`${API_URL}/sandboxes/${sandboxId}/files`, { method: 'POST', headers: { @@ -219,7 +248,6 @@ export function ChatInput({ throw new Error(`Upload failed: ${response.statusText}`); } - // Add to uploaded files newUploadedFiles.push({ name: file.name, path: uploadPath, @@ -229,7 +257,6 @@ export function ChatInput({ toast.success(`File uploaded: ${file.name}`); } - // Update the uploaded files state setUploadedFiles(prev => [...prev, ...newUploadedFiles]); } catch (error) { @@ -273,11 +300,18 @@ export function ChatInput({ } }; + const modelDisplayNames: { [key: string]: string } = { + "sonnet-3.7": "Sonnet 3.7", + "sonnet-3.7-thinking": "Sonnet 3.7 (Thinking)", + "gpt-4.1": "GPT-4.1", + "gemini-flash-2.5": "Gemini Flash 2.5" + }; + return (
0 ? "border-border" : "border-input", + "w-full border rounded-xl transition-all duration-200 shadow-sm bg-[#1a1a1a] border-gray-800", + uploadedFiles.length > 0 ? "border-border" : "border-gray-800", isDraggingOver ? "border-primary border-dashed bg-primary/5" : "" )} onDragOver={handleDragOver} @@ -300,18 +334,18 @@ export function ChatInput({ animate={{ opacity: 1, scale: 1 }} exit={{ opacity: 0, scale: 0.9 }} transition={{ duration: 0.15 }} - className="px-2 py-1 bg-secondary/20 rounded-full flex items-center gap-1.5 group border border-secondary/30 hover:border-secondary/50 transition-colors text-sm" + className="px-2 py-1 bg-gray-800 rounded-full flex items-center gap-1.5 group border border-gray-700 hover:border-gray-600 transition-colors text-sm" > {getFileIcon(file.name)} - {file.name} - + {file.name} + ({formatFileSize(file.size)})
-
+
)} -
- {isDraggingOver && ( -
-
- -

Drop files to upload

-
-
- )} +
+
+