mirror of https://github.com/kortix-ai/suna.git
commit
db46c1aee5
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import json
|
||||
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
|||
import uuid
|
||||
from typing import Optional, List, Dict, Any
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
|
@ -26,6 +27,19 @@ db = None
|
|||
# In-memory storage for active agent runs and their responses
|
||||
active_agent_runs: Dict[str, List[Any]] = {}
|
||||
|
||||
MODEL_NAME_ALIASES = {
|
||||
"sonnet-3.7": "anthropic/claude-3-7-sonnet-latest",
|
||||
"gpt-4.1": "openai/gpt-4.1-2025-04-14",
|
||||
"gemini-flash-2.5": "openrouter/google/gemini-2.5-flash-preview",
|
||||
}
|
||||
|
||||
class AgentStartRequest(BaseModel):
|
||||
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||
enable_thinking: Optional[bool] = False
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
stream: Optional[bool] = True
|
||||
enable_context_manager: Optional[bool] = False
|
||||
|
||||
def initialize(
|
||||
_thread_manager: ThreadManager,
|
||||
_db: DBConnection,
|
||||
|
@ -237,9 +251,13 @@ async def _cleanup_agent_run(agent_run_id: str):
|
|||
# Non-fatal error, can continue
|
||||
|
||||
@router.post("/thread/{thread_id}/agent/start")
|
||||
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
async def start_agent(
|
||||
thread_id: str,
|
||||
body: AgentStartRequest = Body(...), # Accept request body
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
):
|
||||
"""Start an agent for a specific thread in the background."""
|
||||
logger.info(f"Starting new agent for thread: {thread_id}")
|
||||
logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager}")
|
||||
client = await db.client
|
||||
|
||||
# Verify user has access to this thread
|
||||
|
@ -314,7 +332,18 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
|||
|
||||
# Run the agent in the background
|
||||
task = asyncio.create_task(
|
||||
run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox)
|
||||
run_agent_background(
|
||||
agent_run_id=agent_run_id,
|
||||
thread_id=thread_id,
|
||||
instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
sandbox=sandbox,
|
||||
model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name),
|
||||
enable_thinking=body.enable_thinking,
|
||||
reasoning_effort=body.reasoning_effort,
|
||||
stream=body.stream,
|
||||
enable_context_manager=body.enable_context_manager
|
||||
)
|
||||
)
|
||||
|
||||
# Set a callback to clean up when task is done
|
||||
|
@ -441,9 +470,20 @@ async def stream_agent_run(
|
|||
}
|
||||
)
|
||||
|
||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox):
|
||||
async def run_agent_background(
|
||||
agent_run_id: str,
|
||||
thread_id: str,
|
||||
instance_id: str,
|
||||
project_id: str,
|
||||
sandbox,
|
||||
model_name: str,
|
||||
enable_thinking: Optional[bool],
|
||||
reasoning_effort: Optional[str],
|
||||
stream: bool,
|
||||
enable_context_manager: bool
|
||||
):
|
||||
"""Run the agent in the background and handle status updates."""
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}, context_manager={enable_context_manager}")
|
||||
client = await db.client
|
||||
|
||||
# Tracking variables
|
||||
|
@ -561,9 +601,17 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
|||
try:
|
||||
# Run the agent
|
||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||
agent_gen = run_agent(thread_id, stream=True,
|
||||
thread_manager=thread_manager, project_id=project_id,
|
||||
sandbox=sandbox)
|
||||
agent_gen = run_agent(
|
||||
thread_id=thread_id,
|
||||
project_id=project_id,
|
||||
stream=stream,
|
||||
thread_manager=thread_manager,
|
||||
sandbox=sandbox,
|
||||
model_name=model_name,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
||||
# Collect all responses to save to database
|
||||
all_responses = []
|
||||
|
|
|
@ -22,7 +22,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread
|
|||
|
||||
load_dotenv()
|
||||
|
||||
async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
|
||||
async def run_agent(
|
||||
thread_id: str,
|
||||
project_id: str,
|
||||
sandbox,
|
||||
stream: bool,
|
||||
thread_manager: Optional[ThreadManager] = None,
|
||||
native_max_auto_continues: int = 25,
|
||||
max_iterations: int = 150,
|
||||
model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
||||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low',
|
||||
enable_context_manager: bool = True
|
||||
):
|
||||
"""Run the development agent with specified configuration."""
|
||||
|
||||
if not thread_manager:
|
||||
|
@ -42,17 +54,15 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
|||
thread_manager.add_tool(SandboxDeployTool, sandbox=sandbox)
|
||||
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
|
||||
|
||||
if os.getenv("EXA_API_KEY"):
|
||||
if os.getenv("TAVILY_API_KEY"):
|
||||
thread_manager.add_tool(WebSearchTool)
|
||||
else:
|
||||
print("TAVILY_API_KEY not found, WebSearchTool will not be available.")
|
||||
|
||||
if os.getenv("RAPID_API_KEY"):
|
||||
thread_manager.add_tool(DataProvidersTool)
|
||||
|
||||
xml_examples = ""
|
||||
for tag_name, example in thread_manager.tool_registry.get_xml_examples().items():
|
||||
xml_examples += f"{example}\n"
|
||||
|
||||
system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"<tool_examples>\n{xml_examples}\n</tool_examples>" }
|
||||
system_message = { "role": "system", "content": get_system_prompt() }
|
||||
|
||||
iteration_count = 0
|
||||
continue_execution = True
|
||||
|
@ -112,14 +122,16 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
|||
except Exception as e:
|
||||
print(f"Error parsing browser state: {e}")
|
||||
# print(latest_browser_state.data[0])
|
||||
|
||||
max_tokens = 64000 if "sonnet" in model_name.lower() else None
|
||||
|
||||
response = await thread_manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_prompt=system_message,
|
||||
stream=stream,
|
||||
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
||||
llm_model=model_name,
|
||||
llm_temperature=0,
|
||||
llm_max_tokens=64000,
|
||||
llm_max_tokens=max_tokens,
|
||||
tool_choice="auto",
|
||||
max_xml_tool_calls=1,
|
||||
temporary_message=temporary_message,
|
||||
|
@ -133,6 +145,9 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
|
|||
),
|
||||
native_max_auto_continues=native_max_auto_continues,
|
||||
include_xml_examples=True,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||
|
@ -267,7 +282,16 @@ async def test_agent():
|
|||
|
||||
print("\n👋 Test completed. Goodbye!")
|
||||
|
||||
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
|
||||
async def process_agent_response(
|
||||
thread_id: str,
|
||||
project_id: str,
|
||||
thread_manager: ThreadManager,
|
||||
stream: bool = True,
|
||||
model_name: str = "anthropic/claude-3-7-sonnet-latest",
|
||||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low',
|
||||
enable_context_manager: bool = True
|
||||
):
|
||||
"""Process the streaming response from the agent."""
|
||||
chunk_counter = 0
|
||||
current_response = ""
|
||||
|
@ -276,9 +300,20 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager
|
|||
# Create a test sandbox for processing
|
||||
sandbox_pass = str(uuid4())
|
||||
sandbox = create_sandbox(sandbox_pass)
|
||||
print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
||||
print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
|
||||
|
||||
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
||||
async for chunk in run_agent(
|
||||
thread_id=thread_id,
|
||||
project_id=project_id,
|
||||
sandbox=sandbox,
|
||||
stream=stream,
|
||||
thread_manager=thread_manager,
|
||||
native_max_auto_continues=25,
|
||||
model_name=model_name,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
):
|
||||
chunk_counter += 1
|
||||
# print(f"CHUNK: {chunk}") # Uncomment for debugging
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from exa_py import Exa
|
||||
from tavily import AsyncTavilyClient
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
@ -15,10 +16,12 @@ class WebSearchTool(Tool):
|
|||
# Load environment variables
|
||||
load_dotenv()
|
||||
# Use the provided API key or get it from environment variables
|
||||
self.api_key = api_key or os.getenv("EXA_API_KEY")
|
||||
self.api_key = api_key or os.getenv("TAVILY_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("EXA_API_KEY not found in environment variables")
|
||||
self.exa = Exa(api_key=self.api_key)
|
||||
raise ValueError("TAVILY_API_KEY not found in environment variables")
|
||||
|
||||
# Tavily asynchronous search client
|
||||
self.tavily_client = AsyncTavilyClient(api_key=self.api_key)
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
|
@ -111,57 +114,49 @@ class WebSearchTool(Tool):
|
|||
if not query or not isinstance(query, str):
|
||||
return self.fail_response("A valid search query is required.")
|
||||
|
||||
# Basic parameters - use only the minimum required to avoid API errors
|
||||
params = {
|
||||
"query": query,
|
||||
"type": "auto",
|
||||
"livecrawl": "auto"
|
||||
}
|
||||
|
||||
# Handle summary parameter (boolean conversion)
|
||||
if summary is None:
|
||||
params["summary"] = True
|
||||
elif isinstance(summary, bool):
|
||||
params["summary"] = summary
|
||||
elif isinstance(summary, str):
|
||||
params["summary"] = summary.lower() == "true"
|
||||
else:
|
||||
params["summary"] = True
|
||||
|
||||
# Handle num_results parameter (integer conversion)
|
||||
# ---------- Tavily search parameters ----------
|
||||
# num_results normalisation (1‑50)
|
||||
if num_results is None:
|
||||
params["num_results"] = 20
|
||||
num_results = 20
|
||||
elif isinstance(num_results, int):
|
||||
params["num_results"] = max(1, min(num_results, 50))
|
||||
num_results = max(1, min(num_results, 50))
|
||||
elif isinstance(num_results, str):
|
||||
try:
|
||||
params["num_results"] = max(1, min(int(num_results), 50))
|
||||
num_results = max(1, min(int(num_results), 50))
|
||||
except ValueError:
|
||||
params["num_results"] = 20
|
||||
num_results = 20
|
||||
else:
|
||||
params["num_results"] = 20
|
||||
|
||||
# Execute the search with minimal parameters
|
||||
search_response = self.exa.search_and_contents(**params)
|
||||
|
||||
# Format the results
|
||||
num_results = 20
|
||||
|
||||
# Execute the search with Tavily
|
||||
search_response = await self.tavily_client.search(
|
||||
query=query,
|
||||
max_results=num_results,
|
||||
include_answer=False,
|
||||
include_images=False,
|
||||
)
|
||||
|
||||
# `tavily` may return a dict with `results` or a bare list
|
||||
raw_results = (
|
||||
search_response.get("results")
|
||||
if isinstance(search_response, dict)
|
||||
else search_response
|
||||
)
|
||||
|
||||
formatted_results = []
|
||||
for result in search_response.results:
|
||||
for result in raw_results:
|
||||
formatted_result = {
|
||||
"Title": result.title,
|
||||
"URL": result.url
|
||||
"Title": result.get("title"),
|
||||
"URL": result.get("url"),
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
formatted_result["Summary"] = result.summary
|
||||
|
||||
if hasattr(result, 'published_date') and result.published_date:
|
||||
formatted_result["Published Date"] = result.published_date
|
||||
|
||||
if hasattr(result, 'score'):
|
||||
formatted_result["Score"] = result.score
|
||||
|
||||
|
||||
if summary:
|
||||
# Prefer full content; fall back to description
|
||||
if result.get("content"):
|
||||
formatted_result["Summary"] = result["content"]
|
||||
elif result.get("description"):
|
||||
formatted_result["Summary"] = result["description"]
|
||||
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
return self.success_response(formatted_results)
|
||||
|
@ -243,26 +238,50 @@ class WebSearchTool(Tool):
|
|||
else:
|
||||
return self.fail_response("URL must be a string.")
|
||||
|
||||
# Execute the crawl with the parsed URL
|
||||
result = self.exa.get_contents(
|
||||
[url],
|
||||
text=True,
|
||||
livecrawl="auto"
|
||||
)
|
||||
|
||||
# Format the results to include all available fields
|
||||
formatted_results = []
|
||||
for content in result.results:
|
||||
formatted_result = {
|
||||
"Title": content.title,
|
||||
"URL": content.url,
|
||||
"Text": content.text
|
||||
# ---------- Tavily extract endpoint ----------
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add optional fields if they exist
|
||||
if hasattr(content, 'published_date') and content.published_date:
|
||||
formatted_result["Published Date"] = content.published_date
|
||||
|
||||
payload = {
|
||||
"urls": url,
|
||||
"include_images": False,
|
||||
"extract_depth": "basic",
|
||||
}
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/extract",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
print(f"--- Raw Tavily Response ---")
|
||||
print(data)
|
||||
print(f"--------------------------")
|
||||
|
||||
# Normalise Tavily extract output to a list of dicts
|
||||
extracted = []
|
||||
if isinstance(data, list):
|
||||
extracted = data
|
||||
elif isinstance(data, dict):
|
||||
if "results" in data and isinstance(data["results"], list):
|
||||
extracted = data["results"]
|
||||
elif "urls" in data and isinstance(data["urls"], dict):
|
||||
extracted = list(data["urls"].values())
|
||||
else:
|
||||
extracted = [data]
|
||||
|
||||
formatted_results = []
|
||||
for item in extracted:
|
||||
formatted_result = {
|
||||
"Title": item.get("title"),
|
||||
"URL": item.get("url") or url,
|
||||
"Text": item.get("content") or item.get("text") or "",
|
||||
}
|
||||
if item.get("published_date"):
|
||||
formatted_result["Published Date"] = item["published_date"]
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
return self.success_response(formatted_results)
|
||||
|
@ -279,27 +298,27 @@ class WebSearchTool(Tool):
|
|||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# async def test_web_search():
|
||||
# """Test function for the web search tool"""
|
||||
# search_tool = WebSearchTool()
|
||||
# result = await search_tool.web_search(
|
||||
# query="rubber gym mats best prices comparison",
|
||||
# summary=True,
|
||||
# num_results=20
|
||||
# )
|
||||
# print(result)
|
||||
async def test_web_search():
|
||||
"""Test function for the web search tool"""
|
||||
search_tool = WebSearchTool()
|
||||
result = await search_tool.web_search(
|
||||
query="rubber gym mats best prices comparison",
|
||||
summary=True,
|
||||
num_results=20
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def test_crawl_webpage():
|
||||
"""Test function for the webpage crawl tool"""
|
||||
search_tool = WebSearchTool()
|
||||
result = await search_tool.crawl_webpage(
|
||||
url="https://example.com"
|
||||
url="https://google.com"
|
||||
)
|
||||
print(result)
|
||||
|
||||
async def run_tests():
|
||||
"""Run all test functions"""
|
||||
# await test_web_search()
|
||||
await test_web_search()
|
||||
await test_crawl_webpage()
|
||||
|
||||
asyncio.run(run_tests())
|
||||
asyncio.run(run_tests())
|
|
@ -96,10 +96,19 @@ class ResponseProcessor:
|
|||
self,
|
||||
llm_response: AsyncGenerator,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process a streaming LLM response, handling tool calls and execution.
|
||||
|
||||
Args:
|
||||
llm_response: Streaming response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
|
||||
Yields:
|
||||
Complete message objects matching the DB schema, except for content chunks.
|
||||
"""
|
||||
|
@ -144,8 +153,14 @@ class ResponseProcessor:
|
|||
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None
|
||||
|
||||
# Check for and log Anthropic thinking content
|
||||
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||
logger.info(f"[THINKING]: {delta.reasoning_content}")
|
||||
# Append reasoning to main content to be saved in the final message
|
||||
accumulated_content += delta.reasoning_content
|
||||
|
||||
# --- Process Content Chunk ---
|
||||
# Process content chunk
|
||||
if delta and hasattr(delta, 'content') and delta.content:
|
||||
chunk_content = delta.content
|
||||
accumulated_content += chunk_content
|
||||
|
@ -263,8 +278,8 @@ class ResponseProcessor:
|
|||
tool_index += 1
|
||||
|
||||
if finish_reason == "xml_tool_limit_reached":
|
||||
logger.info("Stopping stream due to XML tool call limit")
|
||||
break # Exit the async for loop
|
||||
logger.info("Stopping stream processing after loop due to XML tool call limit")
|
||||
break
|
||||
|
||||
# --- After Streaming Loop ---
|
||||
|
||||
|
@ -529,10 +544,19 @@ class ResponseProcessor:
|
|||
self,
|
||||
llm_response: Any,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig()
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process a non-streaming LLM response, handling tool calls and execution.
|
||||
|
||||
|
||||
Args:
|
||||
llm_response: Response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
|
||||
Yields:
|
||||
Complete message objects matching the DB schema.
|
||||
"""
|
||||
|
|
|
@ -161,6 +161,9 @@ class ThreadManager:
|
|||
native_max_auto_continues: int = 25,
|
||||
max_xml_tool_calls: int = 0,
|
||||
include_xml_examples: bool = False,
|
||||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low',
|
||||
enable_context_manager: bool = True
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""Run a conversation thread with LLM integration and tool execution.
|
||||
|
||||
|
@ -178,6 +181,9 @@ class ThreadManager:
|
|||
finish_reason="tool_calls" (0 disables auto-continue)
|
||||
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
|
||||
include_xml_examples: Whether to include XML tool examples in the system prompt
|
||||
enable_thinking: Whether to enable thinking before making a decision
|
||||
reasoning_effort: The effort level for reasoning
|
||||
enable_context_manager: Whether to enable automatic context summarization.
|
||||
|
||||
Returns:
|
||||
An async generator yielding response chunks or error dict
|
||||
|
@ -187,6 +193,52 @@ class ThreadManager:
|
|||
logger.debug(f"Parameters: model={llm_model}, temperature={llm_temperature}, max_tokens={llm_max_tokens}")
|
||||
logger.debug(f"Auto-continue: max={native_max_auto_continues}, XML tool limit={max_xml_tool_calls}")
|
||||
|
||||
# Use a default config if none was provided (needed for XML examples check)
|
||||
if processor_config is None:
|
||||
processor_config = ProcessorConfig()
|
||||
|
||||
# Apply max_xml_tool_calls if specified and not already set in config
|
||||
if max_xml_tool_calls > 0 and not processor_config.max_xml_tool_calls:
|
||||
processor_config.max_xml_tool_calls = max_xml_tool_calls
|
||||
|
||||
# Create a working copy of the system prompt to potentially modify
|
||||
working_system_prompt = system_prompt.copy()
|
||||
|
||||
# Add XML examples to system prompt if requested, do this only ONCE before the loop
|
||||
if include_xml_examples and processor_config.xml_tool_calling:
|
||||
xml_examples = self.tool_registry.get_xml_examples()
|
||||
if xml_examples:
|
||||
examples_content = """
|
||||
--- XML TOOL CALLING ---
|
||||
|
||||
In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format.
|
||||
Format your tool calls using the specified XML tags. Place parameters marked as 'attribute' within the opening tag (e.g., `<tag attribute='value'>`). Place parameters marked as 'content' between the opening and closing tags. Place parameters marked as 'element' within their own child tags (e.g., `<tag><element>value</element></tag>`). Refer to the examples provided below for the exact structure of each tool.
|
||||
String and scalar parameters should be specified as attributes, while content goes between tags.
|
||||
Note that spaces for string values are not stripped. The output is parsed with regular expressions.
|
||||
|
||||
Here are the XML tools available with examples:
|
||||
"""
|
||||
for tag_name, example in xml_examples.items():
|
||||
examples_content += f"<{tag_name}> Example: {example}\\n"
|
||||
|
||||
system_content = working_system_prompt.get('content')
|
||||
|
||||
if isinstance(system_content, str):
|
||||
working_system_prompt['content'] += examples_content
|
||||
logger.debug("Appended XML examples to string system prompt content.")
|
||||
elif isinstance(system_content, list):
|
||||
appended = False
|
||||
for item in working_system_prompt['content']: # Modify the copy
|
||||
if isinstance(item, dict) and item.get('type') == 'text' and 'text' in item:
|
||||
item['text'] += examples_content
|
||||
logger.debug("Appended XML examples to the first text block in list system prompt content.")
|
||||
appended = True
|
||||
break
|
||||
if not appended:
|
||||
logger.warning("System prompt content is a list but no text block found to append XML examples.")
|
||||
else:
|
||||
logger.warning(f"System prompt content is of unexpected type ({type(system_content)}), cannot add XML examples.")
|
||||
|
||||
# Control whether we need to auto-continue due to tool_calls finish reason
|
||||
auto_continue = True
|
||||
auto_continue_count = 0
|
||||
|
@ -195,81 +247,46 @@ class ThreadManager:
|
|||
async def _run_once(temp_msg=None):
|
||||
try:
|
||||
# Ensure processor_config is available in this scope
|
||||
nonlocal processor_config
|
||||
|
||||
# Use a default config if none was provided
|
||||
if processor_config is None:
|
||||
processor_config = ProcessorConfig()
|
||||
|
||||
# Apply max_xml_tool_calls if specified and not already set
|
||||
if max_xml_tool_calls > 0:
|
||||
processor_config.max_xml_tool_calls = max_xml_tool_calls
|
||||
|
||||
# Add XML examples to system prompt if requested
|
||||
if include_xml_examples and processor_config.xml_tool_calling:
|
||||
xml_examples = self.tool_registry.get_xml_examples()
|
||||
if xml_examples:
|
||||
# logger.debug(f"Adding {len(xml_examples)} XML examples to system prompt")
|
||||
|
||||
# Create or append to content
|
||||
if isinstance(system_prompt['content'], str):
|
||||
examples_content = """
|
||||
--- XML TOOL CALLING ---
|
||||
|
||||
In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format.
|
||||
{{ FORMATTING INSTRUCTIONS }}
|
||||
String and scalar parameters should be specified as attributes, while content goes between tags.
|
||||
Note that spaces for string values are not stripped. The output is parsed with regular expressions.
|
||||
|
||||
Here are the XML tools available with examples:
|
||||
"""
|
||||
for tag_name, example in xml_examples.items():
|
||||
examples_content += f"<{tag_name}> Example: {example}\n"
|
||||
|
||||
system_prompt['content'] += examples_content
|
||||
else:
|
||||
# If content is not a string (might be a list or dict), log a warning
|
||||
logger.warning("System prompt content is not a string, cannot add XML examples")
|
||||
nonlocal processor_config
|
||||
# Note: processor_config is now guaranteed to exist due to check above
|
||||
|
||||
# 1. Get messages from thread for LLM call
|
||||
messages = await self.get_llm_messages(thread_id)
|
||||
|
||||
# 2. Check token count before proceeding
|
||||
# Use litellm to count tokens in the messages
|
||||
token_count = 0
|
||||
try:
|
||||
from litellm import token_counter
|
||||
token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
|
||||
# Use the potentially modified working_system_prompt for token counting
|
||||
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
token_threshold = self.context_manager.token_threshold
|
||||
logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)")
|
||||
|
||||
# If we're over the threshold, summarize the thread
|
||||
if token_count >= token_threshold:
|
||||
if token_count >= token_threshold and enable_context_manager:
|
||||
logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...")
|
||||
|
||||
# Create summary using context manager
|
||||
summarized = await self.context_manager.check_and_summarize_if_needed(
|
||||
thread_id=thread_id,
|
||||
add_message_callback=self.add_message,
|
||||
model=llm_model,
|
||||
force=True # Force summarization
|
||||
force=True
|
||||
)
|
||||
|
||||
if summarized:
|
||||
# If summarization was successful, get the updated messages
|
||||
# This will now include the summary message and only messages after it
|
||||
logger.info("Summarization complete, fetching updated messages with summary")
|
||||
messages = await self.get_llm_messages(thread_id)
|
||||
# Recount tokens after summarization
|
||||
new_token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
|
||||
# Recount tokens after summarization, using the modified prompt
|
||||
new_token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
|
||||
logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}")
|
||||
else:
|
||||
logger.warning("Summarization failed or wasn't needed - proceeding with original messages")
|
||||
elif not enable_context_manager: # Added condition for clarity
|
||||
logger.info("Automatic summarization disabled. Skipping token count check and summarization.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting tokens or summarizing: {str(e)}")
|
||||
|
||||
# 3. Prepare messages for LLM call + add temporary message if it exists
|
||||
prepared_messages = [system_prompt]
|
||||
# Use the working_system_prompt which may contain the XML examples
|
||||
prepared_messages = [working_system_prompt]
|
||||
|
||||
# Find the last user message index
|
||||
last_user_index = -1
|
||||
|
@ -306,13 +323,15 @@ Here are the XML tools available with examples:
|
|||
logger.debug("Making LLM API call")
|
||||
try:
|
||||
llm_response = await make_llm_api_call(
|
||||
prepared_messages,
|
||||
prepared_messages, # Pass the potentially modified messages
|
||||
llm_model,
|
||||
temperature=llm_temperature,
|
||||
max_tokens=llm_max_tokens,
|
||||
tools=openapi_tool_schemas,
|
||||
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort
|
||||
)
|
||||
logger.debug("Successfully received raw LLM API response stream/object")
|
||||
|
||||
|
@ -326,22 +345,27 @@ Here are the XML tools available with examples:
|
|||
response_generator = self.response_processor.process_streaming_response(
|
||||
llm_response=llm_response,
|
||||
thread_id=thread_id,
|
||||
config=processor_config
|
||||
config=processor_config,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
return response_generator
|
||||
else:
|
||||
logger.debug("Processing non-streaming response")
|
||||
try:
|
||||
response = await self.response_processor.process_non_streaming_response(
|
||||
# Return the async generator directly, don't await it
|
||||
response_generator = self.response_processor.process_non_streaming_response(
|
||||
llm_response=llm_response,
|
||||
thread_id=thread_id,
|
||||
config=processor_config
|
||||
config=processor_config,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model
|
||||
)
|
||||
return response
|
||||
return response_generator # Return the generator
|
||||
except Exception as e:
|
||||
logger.error(f"Error in non-streaming response: {str(e)}", exc_info=True)
|
||||
raise
|
||||
logger.error(f"Error setting up non-streaming response: {str(e)}", exc_info=True)
|
||||
raise # Re-raise the exception to be caught by the outer handler
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
|
||||
|
@ -358,8 +382,9 @@ Here are the XML tools available with examples:
|
|||
# Reset auto_continue for this iteration
|
||||
auto_continue = False
|
||||
|
||||
# Run the thread once
|
||||
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
|
||||
# Run the thread once, passing the potentially modified system prompt
|
||||
# Pass temp_msg only on the first iteration
|
||||
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
|
||||
|
||||
# Handle error responses
|
||||
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error":
|
||||
|
@ -402,7 +427,8 @@ Here are the XML tools available with examples:
|
|||
# If auto-continue is disabled (max=0), just run once
|
||||
if native_max_auto_continues == 0:
|
||||
logger.info("Auto-continue is disabled (native_max_auto_continues=0)")
|
||||
return await _run_once(temporary_message)
|
||||
# Pass the potentially modified system prompt and temp message
|
||||
return await _run_once(temporary_message)
|
||||
|
||||
# Otherwise return the auto-continue wrapper generator
|
||||
return auto_continue_wrapper()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
streamlit-quill==0.0.3
|
||||
python-dotenv==1.0.1
|
||||
litellm>=1.44.0
|
||||
litellm>=1.66.2
|
||||
click==8.1.7
|
||||
questionary==2.0.1
|
||||
requests>=2.31.0
|
||||
|
@ -22,4 +22,5 @@ certifi==2024.2.2
|
|||
python-ripgrep==0.0.6
|
||||
daytona_sdk>=0.12.0
|
||||
boto3>=1.34.0
|
||||
exa-py>=1.9.1
|
||||
pydantic
|
||||
tavily-python>=0.5.4
|
|
@ -151,9 +151,11 @@ async def list_files(
|
|||
|
||||
for file in files:
|
||||
# Convert file information to our model
|
||||
# Ensure forward slashes are used for paths, regardless of OS
|
||||
full_path = f"{path.rstrip('/')}/{file.name}" if path != '/' else f"/{file.name}"
|
||||
file_info = FileInfo(
|
||||
name=file.name,
|
||||
path=os.path.join(path, file.name),
|
||||
path=full_path, # Use the constructed path
|
||||
is_dir=file.is_dir,
|
||||
size=file.size,
|
||||
mod_time=str(file.mod_time),
|
||||
|
|
|
@ -17,6 +17,8 @@ import asyncio
|
|||
from openai import OpenAIError
|
||||
import litellm
|
||||
from utils.logger import logger
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
# litellm.set_verbose=True
|
||||
litellm.modify_params=True
|
||||
|
@ -82,7 +84,9 @@ def prepare_params(
|
|||
api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None
|
||||
model_id: Optional[str] = None,
|
||||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare parameters for the API call."""
|
||||
params = {
|
||||
|
@ -152,6 +156,75 @@ def prepare_params(
|
|||
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
|
||||
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
||||
|
||||
# Apply Anthropic prompt caching (minimal implementation)
|
||||
# Check model name *after* potential modifications (like adding bedrock/ prefix)
|
||||
effective_model_name = params.get("model", model_name) # Use model from params if set, else original
|
||||
if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower():
|
||||
logger.debug("Applying minimal Anthropic prompt caching.")
|
||||
messages = params["messages"] # Direct reference, modification affects params
|
||||
|
||||
# Ensure messages is a list
|
||||
if not isinstance(messages, list):
|
||||
logger.warning(f"Messages is not a list ({type(messages)}), skipping Anthropic cache control.")
|
||||
return params # Return early if messages format is unexpected
|
||||
|
||||
# 1. Process the first message if it's a system prompt with string content
|
||||
if messages and messages[0].get("role") == "system":
|
||||
content = messages[0].get("content")
|
||||
if isinstance(content, str):
|
||||
# Wrap the string content in the required list structure
|
||||
messages[0]["content"] = [
|
||||
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||
]
|
||||
logger.debug("Applied cache_control to system message (converted from string).")
|
||||
elif isinstance(content, list):
|
||||
# If content is already a list, check if the first text block needs cache_control
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
if "cache_control" not in item:
|
||||
item["cache_control"] = {"type": "ephemeral"}
|
||||
break # Apply to the first text block only for system prompt
|
||||
else:
|
||||
logger.warning("System message content is not a string or list, skipping cache_control.")
|
||||
|
||||
# 2. Find and process the last user message
|
||||
last_user_idx = -1
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].get("role") == "user":
|
||||
last_user_idx = i
|
||||
break
|
||||
|
||||
if last_user_idx != -1:
|
||||
last_user_message = messages[last_user_idx]
|
||||
content = last_user_message.get("content")
|
||||
|
||||
if isinstance(content, str):
|
||||
# Wrap the string content in the required list structure
|
||||
last_user_message["content"] = [
|
||||
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
||||
]
|
||||
logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).")
|
||||
elif isinstance(content, list):
|
||||
# Modify text blocks within the list directly
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
# Add cache_control if not already present
|
||||
if "cache_control" not in item:
|
||||
item["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
else:
|
||||
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
|
||||
|
||||
# Add reasoning_effort for Anthropic models if enabled
|
||||
use_thinking = enable_thinking if enable_thinking is not None else False
|
||||
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
|
||||
|
||||
if is_anthropic and use_thinking:
|
||||
effort_level = reasoning_effort if reasoning_effort else 'low'
|
||||
params["reasoning_effort"] = effort_level
|
||||
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
|
||||
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
|
||||
|
||||
return params
|
||||
|
||||
async def make_llm_api_call(
|
||||
|
@ -166,7 +239,9 @@ async def make_llm_api_call(
|
|||
api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None
|
||||
model_id: Optional[str] = None,
|
||||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""
|
||||
Make an API call to a language model using LiteLLM.
|
||||
|
@ -184,6 +259,8 @@ async def make_llm_api_call(
|
|||
stream: Whether to stream the response
|
||||
top_p: Top-p sampling parameter
|
||||
model_id: Optional ARN for Bedrock inference profiles
|
||||
enable_thinking: Whether to enable thinking
|
||||
reasoning_effort: Level of reasoning effort
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], AsyncGenerator]: API response or stream
|
||||
|
@ -192,7 +269,7 @@ async def make_llm_api_call(
|
|||
LLMRetryError: If API call fails after retries
|
||||
LLMError: For other API-related errors
|
||||
"""
|
||||
logger.debug(f"Making LLM API call to model: {model_name}")
|
||||
logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
|
||||
params = prepare_params(
|
||||
messages=messages,
|
||||
model_name=model_name,
|
||||
|
@ -205,7 +282,9 @@ async def make_llm_api_call(
|
|||
api_base=api_base,
|
||||
stream=stream,
|
||||
top_p=top_p,
|
||||
model_id=model_id
|
||||
model_id=model_id,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort
|
||||
)
|
||||
|
||||
last_error = None
|
||||
|
|
|
@ -12,7 +12,7 @@ function DashboardContent() {
|
|||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const router = useRouter();
|
||||
|
||||
const handleSubmit = async (message: string) => {
|
||||
const handleSubmit = async (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => {
|
||||
if (!message.trim() || isSubmitting) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
|
@ -34,7 +34,11 @@ function DashboardContent() {
|
|||
await addUserMessage(thread.thread_id, message.trim());
|
||||
|
||||
// 4. Start the agent with the thread ID
|
||||
const agentRun = await startAgent(thread.thread_id);
|
||||
const agentRun = await startAgent(thread.thread_id, {
|
||||
model_name: options?.model_name,
|
||||
enable_thinking: options?.enable_thinking,
|
||||
stream: true
|
||||
});
|
||||
|
||||
// 5. Navigate to the new agent's thread page
|
||||
router.push(`/dashboard/agents/${thread.thread_id}`);
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
import React, { useState, useRef, useEffect } from 'react';
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText } from "lucide-react";
|
||||
import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText, ChevronDown, Cpu } from "lucide-react";
|
||||
import { createClient } from "@/lib/supabase/client";
|
||||
import { toast } from "sonner";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
|
@ -13,13 +13,22 @@ import {
|
|||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Define API_URL
|
||||
const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || '';
|
||||
|
||||
// Local storage keys
|
||||
const STORAGE_KEY_MODEL = 'suna-preferred-model';
|
||||
|
||||
interface ChatInputProps {
|
||||
onSubmit: (message: string) => void;
|
||||
onSubmit: (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => void;
|
||||
placeholder?: string;
|
||||
loading?: boolean;
|
||||
disabled?: boolean;
|
||||
|
@ -40,7 +49,7 @@ interface UploadedFile {
|
|||
|
||||
export function ChatInput({
|
||||
onSubmit,
|
||||
placeholder = "Type your message... (Enter to send, Shift+Enter for new line)",
|
||||
placeholder = "Describe what you need help with...",
|
||||
loading = false,
|
||||
disabled = false,
|
||||
isAgentRunning = false,
|
||||
|
@ -52,47 +61,63 @@ export function ChatInput({
|
|||
sandboxId
|
||||
}: ChatInputProps) {
|
||||
const [inputValue, setInputValue] = useState(value || "");
|
||||
const [selectedModel, setSelectedModel] = useState("sonnet-3.7");
|
||||
const textareaRef = useRef<HTMLTextAreaElement | null>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<UploadedFile[]>([]);
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [isDraggingOver, setIsDraggingOver] = useState(false);
|
||||
|
||||
// Allow controlled or uncontrolled usage
|
||||
useEffect(() => {
|
||||
if (typeof window !== 'undefined') {
|
||||
try {
|
||||
const savedModel = localStorage.getItem(STORAGE_KEY_MODEL);
|
||||
if (savedModel) {
|
||||
setSelectedModel(savedModel);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to load preferences from localStorage:', error);
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
const isControlled = value !== undefined && onChange !== undefined;
|
||||
|
||||
// Update local state if controlled and value changes
|
||||
useEffect(() => {
|
||||
if (isControlled && value !== inputValue) {
|
||||
setInputValue(value);
|
||||
}
|
||||
}, [value, isControlled, inputValue]);
|
||||
|
||||
// Auto-focus on textarea when component loads
|
||||
useEffect(() => {
|
||||
if (autoFocus && textareaRef.current) {
|
||||
textareaRef.current.focus();
|
||||
}
|
||||
}, [autoFocus]);
|
||||
|
||||
// Adjust textarea height based on content
|
||||
useEffect(() => {
|
||||
const textarea = textareaRef.current;
|
||||
if (!textarea) return;
|
||||
|
||||
const adjustHeight = () => {
|
||||
textarea.style.height = 'auto';
|
||||
const newHeight = Math.min(textarea.scrollHeight, 200); // Max height of 200px
|
||||
const newHeight = Math.min(Math.max(textarea.scrollHeight, 50), 200); // Min 50px, max 200px
|
||||
textarea.style.height = `${newHeight}px`;
|
||||
};
|
||||
|
||||
adjustHeight();
|
||||
|
||||
// Adjust on window resize too
|
||||
window.addEventListener('resize', adjustHeight);
|
||||
return () => window.removeEventListener('resize', adjustHeight);
|
||||
}, [inputValue]);
|
||||
|
||||
const handleModelChange = (model: string) => {
|
||||
setSelectedModel(model);
|
||||
if (typeof window !== 'undefined') {
|
||||
localStorage.setItem(STORAGE_KEY_MODEL, model);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if ((!inputValue.trim() && uploadedFiles.length === 0) || loading || (disabled && !isAgentRunning)) return;
|
||||
|
@ -104,7 +129,6 @@ export function ChatInput({
|
|||
|
||||
let message = inputValue;
|
||||
|
||||
// Add file information to the message if files were uploaded
|
||||
if (uploadedFiles.length > 0) {
|
||||
const fileInfo = uploadedFiles.map(file =>
|
||||
`[Uploaded file: ${file.name} (${formatFileSize(file.size)}) at ${file.path}]`
|
||||
|
@ -112,13 +136,22 @@ export function ChatInput({
|
|||
message = message ? `${message}\n\n${fileInfo}` : fileInfo;
|
||||
}
|
||||
|
||||
onSubmit(message);
|
||||
let baseModelName = selectedModel;
|
||||
let thinkingEnabled = false;
|
||||
if (selectedModel === "sonnet-3.7-thinking") {
|
||||
baseModelName = "sonnet-3.7";
|
||||
thinkingEnabled = true;
|
||||
}
|
||||
|
||||
onSubmit(message, {
|
||||
model_name: baseModelName,
|
||||
enable_thinking: thinkingEnabled
|
||||
});
|
||||
|
||||
if (!isControlled) {
|
||||
setInputValue("");
|
||||
}
|
||||
|
||||
// Reset the uploaded files after sending
|
||||
setUploadedFiles([]);
|
||||
};
|
||||
|
||||
|
@ -175,7 +208,6 @@ export function ChatInput({
|
|||
const files = Array.from(event.target.files);
|
||||
await uploadFiles(files);
|
||||
|
||||
// Reset the input
|
||||
event.target.value = '';
|
||||
};
|
||||
|
||||
|
@ -191,11 +223,9 @@ export function ChatInput({
|
|||
continue;
|
||||
}
|
||||
|
||||
// Create a FormData object
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
// Upload to workspace root by default
|
||||
const uploadPath = `/workspace/${file.name}`;
|
||||
formData.append('path', uploadPath);
|
||||
|
||||
|
@ -206,7 +236,6 @@ export function ChatInput({
|
|||
throw new Error('No access token available');
|
||||
}
|
||||
|
||||
// Upload using FormData
|
||||
const response = await fetch(`${API_URL}/sandboxes/${sandboxId}/files`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
@ -219,7 +248,6 @@ export function ChatInput({
|
|||
throw new Error(`Upload failed: ${response.statusText}`);
|
||||
}
|
||||
|
||||
// Add to uploaded files
|
||||
newUploadedFiles.push({
|
||||
name: file.name,
|
||||
path: uploadPath,
|
||||
|
@ -229,7 +257,6 @@ export function ChatInput({
|
|||
toast.success(`File uploaded: ${file.name}`);
|
||||
}
|
||||
|
||||
// Update the uploaded files state
|
||||
setUploadedFiles(prev => [...prev, ...newUploadedFiles]);
|
||||
|
||||
} catch (error) {
|
||||
|
@ -273,11 +300,18 @@ export function ChatInput({
|
|||
}
|
||||
};
|
||||
|
||||
const modelDisplayNames: { [key: string]: string } = {
|
||||
"sonnet-3.7": "Sonnet 3.7",
|
||||
"sonnet-3.7-thinking": "Sonnet 3.7 (Thinking)",
|
||||
"gpt-4.1": "GPT-4.1",
|
||||
"gemini-flash-2.5": "Gemini Flash 2.5"
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"w-full border rounded-lg transition-all duration-200 shadow-sm",
|
||||
uploadedFiles.length > 0 ? "border-border" : "border-input",
|
||||
"w-full border rounded-xl transition-all duration-200 shadow-sm bg-[#1a1a1a] border-gray-800",
|
||||
uploadedFiles.length > 0 ? "border-border" : "border-gray-800",
|
||||
isDraggingOver ? "border-primary border-dashed bg-primary/5" : ""
|
||||
)}
|
||||
onDragOver={handleDragOver}
|
||||
|
@ -300,18 +334,18 @@ export function ChatInput({
|
|||
animate={{ opacity: 1, scale: 1 }}
|
||||
exit={{ opacity: 0, scale: 0.9 }}
|
||||
transition={{ duration: 0.15 }}
|
||||
className="px-2 py-1 bg-secondary/20 rounded-full flex items-center gap-1.5 group border border-secondary/30 hover:border-secondary/50 transition-colors text-sm"
|
||||
className="px-2 py-1 bg-gray-800 rounded-full flex items-center gap-1.5 group border border-gray-700 hover:border-gray-600 transition-colors text-sm"
|
||||
>
|
||||
{getFileIcon(file.name)}
|
||||
<span className="font-medium truncate max-w-[120px]">{file.name}</span>
|
||||
<span className="text-xs text-muted-foreground flex-shrink-0">
|
||||
<span className="font-medium truncate max-w-[120px] text-gray-300">{file.name}</span>
|
||||
<span className="text-xs text-gray-400 flex-shrink-0">
|
||||
({formatFileSize(file.size)})
|
||||
</span>
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-4 w-4 ml-0.5 rounded-full p-0 hover:bg-secondary/50"
|
||||
className="h-4 w-4 ml-0.5 rounded-full p-0 hover:bg-gray-700"
|
||||
onClick={() => removeUploadedFile(index)}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
|
@ -319,42 +353,94 @@ export function ChatInput({
|
|||
</motion.div>
|
||||
))}
|
||||
</div>
|
||||
<div className="h-px bg-border/40 my-2 mx-1" />
|
||||
<div className="h-px bg-gray-800 my-2 mx-1" />
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<div className="relative">
|
||||
{isDraggingOver && (
|
||||
<div className="absolute inset-0 flex items-center justify-center z-10 pointer-events-none">
|
||||
<div className="flex flex-col items-center">
|
||||
<Upload className="h-6 w-6 text-primary mb-2" />
|
||||
<p className="text-sm font-medium text-primary">Drop files to upload</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center px-3 py-3">
|
||||
<div className="relative flex-1 flex items-center">
|
||||
<Textarea
|
||||
ref={textareaRef}
|
||||
value={inputValue}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={
|
||||
isAgentRunning
|
||||
? "Agent is thinking..."
|
||||
: placeholder
|
||||
}
|
||||
className={cn(
|
||||
"min-h-[50px] py-3 px-4 text-gray-200 resize-none border-0 focus-visible:ring-0 focus-visible:ring-offset-0 bg-transparent w-full text-base",
|
||||
isDraggingOver ? "opacity-20" : ""
|
||||
)}
|
||||
disabled={loading || (disabled && !isAgentRunning)}
|
||||
rows={1}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<Textarea
|
||||
ref={textareaRef}
|
||||
value={inputValue}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={
|
||||
isAgentRunning
|
||||
? "Agent is thinking..."
|
||||
: placeholder
|
||||
}
|
||||
className={cn(
|
||||
"min-h-[50px] max-h-[200px] pr-24 resize-none border-0 focus-visible:ring-0 focus-visible:ring-offset-0 bg-white",
|
||||
isDraggingOver ? "opacity-20" : "",
|
||||
isAgentRunning ? "rounded-t-lg" : "rounded-lg"
|
||||
<div className="flex items-center gap-3 ml-3">
|
||||
{!isAgentRunning && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"h-8 w-8 rounded-full p-0 hover:bg-gray-800",
|
||||
selectedModel === "sonnet-3.7" ? "text-purple-400" :
|
||||
selectedModel === "sonnet-3.7-thinking" ? "text-violet-400" :
|
||||
selectedModel === "gpt-4.1" ? "text-green-400" :
|
||||
selectedModel === "gemini-flash-2.5" ? "text-blue-400" :
|
||||
"text-gray-400"
|
||||
)}
|
||||
aria-label="Select model"
|
||||
>
|
||||
<Cpu className="h-5 w-5" />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="bg-gray-900 border-gray-800 text-gray-300">
|
||||
<DropdownMenuItem onClick={() => handleModelChange("sonnet-3.7")} className={cn(
|
||||
"hover:bg-gray-800 flex items-center justify-between",
|
||||
selectedModel === "sonnet-3.7" && "text-purple-400"
|
||||
)}>
|
||||
<span>Sonnet 3.7</span>
|
||||
{selectedModel === "sonnet-3.7" && <span className="ml-2">✓</span>}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={() => handleModelChange("sonnet-3.7-thinking")} className={cn(
|
||||
"hover:bg-gray-800 flex items-center justify-between",
|
||||
selectedModel === "sonnet-3.7-thinking" && "text-violet-400"
|
||||
)}>
|
||||
<span>Sonnet 3.7 (Thinking)</span>
|
||||
{selectedModel === "sonnet-3.7-thinking" && <span className="ml-2">✓</span>}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={() => handleModelChange("gpt-4.1")} className={cn(
|
||||
"hover:bg-gray-800 flex items-center justify-between",
|
||||
selectedModel === "gpt-4.1" && "text-green-400"
|
||||
)}>
|
||||
<span>GPT-4.1</span>
|
||||
{selectedModel === "gpt-4.1" && <span className="ml-2">✓</span>}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={() => handleModelChange("gemini-flash-2.5")} className={cn(
|
||||
"hover:bg-gray-800 flex items-center justify-between",
|
||||
selectedModel === "gemini-flash-2.5" && "text-blue-400"
|
||||
)}>
|
||||
<span>Gemini Flash 2.5</span>
|
||||
{selectedModel === "gemini-flash-2.5" && <span className="ml-2">✓</span>}
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" className="bg-gray-900 text-gray-300 border-gray-800">
|
||||
<p>Model: {modelDisplayNames[selectedModel as keyof typeof modelDisplayNames] || selectedModel}</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
disabled={loading || (disabled && !isAgentRunning)}
|
||||
rows={1}
|
||||
/>
|
||||
|
||||
<div className="absolute right-2 bottom-2 flex items-center space-x-1.5">
|
||||
{/* Upload file button */}
|
||||
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
|
@ -364,26 +450,25 @@ export function ChatInput({
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"h-8 w-8 rounded-full transition-all hover:bg-primary/10 hover:text-primary",
|
||||
isUploading && "text-primary"
|
||||
"h-8 w-8 rounded-full p-0 text-gray-400 hover:bg-gray-800",
|
||||
isUploading && "text-blue-400"
|
||||
)}
|
||||
disabled={loading || (disabled && !isAgentRunning) || isUploading}
|
||||
aria-label="Upload files"
|
||||
>
|
||||
{isUploading ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<Loader2 className="h-5 w-5 animate-spin" />
|
||||
) : (
|
||||
<Paperclip className="h-4 w-4" />
|
||||
<Paperclip className="h-5 w-5" />
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top">
|
||||
<TooltipContent side="top" className="bg-gray-900 text-gray-300 border-gray-800">
|
||||
<p>Attach files</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
|
||||
{/* Hidden file input */}
|
||||
<input
|
||||
type="file"
|
||||
ref={fileInputRef}
|
||||
|
@ -392,30 +477,6 @@ export function ChatInput({
|
|||
multiple
|
||||
/>
|
||||
|
||||
{/* File browser button */}
|
||||
{onFileBrowse && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={onFileBrowse}
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-8 w-8 rounded-full transition-all hover:bg-primary/10 hover:text-primary"
|
||||
disabled={loading || (disabled && !isAgentRunning)}
|
||||
aria-label="Browse files"
|
||||
>
|
||||
<File className="h-4 w-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top">
|
||||
<p>Browse workspace files</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
|
@ -425,23 +486,23 @@ export function ChatInput({
|
|||
variant={isAgentRunning ? "destructive" : "default"}
|
||||
size="icon"
|
||||
className={cn(
|
||||
"h-8 w-8 rounded-full",
|
||||
!isAgentRunning && "bg-primary hover:bg-primary/90",
|
||||
isAgentRunning && "bg-destructive hover:bg-destructive/90"
|
||||
"h-10 w-10 rounded-full",
|
||||
!isAgentRunning && "bg-gray-700 hover:bg-gray-600 text-gray-200",
|
||||
isAgentRunning && "bg-red-600 hover:bg-red-700"
|
||||
)}
|
||||
disabled={((!inputValue.trim() && uploadedFiles.length === 0) && !isAgentRunning) || loading || (disabled && !isAgentRunning)}
|
||||
aria-label={isAgentRunning ? 'Stop agent' : 'Send message'}
|
||||
>
|
||||
{loading ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<Loader2 className="h-5 w-5 animate-spin" />
|
||||
) : isAgentRunning ? (
|
||||
<Square className="h-4 w-4" />
|
||||
<Square className="h-5 w-5" />
|
||||
) : (
|
||||
<Send className="h-4 w-4" />
|
||||
<Send className="h-5 w-5" />
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top">
|
||||
<TooltipContent side="top" className="bg-gray-900 text-gray-300 border-gray-800">
|
||||
<p>{isAgentRunning ? 'Stop agent' : 'Send message'}</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
|
@ -453,15 +514,15 @@ export function ChatInput({
|
|||
<motion.div
|
||||
initial={{ opacity: 0, y: -10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
className="py-2 px-3 flex items-center justify-center bg-white border-t rounded-b-lg"
|
||||
className="py-2 px-3 flex items-center justify-center bg-gray-800 border-t border-gray-700 rounded-b-xl"
|
||||
>
|
||||
<div className="text-xs text-muted-foreground flex items-center gap-2">
|
||||
<div className="text-xs text-gray-400 flex items-center gap-2">
|
||||
<span className="inline-flex items-center">
|
||||
<Loader2 className="h-3 w-3 animate-spin mr-1.5" />
|
||||
Agent is thinking...
|
||||
</span>
|
||||
<span className="text-muted-foreground/60 border-l pl-2">
|
||||
Press <kbd className="inline-flex items-center justify-center p-0.5 mx-1 bg-muted border rounded text-xs"><Square className="h-2.5 w-2.5" /></kbd> to stop
|
||||
<span className="text-gray-500 border-l border-gray-700 pl-2">
|
||||
Press <kbd className="inline-flex items-center justify-center p-0.5 mx-1 bg-gray-700 border border-gray-600 rounded text-xs"><Square className="h-2.5 w-2.5" /></kbd> to stop
|
||||
</span>
|
||||
</div>
|
||||
</motion.div>
|
||||
|
|
|
@ -361,7 +361,15 @@ export const getMessages = async (threadId: string): Promise<Message[]> => {
|
|||
};
|
||||
|
||||
// Agent APIs
|
||||
export const startAgent = async (threadId: string): Promise<{ agent_run_id: string }> => {
|
||||
export const startAgent = async (
|
||||
threadId: string,
|
||||
options?: {
|
||||
model_name?: string;
|
||||
enable_thinking?: boolean;
|
||||
reasoning_effort?: string;
|
||||
stream?: boolean;
|
||||
}
|
||||
): Promise<{ agent_run_id: string }> => {
|
||||
try {
|
||||
const supabase = createClient();
|
||||
const { data: { session } } = await supabase.auth.getSession();
|
||||
|
@ -385,6 +393,8 @@ export const startAgent = async (threadId: string): Promise<{ agent_run_id: stri
|
|||
},
|
||||
// Add cache: 'no-store' to prevent caching
|
||||
cache: 'no-store',
|
||||
// Add the body, stringifying the options or an empty object
|
||||
body: JSON.stringify(options || {}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
|
Loading…
Reference in New Issue