From 9beb4dc2152cc6dc6c92effb1ad332fc4f9ae1f3 Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 04:41:55 +0100 Subject: [PATCH 01/13] fix non-stream async --- backend/agentpress/thread_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index e68a942c..f03459b4 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -334,15 +334,16 @@ Here are the XML tools available with examples: else: logger.debug("Processing non-streaming response") try: - response = await self.response_processor.process_non_streaming_response( + # Return the async generator directly, don't await it + response_generator = self.response_processor.process_non_streaming_response( llm_response=llm_response, thread_id=thread_id, config=processor_config ) - return response + return response_generator # Return the generator except Exception as e: - logger.error(f"Error in non-streaming response: {str(e)}", exc_info=True) - raise + logger.error(f"Error setting up non-streaming response: {str(e)}", exc_info=True) + raise # Re-raise the exception to be caught by the outer handler except Exception as e: logger.error(f"Error in run_thread: {str(e)}", exc_info=True) From 70882f9292ec1f9006e232dd9a29b8c76cb442ac Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 04:49:27 +0100 Subject: [PATCH 02/13] claude prompt-caching --- backend/services/llm.py | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/backend/services/llm.py b/backend/services/llm.py index e559c4bd..f3abe4df 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -152,6 +152,65 @@ def prepare_params( params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}") + # Apply Anthropic prompt caching (minimal implementation) + # Check model name *after* potential modifications (like adding bedrock/ prefix) + effective_model_name = params.get("model", model_name) # Use model from params if set, else original + if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower(): + logger.debug("Applying minimal Anthropic prompt caching.") + messages = params["messages"] # Direct reference, modification affects params + + # Ensure messages is a list + if not isinstance(messages, list): + logger.warning(f"Messages is not a list ({type(messages)}), skipping Anthropic cache control.") + return params # Return early if messages format is unexpected + + # 1. Process the first message if it's a system prompt with string content + if messages and messages[0].get("role") == "system": + content = messages[0].get("content") + if isinstance(content, str): + # Wrap the string content in the required list structure + messages[0]["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + logger.debug("Applied cache_control to system message (converted from string).") + elif isinstance(content, list): + # If content is already a list, check if the first text block needs cache_control + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + if "cache_control" not in item: + item["cache_control"] = {"type": "ephemeral"} + break # Apply to the first text block only for system prompt + else: + logger.warning("System message content is not a string or list, skipping cache_control.") + + # 2. Find and process the last user message + last_user_idx = -1 + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + last_user_idx = i + break + + if last_user_idx != -1: + last_user_message = messages[last_user_idx] + content = last_user_message.get("content") + + if isinstance(content, str): + # Wrap the string content in the required list structure + last_user_message["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).") + elif isinstance(content, list): + # Modify text blocks within the list directly + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + # Add cache_control if not already present + if "cache_control" not in item: + item["cache_control"] = {"type": "ephemeral"} + + else: + logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.") + return params async def make_llm_api_call( From 3f340f740d3a4ab10ffe98de404417a6cfe23bd4 Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 04:51:09 +0100 Subject: [PATCH 03/13] cast sandbox preview links to string for consistency --- backend/agent/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index 6ab8891f..edb54c35 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -285,8 +285,8 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id 'sandbox': { 'id': sandbox_id, 'pass': sandbox_pass, - 'vnc_preview': sandbox.get_preview_link(6080), - 'sandbox_url': sandbox.get_preview_link(8080) + 'vnc_preview': str(sandbox.get_preview_link(6080)), + 'sandbox_url': str(sandbox.get_preview_link(8080)) } }).eq('project_id', project_id).execute() From 3f06702ea4a34649394270725677f5fca914dc0a Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 05:19:41 +0100 Subject: [PATCH 04/13] remove xml tool example duplication --- backend/agent/run.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/backend/agent/run.py b/backend/agent/run.py index 4565ac36..87929225 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -47,11 +47,7 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru if os.getenv("RAPID_API_KEY"): thread_manager.add_tool(DataProvidersTool) - xml_examples = "" - for tag_name, example in thread_manager.tool_registry.get_xml_examples().items(): - xml_examples += f"{example}\n" - - system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"\n{xml_examples}\n" } + system_message = { "role": "system", "content": get_system_prompt() } iteration_count = 0 continue_execution = True From adc80366156720881417a1b94c5b405d95cf06d7 Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 05:48:58 +0100 Subject: [PATCH 05/13] -U litellm --- backend/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/requirements.txt b/backend/requirements.txt index a8967c5b..7025281b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,6 +1,6 @@ streamlit-quill==0.0.3 python-dotenv==1.0.1 -litellm>=1.44.0 +litellm>=1.66.2 click==8.1.7 questionary==2.0.1 requests>=2.31.0 @@ -23,3 +23,4 @@ python-ripgrep==0.0.6 daytona_sdk>=0.12.0 boto3>=1.34.0 exa-py>=1.9.1 +pydantic From c84ee59dc613436634d67e6706c0f2603cbbb228 Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 05:49:41 +0100 Subject: [PATCH 06/13] thinking & reasoning --- backend/agent/api.py | 56 ++++++++-- backend/agent/run.py | 41 ++++++- backend/agentpress/response_processor.py | 133 +++++++++++++++++++---- backend/agentpress/thread_manager.py | 16 ++- backend/services/llm.py | 28 ++++- 5 files changed, 230 insertions(+), 44 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index edb54c35..d300ca9d 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends, Request +from fastapi import APIRouter, HTTPException, Depends, Request, Body from fastapi.responses import StreamingResponse import asyncio import json @@ -7,6 +7,7 @@ from datetime import datetime, timezone import uuid from typing import Optional, List, Dict, Any import jwt +from pydantic import BaseModel from agentpress.thread_manager import ThreadManager from services.supabase import DBConnection @@ -26,6 +27,12 @@ db = None # In-memory storage for active agent runs and their responses active_agent_runs: Dict[str, List[Any]] = {} +class AgentStartRequest(BaseModel): + model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest" + enable_thinking: Optional[bool] = False + reasoning_effort: Optional[str] = 'low' + stream: Optional[bool] = False # Default stream to False for API + def initialize( _thread_manager: ThreadManager, _db: DBConnection, @@ -237,9 +244,13 @@ async def _cleanup_agent_run(agent_run_id: str): # Non-fatal error, can continue @router.post("/thread/{thread_id}/agent/start") -async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)): +async def start_agent( + thread_id: str, + body: AgentStartRequest = Body(...), # Accept request body + user_id: str = Depends(get_current_user_id) +): """Start an agent for a specific thread in the background.""" - logger.info(f"Starting new agent for thread: {thread_id}") + logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}") client = await db.client # Verify user has access to this thread @@ -314,7 +325,17 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id # Run the agent in the background task = asyncio.create_task( - run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox) + run_agent_background( + agent_run_id=agent_run_id, + thread_id=thread_id, + instance_id=instance_id, + project_id=project_id, + sandbox=sandbox, + model_name=body.model_name, + enable_thinking=body.enable_thinking, + reasoning_effort=body.reasoning_effort, + stream=body.stream # Pass stream parameter + ) ) # Set a callback to clean up when task is done @@ -441,9 +462,19 @@ async def stream_agent_run( } ) -async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox): +async def run_agent_background( + agent_run_id: str, + thread_id: str, + instance_id: str, + project_id: str, + sandbox, + model_name: str, + enable_thinking: Optional[bool], + reasoning_effort: Optional[str], + stream: bool # Add stream parameter +): """Run the agent in the background and handle status updates.""" - logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})") + logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}") client = await db.client # Tracking variables @@ -561,9 +592,16 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s try: # Run the agent logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})") - agent_gen = run_agent(thread_id, stream=True, - thread_manager=thread_manager, project_id=project_id, - sandbox=sandbox) + agent_gen = run_agent( + thread_id=thread_id, + project_id=project_id, + stream=stream, # Pass stream parameter from API request + thread_manager=thread_manager, + sandbox=sandbox, + model_name=model_name, # Pass model_name + enable_thinking=enable_thinking, # Pass enable_thinking + reasoning_effort=reasoning_effort # Pass reasoning_effort + ) # Collect all responses to save to database all_responses = [] diff --git a/backend/agent/run.py b/backend/agent/run.py index 87929225..d4fddc61 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -21,7 +21,18 @@ from utils.billing import check_billing_status, get_account_id_from_thread load_dotenv() -async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150): +async def run_agent( + thread_id: str, + project_id: str, + sandbox, + stream: bool, + thread_manager: Optional[ThreadManager] = None, + native_max_auto_continues: int = 25, + max_iterations: int = 150, + model_name: str = "anthropic/claude-3-7-sonnet-latest", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' +): """Run the development agent with specified configuration.""" if not thread_manager: @@ -112,7 +123,7 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru thread_id=thread_id, system_prompt=system_message, stream=stream, - llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"), + llm_model=model_name, llm_temperature=0, llm_max_tokens=64000, tool_choice="auto", @@ -128,6 +139,8 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru ), native_max_auto_continues=native_max_auto_continues, include_xml_examples=True, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) if isinstance(response, dict) and "status" in response and response["status"] == "error": @@ -250,7 +263,15 @@ async def test_agent(): print("\nšŸ‘‹ Test completed. Goodbye!") -async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager): +async def process_agent_response( + thread_id: str, + project_id: str, + thread_manager: ThreadManager, + stream: bool = True, + model_name: str = "anthropic/claude-3-7-sonnet-latest", + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' +): """Process the streaming response from the agent.""" chunk_counter = 0 current_response = "" @@ -259,9 +280,19 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager # Create a test sandbox for processing sandbox_pass = str(uuid4()) sandbox = create_sandbox(sandbox_pass) - print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m") + print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m") - async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25): + async for chunk in run_agent( + thread_id=thread_id, + project_id=project_id, + sandbox=sandbox, + stream=stream, + thread_manager=thread_manager, + native_max_auto_continues=25, + model_name=model_name, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort + ): chunk_counter += 1 if chunk.get('type') == 'content' and 'content' in chunk: diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 0e250cae..468a6b09 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -98,6 +98,8 @@ class ResponseProcessor: llm_response: AsyncGenerator, thread_id: str, config: ProcessorConfig = ProcessorConfig(), + prompt_messages: Optional[List[Dict[str, Any]]] = None, + llm_model: Optional[str] = None ) -> AsyncGenerator: """Process a streaming LLM response, handling tool calls and execution. @@ -105,6 +107,8 @@ class ResponseProcessor: llm_response: Streaming response from the LLM thread_id: ID of the conversation thread config: Configuration for parsing and execution + prompt_messages: List of messages used for cost calculation + llm_model: Name of the LLM model used for cost calculation Yields: Formatted chunks of the response including content and tool results @@ -175,15 +179,20 @@ class ResponseProcessor: accumulated_content += chunk_content current_xml_content += chunk_content - # Calculate cost using prompt and completion - try: - cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) - tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) - accumulated_cost += cost - accumulated_token_count += tcount - logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") - except Exception as e: - logger.error(f"Error calculating cost: {str(e)}") + # Process reasoning content if present (Anthropic) + if hasattr(delta, 'reasoning_content') and delta.reasoning_content: + logger.info(f"[THINKING]: {delta.reasoning_content}") + accumulated_content += delta.reasoning_content # Append reasoning to main content + + # Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE + # try: + # cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) + # tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) + # accumulated_cost += cost + # accumulated_token_count += tcount + # logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") + # except Exception as e: + # logger.error(f"Error calculating cost: {str(e)}") # Check if we've reached the XML tool call limit before yielding content if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: @@ -489,6 +498,35 @@ class ResponseProcessor: "thread_run_id": thread_run_id } + # --- Cost Calculation (moved here) --- + if prompt_messages and llm_model and accumulated_content: + try: + cost = completion_cost( + model=llm_model, + messages=prompt_messages, + completion=accumulated_content + ) + token_count = token_counter( + model=llm_model, + messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}] + ) + await self.add_message( + thread_id=thread_id, + type="cost", + content={ + "cost": cost, + "prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx + "completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx + "total_tokens": token_count, + "model_name": llm_model + }, + is_llm_message=False + ) + logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}") + except Exception as e: + logger.error(f"Error calculating cost: {str(e)}") + # --- End Cost Calculation --- + # --- Process All Tool Calls Now --- if config.execute_tools: final_tool_calls_to_process = [] @@ -626,23 +664,24 @@ class ResponseProcessor: yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} finally: - # Yield a finish signal including the final assistant message ID - if last_assistant_message_id: - # Yield the overall run end signal + # Yield the detected finish reason if one exists and wasn't suppressed + if finish_reason and finish_reason != "xml_tool_limit_reached": yield { - "type": "thread_run_end", - "thread_run_id": thread_run_id - } - else: - # Yield the overall run end signal - yield { - "type": "thread_run_end", + "type": "finish", + "finish_reason": finish_reason, "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None } - + + # Yield a finish signal including the final assistant message ID + # Ensure thread_run_id is defined, even if an early error occurred + run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed + yield { + "type": "thread_run_end", + "thread_run_id": run_id + } + + # Remove old cost calculation code pass - # track the cost and token count - # todo: there is a bug as it adds every chunk to db because finally will run every time even in yield # await self.add_message( # thread_id=thread_id, # type="cost", @@ -658,7 +697,9 @@ class ResponseProcessor: self, llm_response: Any, thread_id: str, - config: ProcessorConfig = ProcessorConfig() + config: ProcessorConfig = ProcessorConfig(), + prompt_messages: Optional[List[Dict[str, Any]]] = None, + llm_model: Optional[str] = None ) -> AsyncGenerator[Dict[str, Any], None]: """Process a non-streaming LLM response, handling tool calls and execution. @@ -666,6 +707,8 @@ class ResponseProcessor: llm_response: Response from the LLM thread_id: ID of the conversation thread config: Configuration for parsing and execution + prompt_messages: List of messages used for cost calculation + llm_model: Name of the LLM model used for cost calculation Yields: Formatted response including content and tool results @@ -861,6 +904,50 @@ class ResponseProcessor: "thread_run_id": thread_run_id } + # --- Cost Calculation (moved here) --- + if prompt_messages and llm_model: + cost = None + # Attempt to get cost from LiteLLM response first + if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params: + cost = llm_response._hidden_params['response_cost'] + logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}") + + # If no pre-calculated cost, calculate manually + if cost is None: + try: + cost = completion_cost( + model=llm_model, + messages=prompt_messages, + completion=content # Use extracted content + ) + logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}") + except Exception as e: + logger.error(f"Error calculating cost: {str(e)}") + + # Add cost message if cost was determined + if cost is not None: + try: + # Approximate token counts + completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}]) + prompt_tokens = token_counter(model=llm_model, messages=prompt_messages) + total_tokens = prompt_tokens + completion_tokens + + await self.add_message( + thread_id=thread_id, + type="cost", + content={ + "cost": cost, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + "model_name": llm_model + }, + is_llm_message=False + ) + except Exception as e: + logger.error(f"Error saving cost message: {str(e)}") + # --- End Cost Calculation --- + except Exception as e: logger.error(f"Error processing response: {str(e)}", exc_info=True) yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} diff --git a/backend/agentpress/thread_manager.py b/backend/agentpress/thread_manager.py index f03459b4..8dbe7524 100644 --- a/backend/agentpress/thread_manager.py +++ b/backend/agentpress/thread_manager.py @@ -162,6 +162,8 @@ class ThreadManager: native_max_auto_continues: int = 25, max_xml_tool_calls: int = 0, include_xml_examples: bool = False, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Union[Dict[str, Any], AsyncGenerator]: """Run a conversation thread with LLM integration and tool execution. @@ -179,6 +181,8 @@ class ThreadManager: finish_reason="tool_calls" (0 disables auto-continue) max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit) include_xml_examples: Whether to include XML tool examples in the system prompt + enable_thinking: Whether to enable thinking before making a decision + reasoning_effort: The effort level for reasoning Returns: An async generator yielding response chunks or error dict @@ -313,7 +317,9 @@ Here are the XML tools available with examples: max_tokens=llm_max_tokens, tools=openapi_tool_schemas, tool_choice=tool_choice if processor_config.native_tool_calling else None, - stream=stream + stream=stream, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) logger.debug("Successfully received raw LLM API response stream/object") @@ -327,7 +333,9 @@ Here are the XML tools available with examples: response_generator = self.response_processor.process_streaming_response( llm_response=llm_response, thread_id=thread_id, - config=processor_config + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model ) return response_generator @@ -338,7 +346,9 @@ Here are the XML tools available with examples: response_generator = self.response_processor.process_non_streaming_response( llm_response=llm_response, thread_id=thread_id, - config=processor_config + config=processor_config, + prompt_messages=prepared_messages, + llm_model=llm_model ) return response_generator # Return the generator except Exception as e: diff --git a/backend/services/llm.py b/backend/services/llm.py index f3abe4df..2bcfb7ae 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -17,6 +17,8 @@ import asyncio from openai import OpenAIError import litellm from utils.logger import logger +from datetime import datetime +import traceback # litellm.set_verbose=True litellm.modify_params=True @@ -82,7 +84,9 @@ def prepare_params( api_base: Optional[str] = None, stream: bool = False, top_p: Optional[float] = None, - model_id: Optional[str] = None + model_id: Optional[str] = None, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Dict[str, Any]: """Prepare parameters for the API call.""" params = { @@ -211,6 +215,16 @@ def prepare_params( else: logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.") + # Add reasoning_effort for Anthropic models if enabled + use_thinking = enable_thinking if enable_thinking is not None else False + is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower() + + if is_anthropic and use_thinking: + effort_level = reasoning_effort if reasoning_effort else 'low' + params["reasoning_effort"] = effort_level + params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used + logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'") + return params async def make_llm_api_call( @@ -225,7 +239,9 @@ async def make_llm_api_call( api_base: Optional[str] = None, stream: bool = False, top_p: Optional[float] = None, - model_id: Optional[str] = None + model_id: Optional[str] = None, + enable_thinking: Optional[bool] = False, + reasoning_effort: Optional[str] = 'low' ) -> Union[Dict[str, Any], AsyncGenerator]: """ Make an API call to a language model using LiteLLM. @@ -243,6 +259,8 @@ async def make_llm_api_call( stream: Whether to stream the response top_p: Top-p sampling parameter model_id: Optional ARN for Bedrock inference profiles + enable_thinking: Whether to enable thinking + reasoning_effort: Level of reasoning effort Returns: Union[Dict[str, Any], AsyncGenerator]: API response or stream @@ -251,7 +269,7 @@ async def make_llm_api_call( LLMRetryError: If API call fails after retries LLMError: For other API-related errors """ - logger.debug(f"Making LLM API call to model: {model_name}") + logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})") params = prepare_params( messages=messages, model_name=model_name, @@ -264,7 +282,9 @@ async def make_llm_api_call( api_base=api_base, stream=stream, top_p=top_p, - model_id=model_id + model_id=model_id, + enable_thinking=enable_thinking, + reasoning_effort=reasoning_effort ) last_error = None From 6e4e2673d50b51b84ec315fd2dc6b9be94fa1c1a Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 05:59:00 +0100 Subject: [PATCH 07/13] cost calculation --- backend/agentpress/response_processor.py | 197 ++++++++++------------- 1 file changed, 85 insertions(+), 112 deletions(-) diff --git a/backend/agentpress/response_processor.py b/backend/agentpress/response_processor.py index 468a6b09..6dff7bf2 100644 --- a/backend/agentpress/response_processor.py +++ b/backend/agentpress/response_processor.py @@ -97,18 +97,18 @@ class ResponseProcessor: self, llm_response: AsyncGenerator, thread_id: str, + prompt_messages: List[Dict[str, Any]], + llm_model: str, config: ProcessorConfig = ProcessorConfig(), - prompt_messages: Optional[List[Dict[str, Any]]] = None, - llm_model: Optional[str] = None ) -> AsyncGenerator: """Process a streaming LLM response, handling tool calls and execution. Args: llm_response: Streaming response from the LLM thread_id: ID of the conversation thread + prompt_messages: List of messages sent to the LLM (the prompt) + llm_model: The name of the LLM model used config: Configuration for parsing and execution - prompt_messages: List of messages used for cost calculation - llm_model: Name of the LLM model used for cost calculation Yields: Formatted chunks of the response including content and tool results @@ -173,27 +173,18 @@ class ResponseProcessor: if hasattr(chunk, 'choices') and chunk.choices: delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None + # Check for and log Anthropic thinking content + if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content: + logger.info(f"[THINKING]: {delta.reasoning_content}") + # Append reasoning to main content to be saved in the final message + accumulated_content += delta.reasoning_content + # Process content chunk if delta and hasattr(delta, 'content') and delta.content: chunk_content = delta.content accumulated_content += chunk_content current_xml_content += chunk_content - # Process reasoning content if present (Anthropic) - if hasattr(delta, 'reasoning_content') and delta.reasoning_content: - logger.info(f"[THINKING]: {delta.reasoning_content}") - accumulated_content += delta.reasoning_content # Append reasoning to main content - - # Calculate cost using prompt and completion - MOVED AFTER MESSAGE SAVE - # try: - # cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content) - # tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}]) - # accumulated_cost += cost - # accumulated_token_count += tcount - # logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}") - # except Exception as e: - # logger.error(f"Error calculating cost: {str(e)}") - # Check if we've reached the XML tool call limit before yielding content if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls: # We've reached the limit, don't yield any more content @@ -362,7 +353,7 @@ class ResponseProcessor: # If we've reached the XML tool call limit, stop streaming if finish_reason == "xml_tool_limit_reached": - logger.info("Stopping stream due to XML tool call limit") + logger.info("Stopping stream processing after loop due to XML tool call limit") break # After streaming completes or is stopped due to limit, wait for any remaining tool executions @@ -483,6 +474,27 @@ class ResponseProcessor: is_llm_message=True ) + # Calculate and store cost AFTER adding the main assistant message + if accumulated_content: # Calculate cost if there was content (now includes reasoning) + try: + final_cost = completion_cost( + model=llm_model, # Use the passed model name + messages=prompt_messages, # Use the provided prompt messages + completion=accumulated_content + ) + if final_cost is not None and final_cost > 0: + logger.info(f"Calculated final cost for stream: {final_cost}") + await self.add_message( + thread_id=thread_id, + type="cost", + content={"cost": final_cost}, + is_llm_message=False # Cost is metadata, not LLM content + ) + else: + logger.info("Cost calculation resulted in zero or None, not storing cost message.") + except Exception as e: + logger.error(f"Error calculating final cost for stream: {str(e)}") + # Yield the assistant response end signal *immediately* after saving if last_assistant_message_id: yield { @@ -498,35 +510,6 @@ class ResponseProcessor: "thread_run_id": thread_run_id } - # --- Cost Calculation (moved here) --- - if prompt_messages and llm_model and accumulated_content: - try: - cost = completion_cost( - model=llm_model, - messages=prompt_messages, - completion=accumulated_content - ) - token_count = token_counter( - model=llm_model, - messages=prompt_messages + [{"role": "assistant", "content": accumulated_content}] - ) - await self.add_message( - thread_id=thread_id, - type="cost", - content={ - "cost": cost, - "prompt_tokens": token_count - token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx - "completion_tokens": token_counter(model=llm_model, messages=[{"role": "assistant", "content": accumulated_content}]), # Approx - "total_tokens": token_count, - "model_name": llm_model - }, - is_llm_message=False - ) - logger.info(f"Calculated cost for streaming response: {cost:.6f} using model {llm_model}") - except Exception as e: - logger.error(f"Error calculating cost: {str(e)}") - # --- End Cost Calculation --- - # --- Process All Tool Calls Now --- if config.execute_tools: final_tool_calls_to_process = [] @@ -664,24 +647,23 @@ class ResponseProcessor: yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} finally: - # Yield the detected finish reason if one exists and wasn't suppressed - if finish_reason and finish_reason != "xml_tool_limit_reached": + # Yield a finish signal including the final assistant message ID + if last_assistant_message_id: + # Yield the overall run end signal yield { - "type": "finish", - "finish_reason": finish_reason, + "type": "thread_run_end", + "thread_run_id": thread_run_id + } + else: + # Yield the overall run end signal + yield { + "type": "thread_run_end", "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None } - - # Yield a finish signal including the final assistant message ID - # Ensure thread_run_id is defined, even if an early error occurred - run_id = thread_run_id if 'thread_run_id' in locals() else str(uuid.uuid4()) # Fallback ID if needed - yield { - "type": "thread_run_end", - "thread_run_id": run_id - } - - # Remove old cost calculation code + pass + # track the cost and token count + # todo: there is a bug as it adds every chunk to db because finally will run every time even in yield # await self.add_message( # thread_id=thread_id, # type="cost", @@ -697,18 +679,18 @@ class ResponseProcessor: self, llm_response: Any, thread_id: str, - config: ProcessorConfig = ProcessorConfig(), - prompt_messages: Optional[List[Dict[str, Any]]] = None, - llm_model: Optional[str] = None + prompt_messages: List[Dict[str, Any]], + llm_model: str, + config: ProcessorConfig = ProcessorConfig() ) -> AsyncGenerator[Dict[str, Any], None]: """Process a non-streaming LLM response, handling tool calls and execution. Args: llm_response: Response from the LLM thread_id: ID of the conversation thread + prompt_messages: List of messages sent to the LLM (the prompt) + llm_model: The name of the LLM model used config: Configuration for parsing and execution - prompt_messages: List of messages used for cost calculation - llm_model: Name of the LLM model used for cost calculation Yields: Formatted response including content and tool results @@ -814,6 +796,41 @@ class ResponseProcessor: is_llm_message=True ) + # Calculate and store cost AFTER adding the main assistant message + if content or (config.native_tool_calling and 'native_tool_calls_for_message' in locals() and native_tool_calls_for_message): # Calculate cost if there's content or tool calls + try: + # Use the full response object for potentially more accurate cost calculation + # Pass model explicitly as it might not be reliably in response_object for all providers + # First check if response_cost is directly available in _hidden_params + final_cost = None + if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0: + final_cost = llm_response._hidden_params['response_cost'] + logger.info(f"Using response_cost from _hidden_params: {final_cost}") + + if final_cost is None: # Fall back to calculating cost if direct cost not available or zero + logger.info("Calculating cost using completion_cost function.") + final_cost = completion_cost( + completion_response=llm_response, + model=llm_model, # Use the passed model name + # prompt_messages might be needed for some models/providers + # messages=prompt_messages, # Uncomment if needed + call_type="completion" # Assuming 'completion' type for this context + ) + + if final_cost is not None and final_cost > 0: + logger.info(f"Calculated final cost for non-stream: {final_cost}") + await self.add_message( + thread_id=thread_id, + type="cost", + content={"cost": final_cost}, + is_llm_message=False # Cost is metadata + ) + else: + logger.info("Final cost is zero or None, not storing cost message.") + + except Exception as e: + logger.error(f"Error calculating final cost for non-stream: {str(e)}") + # Yield content first yield { "type": "content", @@ -904,50 +921,6 @@ class ResponseProcessor: "thread_run_id": thread_run_id } - # --- Cost Calculation (moved here) --- - if prompt_messages and llm_model: - cost = None - # Attempt to get cost from LiteLLM response first - if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params: - cost = llm_response._hidden_params['response_cost'] - logger.info(f"Using pre-calculated cost from LiteLLM: {cost:.6f}") - - # If no pre-calculated cost, calculate manually - if cost is None: - try: - cost = completion_cost( - model=llm_model, - messages=prompt_messages, - completion=content # Use extracted content - ) - logger.info(f"Manually calculated cost for non-streaming response: {cost:.6f} using model {llm_model}") - except Exception as e: - logger.error(f"Error calculating cost: {str(e)}") - - # Add cost message if cost was determined - if cost is not None: - try: - # Approximate token counts - completion_tokens = token_counter(model=llm_model, messages=[{"role": "assistant", "content": content}]) - prompt_tokens = token_counter(model=llm_model, messages=prompt_messages) - total_tokens = prompt_tokens + completion_tokens - - await self.add_message( - thread_id=thread_id, - type="cost", - content={ - "cost": cost, - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "model_name": llm_model - }, - is_llm_message=False - ) - except Exception as e: - logger.error(f"Error saving cost message: {str(e)}") - # --- End Cost Calculation --- - except Exception as e: logger.error(f"Error processing response: {str(e)}", exc_info=True) yield {"type": "error", "message": str(e), "thread_run_id": thread_run_id if 'thread_run_id' in locals() else None} From aa8c5c6d783f72a0f7fbb9eec58b142e8b6270b0 Mon Sep 17 00:00:00 2001 From: LE Quoc Dat Date: Fri, 18 Apr 2025 06:42:57 +0100 Subject: [PATCH 08/13] model mapping & UI frontend tested --- backend/agent/api.py | 8 +- backend/agent/run.py | 4 +- .../app/dashboard/agents/[threadId]/page.tsx | 65 ++--- frontend/src/app/dashboard/page.tsx | 8 +- frontend/src/components/thread/chat-input.tsx | 244 ++++++++++++------ frontend/src/lib/api.ts | 11 +- 6 files changed, 227 insertions(+), 113 deletions(-) diff --git a/backend/agent/api.py b/backend/agent/api.py index d300ca9d..a6df77d2 100644 --- a/backend/agent/api.py +++ b/backend/agent/api.py @@ -27,6 +27,12 @@ db = None # In-memory storage for active agent runs and their responses active_agent_runs: Dict[str, List[Any]] = {} +MODEL_NAME_ALIASES = { + "sonnet-3.7": "anthropic/claude-3-7-sonnet-latest", + "gpt-4.1": "openai/gpt-4.1-2025-04-14", + "gemini-flash-2.5": "openrouter/google/gemini-2.5-flash-preview", +} + class AgentStartRequest(BaseModel): model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest" enable_thinking: Optional[bool] = False @@ -331,7 +337,7 @@ async def start_agent( instance_id=instance_id, project_id=project_id, sandbox=sandbox, - model_name=body.model_name, + model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name), enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort, stream=body.stream # Pass stream parameter diff --git a/backend/agent/run.py b/backend/agent/run.py index d4fddc61..4a2e1040 100644 --- a/backend/agent/run.py +++ b/backend/agent/run.py @@ -118,6 +118,8 @@ async def run_agent( except Exception as e: print(f"Error parsing browser state: {e}") # print(latest_browser_state.data[0]) + + max_tokens = 64000 if "sonnet" in model_name.lower() else None response = await thread_manager.run_thread( thread_id=thread_id, @@ -125,7 +127,7 @@ async def run_agent( stream=stream, llm_model=model_name, llm_temperature=0, - llm_max_tokens=64000, + llm_max_tokens=max_tokens, tool_choice="auto", max_xml_tool_calls=1, temporary_message=temporary_message, diff --git a/frontend/src/app/dashboard/agents/[threadId]/page.tsx b/frontend/src/app/dashboard/agents/[threadId]/page.tsx index 689d4749..44419be3 100644 --- a/frontend/src/app/dashboard/agents/[threadId]/page.tsx +++ b/frontend/src/app/dashboard/agents/[threadId]/page.tsx @@ -522,47 +522,54 @@ export default function ThreadPage({ params }: { params: Promise } }; }, [threadId, handleStreamAgent, agentRunId, agentStatus, isStreaming]); - const handleSubmitMessage = useCallback(async (message: string) => { - if (!message.trim()) return; + const handleSubmitMessage = async (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => { + if (agentStatus === 'running') { + if (agentRunId && onStopAgent) { + onStopAgent(); + } + return; + } setIsSending(true); + setAgentStatus('running'); + setSidePanelContent(null); + setCurrentPairIndex(null); try { - // Add the message optimistically to the UI - const userMessage: ApiMessage = { - role: 'user', - content: message - }; + // First, add user message to the thread + await addUserMessage(threadId, message); - setMessages(prev => [...prev, userMessage]); + // Then fetch updated messages to include the user message + const updatedMessages = await getMessages(threadId); + setMessages(updatedMessages as ApiMessage[]); + + // Clear any input setNewMessage(''); - scrollToBottom(); - // Send to the API and start agent in parallel - const [messageResult, agentResult] = await Promise.all([ - addUserMessage(threadId, userMessage.content).catch(err => { - throw new Error('Failed to send message: ' + err.message); - }), - startAgent(threadId).catch(err => { - throw new Error('Failed to start agent: ' + err.message); - }) - ]); + // Start the agent + console.log('[SUBMIT] Starting agent'); + const { agent_run_id } = await startAgent(threadId, { + model_name: options?.model_name, + enable_thinking: options?.enable_thinking, + stream: true + }); - setAgentRunId(agentResult.agent_run_id); - setAgentStatus('running'); + // Set agent run ID and start streaming + console.log(`[SUBMIT] Agent started with ID: ${agent_run_id}`); + setAgentRunId(agent_run_id); + handleStreamAgent(agent_run_id); - // Start streaming the agent's responses immediately - handleStreamAgent(agentResult.agent_run_id); - } catch (err) { - console.error('Error sending message:', err); - toast.error(err instanceof Error ? err.message : 'Failed to send message'); - - // Remove the optimistically added message on error - setMessages(prev => prev.slice(0, -1)); + // Scroll to the bottom after sending the message + setTimeout(() => scrollToBottom(), 100); + } catch (error) { + console.error('[SUBMIT] Error submitting message:', error); + const errorMessage = error instanceof Error ? error.message : 'Unknown error sending message'; + toast.error(errorMessage); + setAgentStatus('idle'); } finally { setIsSending(false); } - }, [threadId, handleStreamAgent]); + }; const handleStopAgent = useCallback(async () => { if (!agentRunId) { diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index 6a6bf875..99a5b62e 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -12,7 +12,7 @@ function DashboardContent() { const [isSubmitting, setIsSubmitting] = useState(false); const router = useRouter(); - const handleSubmit = async (message: string) => { + const handleSubmit = async (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => { if (!message.trim() || isSubmitting) return; setIsSubmitting(true); @@ -34,7 +34,11 @@ function DashboardContent() { await addUserMessage(thread.thread_id, message.trim()); // 4. Start the agent with the thread ID - const agentRun = await startAgent(thread.thread_id); + const agentRun = await startAgent(thread.thread_id, { + model_name: options?.model_name, + enable_thinking: options?.enable_thinking, + stream: true + }); // 5. Navigate to the new agent's thread page router.push(`/dashboard/agents/${thread.thread_id}`); diff --git a/frontend/src/components/thread/chat-input.tsx b/frontend/src/components/thread/chat-input.tsx index ebafed1b..e118ec08 100644 --- a/frontend/src/components/thread/chat-input.tsx +++ b/frontend/src/components/thread/chat-input.tsx @@ -3,7 +3,7 @@ import React, { useState, useRef, useEffect } from 'react'; import { Textarea } from "@/components/ui/textarea"; import { Button } from "@/components/ui/button"; -import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText } from "lucide-react"; +import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText, ChevronDown } from "lucide-react"; import { createClient } from "@/lib/supabase/client"; import { toast } from "sonner"; import { AnimatePresence, motion } from "framer-motion"; @@ -13,13 +13,23 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/ui/tooltip"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; import { cn } from "@/lib/utils"; // Define API_URL const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || ''; +// Local storage keys +const STORAGE_KEY_MODEL = 'suna-preferred-model'; +const STORAGE_KEY_THINKING = 'suna-enable-thinking'; + interface ChatInputProps { - onSubmit: (message: string) => void; + onSubmit: (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => void; placeholder?: string; loading?: boolean; disabled?: boolean; @@ -40,7 +50,7 @@ interface UploadedFile { export function ChatInput({ onSubmit, - placeholder = "Type your message... (Enter to send, Shift+Enter for new line)", + placeholder = "Describe what you need help with...", loading = false, disabled = false, isAgentRunning = false, @@ -52,12 +62,35 @@ export function ChatInput({ sandboxId }: ChatInputProps) { const [inputValue, setInputValue] = useState(value || ""); + const [selectedModel, setSelectedModel] = useState("sonnet-3.7"); + const [enableThinking, setEnableThinking] = useState(false); const textareaRef = useRef(null); const fileInputRef = useRef(null); const [uploadedFiles, setUploadedFiles] = useState([]); const [isUploading, setIsUploading] = useState(false); const [isDraggingOver, setIsDraggingOver] = useState(false); + // Load saved preferences from localStorage on mount + useEffect(() => { + if (typeof window !== 'undefined') { + try { + // Load selected model + const savedModel = localStorage.getItem(STORAGE_KEY_MODEL); + if (savedModel) { + setSelectedModel(savedModel); + } + + // Load thinking preference + const savedThinking = localStorage.getItem(STORAGE_KEY_THINKING); + if (savedThinking === 'true') { + setEnableThinking(true); + } + } catch (error) { + console.warn('Failed to load preferences from localStorage:', error); + } + } + }, []); + // Allow controlled or uncontrolled usage const isControlled = value !== undefined && onChange !== undefined; @@ -82,7 +115,7 @@ export function ChatInput({ const adjustHeight = () => { textarea.style.height = 'auto'; - const newHeight = Math.min(textarea.scrollHeight, 200); // Max height of 200px + const newHeight = Math.min(Math.max(textarea.scrollHeight, 50), 200); // Min 50px, max 200px textarea.style.height = `${newHeight}px`; }; @@ -93,6 +126,31 @@ export function ChatInput({ return () => window.removeEventListener('resize', adjustHeight); }, [inputValue]); + const handleModelChange = (model: string) => { + setSelectedModel(model); + // Save to localStorage + if (typeof window !== 'undefined') { + localStorage.setItem(STORAGE_KEY_MODEL, model); + } + + // Reset thinking when changing away from sonnet-3.7 + if (model !== "sonnet-3.7") { + setEnableThinking(false); + if (typeof window !== 'undefined') { + localStorage.setItem(STORAGE_KEY_THINKING, 'false'); + } + } + }; + + const toggleThinking = () => { + const newValue = !enableThinking; + setEnableThinking(newValue); + // Save to localStorage + if (typeof window !== 'undefined') { + localStorage.setItem(STORAGE_KEY_THINKING, newValue.toString()); + } + }; + const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); if ((!inputValue.trim() && uploadedFiles.length === 0) || loading || (disabled && !isAgentRunning)) return; @@ -112,7 +170,10 @@ export function ChatInput({ message = message ? `${message}\n\n${fileInfo}` : fileInfo; } - onSubmit(message); + onSubmit(message, { + model_name: selectedModel, + enable_thinking: enableThinking + }); if (!isControlled) { setInputValue(""); @@ -273,11 +334,18 @@ export function ChatInput({ } }; + // Map of model display names + const modelDisplayNames = { + "sonnet-3.7": "Sonnet 3.7", + "gpt-4.1": "GPT-4.1", + "gemini-flash-2.5": "Gemini Flash 2.5" + }; + return (
0 ? "border-border" : "border-input", + "w-full border rounded-xl transition-all duration-200 shadow-sm bg-[#1a1a1a] border-gray-800", + uploadedFiles.length > 0 ? "border-border" : "border-gray-800", isDraggingOver ? "border-primary border-dashed bg-primary/5" : "" )} onDragOver={handleDragOver} @@ -300,18 +368,18 @@ export function ChatInput({ animate={{ opacity: 1, scale: 1 }} exit={{ opacity: 0, scale: 0.9 }} transition={{ duration: 0.15 }} - className="px-2 py-1 bg-secondary/20 rounded-full flex items-center gap-1.5 group border border-secondary/30 hover:border-secondary/50 transition-colors text-sm" + className="px-2 py-1 bg-gray-800 rounded-full flex items-center gap-1.5 group border border-gray-700 hover:border-gray-600 transition-colors text-sm" > {getFileIcon(file.name)} - {file.name} - + {file.name} + ({formatFileSize(file.size)})
-
+
)} -
- {isDraggingOver && ( -
-
- -

Drop files to upload

-
+
+ {/* Left side - Model selector and Think button */} + {!isAgentRunning && ( +
+ {/* Model selector button with dropdown */} + + + + + + handleModelChange("sonnet-3.7")} className="hover:bg-gray-800"> + Sonnet 3.7 + + handleModelChange("gpt-4.1")} className="hover:bg-gray-800"> + GPT-4.1 + + handleModelChange("gemini-flash-2.5")} className="hover:bg-gray-800"> + Gemini Flash 2.5 + + + + + {/* Think button - only for Sonnet 3.7 */} + {selectedModel === "sonnet-3.7" && ( + + )}
)} -