Merge pull request #54 from kortix-ai/dat_context

Dat context
This commit is contained in:
Dat LQ. 2025-04-18 17:17:43 +01:00 committed by GitHub
commit db46c1aee5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 575 additions and 266 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends, Request
from fastapi import APIRouter, HTTPException, Depends, Request, Body
from fastapi.responses import StreamingResponse
import asyncio
import json
@ -7,6 +7,7 @@ from datetime import datetime, timezone
import uuid
from typing import Optional, List, Dict, Any
import jwt
from pydantic import BaseModel
from agentpress.thread_manager import ThreadManager
from services.supabase import DBConnection
@ -26,6 +27,19 @@ db = None
# In-memory storage for active agent runs and their responses
active_agent_runs: Dict[str, List[Any]] = {}
MODEL_NAME_ALIASES = {
"sonnet-3.7": "anthropic/claude-3-7-sonnet-latest",
"gpt-4.1": "openai/gpt-4.1-2025-04-14",
"gemini-flash-2.5": "openrouter/google/gemini-2.5-flash-preview",
}
class AgentStartRequest(BaseModel):
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
enable_thinking: Optional[bool] = False
reasoning_effort: Optional[str] = 'low'
stream: Optional[bool] = True
enable_context_manager: Optional[bool] = False
def initialize(
_thread_manager: ThreadManager,
_db: DBConnection,
@ -237,9 +251,13 @@ async def _cleanup_agent_run(agent_run_id: str):
# Non-fatal error, can continue
@router.post("/thread/{thread_id}/agent/start")
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
async def start_agent(
thread_id: str,
body: AgentStartRequest = Body(...), # Accept request body
user_id: str = Depends(get_current_user_id)
):
"""Start an agent for a specific thread in the background."""
logger.info(f"Starting new agent for thread: {thread_id}")
logger.info(f"Starting new agent for thread: {thread_id} with config: model={body.model_name}, thinking={body.enable_thinking}, effort={body.reasoning_effort}, stream={body.stream}, context_manager={body.enable_context_manager}")
client = await db.client
# Verify user has access to this thread
@ -314,7 +332,18 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
# Run the agent in the background
task = asyncio.create_task(
run_agent_background(agent_run_id, thread_id, instance_id, project_id, sandbox)
run_agent_background(
agent_run_id=agent_run_id,
thread_id=thread_id,
instance_id=instance_id,
project_id=project_id,
sandbox=sandbox,
model_name=MODEL_NAME_ALIASES.get(body.model_name, body.model_name),
enable_thinking=body.enable_thinking,
reasoning_effort=body.reasoning_effort,
stream=body.stream,
enable_context_manager=body.enable_context_manager
)
)
# Set a callback to clean up when task is done
@ -441,9 +470,20 @@ async def stream_agent_run(
}
)
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str, sandbox):
async def run_agent_background(
agent_run_id: str,
thread_id: str,
instance_id: str,
project_id: str,
sandbox,
model_name: str,
enable_thinking: Optional[bool],
reasoning_effort: Optional[str],
stream: bool,
enable_context_manager: bool
):
"""Run the agent in the background and handle status updates."""
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model={model_name}, thinking={enable_thinking}, effort={reasoning_effort}, stream={stream}, context_manager={enable_context_manager}")
client = await db.client
# Tracking variables
@ -561,9 +601,17 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
try:
# Run the agent
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
agent_gen = run_agent(thread_id, stream=True,
thread_manager=thread_manager, project_id=project_id,
sandbox=sandbox)
agent_gen = run_agent(
thread_id=thread_id,
project_id=project_id,
stream=stream,
thread_manager=thread_manager,
sandbox=sandbox,
model_name=model_name,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
enable_context_manager=enable_context_manager
)
# Collect all responses to save to database
all_responses = []

View File

@ -22,7 +22,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread
load_dotenv()
async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
async def run_agent(
thread_id: str,
project_id: str,
sandbox,
stream: bool,
thread_manager: Optional[ThreadManager] = None,
native_max_auto_continues: int = 25,
max_iterations: int = 150,
model_name: str = "anthropic/claude-3-7-sonnet-latest",
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low',
enable_context_manager: bool = True
):
"""Run the development agent with specified configuration."""
if not thread_manager:
@ -42,17 +54,15 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
thread_manager.add_tool(SandboxDeployTool, sandbox=sandbox)
thread_manager.add_tool(MessageTool) # we are just doing this via prompt as there is no need to call it as a tool
if os.getenv("EXA_API_KEY"):
if os.getenv("TAVILY_API_KEY"):
thread_manager.add_tool(WebSearchTool)
else:
print("TAVILY_API_KEY not found, WebSearchTool will not be available.")
if os.getenv("RAPID_API_KEY"):
thread_manager.add_tool(DataProvidersTool)
xml_examples = ""
for tag_name, example in thread_manager.tool_registry.get_xml_examples().items():
xml_examples += f"{example}\n"
system_message = { "role": "system", "content": get_system_prompt() + "\n\n" + f"<tool_examples>\n{xml_examples}\n</tool_examples>" }
system_message = { "role": "system", "content": get_system_prompt() }
iteration_count = 0
continue_execution = True
@ -112,14 +122,16 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
except Exception as e:
print(f"Error parsing browser state: {e}")
# print(latest_browser_state.data[0])
max_tokens = 64000 if "sonnet" in model_name.lower() else None
response = await thread_manager.run_thread(
thread_id=thread_id,
system_prompt=system_message,
stream=stream,
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
llm_model=model_name,
llm_temperature=0,
llm_max_tokens=64000,
llm_max_tokens=max_tokens,
tool_choice="auto",
max_xml_tool_calls=1,
temporary_message=temporary_message,
@ -133,6 +145,9 @@ async def run_agent(thread_id: str, project_id: str, sandbox, stream: bool = Tru
),
native_max_auto_continues=native_max_auto_continues,
include_xml_examples=True,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
enable_context_manager=enable_context_manager
)
if isinstance(response, dict) and "status" in response and response["status"] == "error":
@ -267,7 +282,16 @@ async def test_agent():
print("\n👋 Test completed. Goodbye!")
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
async def process_agent_response(
thread_id: str,
project_id: str,
thread_manager: ThreadManager,
stream: bool = True,
model_name: str = "anthropic/claude-3-7-sonnet-latest",
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low',
enable_context_manager: bool = True
):
"""Process the streaming response from the agent."""
chunk_counter = 0
current_response = ""
@ -276,9 +300,20 @@ async def process_agent_response(thread_id: str, project_id: str, thread_manager
# Create a test sandbox for processing
sandbox_pass = str(uuid4())
sandbox = create_sandbox(sandbox_pass)
print(f"\033[91mTest sandbox created: {sandbox.get_preview_link(6080)}/vnc_lite.html?password={sandbox_pass}\033[0m")
print(f"\033[91mTest sandbox created: {str(sandbox.get_preview_link(6080))}/vnc_lite.html?password={sandbox_pass}\033[0m")
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, sandbox=sandbox, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
async for chunk in run_agent(
thread_id=thread_id,
project_id=project_id,
sandbox=sandbox,
stream=stream,
thread_manager=thread_manager,
native_max_auto_continues=25,
model_name=model_name,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort,
enable_context_manager=enable_context_manager
):
chunk_counter += 1
# print(f"CHUNK: {chunk}") # Uncomment for debugging

View File

@ -1,4 +1,5 @@
from exa_py import Exa
from tavily import AsyncTavilyClient
import httpx
from typing import List, Optional
from datetime import datetime
import os
@ -15,10 +16,12 @@ class WebSearchTool(Tool):
# Load environment variables
load_dotenv()
# Use the provided API key or get it from environment variables
self.api_key = api_key or os.getenv("EXA_API_KEY")
self.api_key = api_key or os.getenv("TAVILY_API_KEY")
if not self.api_key:
raise ValueError("EXA_API_KEY not found in environment variables")
self.exa = Exa(api_key=self.api_key)
raise ValueError("TAVILY_API_KEY not found in environment variables")
# Tavily asynchronous search client
self.tavily_client = AsyncTavilyClient(api_key=self.api_key)
@openapi_schema({
"type": "function",
@ -111,57 +114,49 @@ class WebSearchTool(Tool):
if not query or not isinstance(query, str):
return self.fail_response("A valid search query is required.")
# Basic parameters - use only the minimum required to avoid API errors
params = {
"query": query,
"type": "auto",
"livecrawl": "auto"
}
# Handle summary parameter (boolean conversion)
if summary is None:
params["summary"] = True
elif isinstance(summary, bool):
params["summary"] = summary
elif isinstance(summary, str):
params["summary"] = summary.lower() == "true"
else:
params["summary"] = True
# Handle num_results parameter (integer conversion)
# ---------- Tavily search parameters ----------
# num_results normalisation (150)
if num_results is None:
params["num_results"] = 20
num_results = 20
elif isinstance(num_results, int):
params["num_results"] = max(1, min(num_results, 50))
num_results = max(1, min(num_results, 50))
elif isinstance(num_results, str):
try:
params["num_results"] = max(1, min(int(num_results), 50))
num_results = max(1, min(int(num_results), 50))
except ValueError:
params["num_results"] = 20
num_results = 20
else:
params["num_results"] = 20
# Execute the search with minimal parameters
search_response = self.exa.search_and_contents(**params)
# Format the results
num_results = 20
# Execute the search with Tavily
search_response = await self.tavily_client.search(
query=query,
max_results=num_results,
include_answer=False,
include_images=False,
)
# `tavily` may return a dict with `results` or a bare list
raw_results = (
search_response.get("results")
if isinstance(search_response, dict)
else search_response
)
formatted_results = []
for result in search_response.results:
for result in raw_results:
formatted_result = {
"Title": result.title,
"URL": result.url
"Title": result.get("title"),
"URL": result.get("url"),
}
# Add optional fields if they exist
if hasattr(result, 'summary') and result.summary:
formatted_result["Summary"] = result.summary
if hasattr(result, 'published_date') and result.published_date:
formatted_result["Published Date"] = result.published_date
if hasattr(result, 'score'):
formatted_result["Score"] = result.score
if summary:
# Prefer full content; fall back to description
if result.get("content"):
formatted_result["Summary"] = result["content"]
elif result.get("description"):
formatted_result["Summary"] = result["description"]
formatted_results.append(formatted_result)
return self.success_response(formatted_results)
@ -243,26 +238,50 @@ class WebSearchTool(Tool):
else:
return self.fail_response("URL must be a string.")
# Execute the crawl with the parsed URL
result = self.exa.get_contents(
[url],
text=True,
livecrawl="auto"
)
# Format the results to include all available fields
formatted_results = []
for content in result.results:
formatted_result = {
"Title": content.title,
"URL": content.url,
"Text": content.text
# ---------- Tavily extract endpoint ----------
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Add optional fields if they exist
if hasattr(content, 'published_date') and content.published_date:
formatted_result["Published Date"] = content.published_date
payload = {
"urls": url,
"include_images": False,
"extract_depth": "basic",
}
response = await client.post(
"https://api.tavily.com/extract",
json=payload,
headers=headers,
timeout=60,
)
response.raise_for_status()
data = response.json()
print(f"--- Raw Tavily Response ---")
print(data)
print(f"--------------------------")
# Normalise Tavily extract output to a list of dicts
extracted = []
if isinstance(data, list):
extracted = data
elif isinstance(data, dict):
if "results" in data and isinstance(data["results"], list):
extracted = data["results"]
elif "urls" in data and isinstance(data["urls"], dict):
extracted = list(data["urls"].values())
else:
extracted = [data]
formatted_results = []
for item in extracted:
formatted_result = {
"Title": item.get("title"),
"URL": item.get("url") or url,
"Text": item.get("content") or item.get("text") or "",
}
if item.get("published_date"):
formatted_result["Published Date"] = item["published_date"]
formatted_results.append(formatted_result)
return self.success_response(formatted_results)
@ -279,27 +298,27 @@ class WebSearchTool(Tool):
if __name__ == "__main__":
import asyncio
# async def test_web_search():
# """Test function for the web search tool"""
# search_tool = WebSearchTool()
# result = await search_tool.web_search(
# query="rubber gym mats best prices comparison",
# summary=True,
# num_results=20
# )
# print(result)
async def test_web_search():
"""Test function for the web search tool"""
search_tool = WebSearchTool()
result = await search_tool.web_search(
query="rubber gym mats best prices comparison",
summary=True,
num_results=20
)
print(result)
async def test_crawl_webpage():
"""Test function for the webpage crawl tool"""
search_tool = WebSearchTool()
result = await search_tool.crawl_webpage(
url="https://example.com"
url="https://google.com"
)
print(result)
async def run_tests():
"""Run all test functions"""
# await test_web_search()
await test_web_search()
await test_crawl_webpage()
asyncio.run(run_tests())
asyncio.run(run_tests())

View File

@ -96,10 +96,19 @@ class ResponseProcessor:
self,
llm_response: AsyncGenerator,
thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig(),
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a streaming LLM response, handling tool calls and execution.
Args:
llm_response: Streaming response from the LLM
thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
Yields:
Complete message objects matching the DB schema, except for content chunks.
"""
@ -144,8 +153,14 @@ class ResponseProcessor:
if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None
# Check for and log Anthropic thinking content
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
logger.info(f"[THINKING]: {delta.reasoning_content}")
# Append reasoning to main content to be saved in the final message
accumulated_content += delta.reasoning_content
# --- Process Content Chunk ---
# Process content chunk
if delta and hasattr(delta, 'content') and delta.content:
chunk_content = delta.content
accumulated_content += chunk_content
@ -263,8 +278,8 @@ class ResponseProcessor:
tool_index += 1
if finish_reason == "xml_tool_limit_reached":
logger.info("Stopping stream due to XML tool call limit")
break # Exit the async for loop
logger.info("Stopping stream processing after loop due to XML tool call limit")
break
# --- After Streaming Loop ---
@ -529,10 +544,19 @@ class ResponseProcessor:
self,
llm_response: Any,
thread_id: str,
prompt_messages: List[Dict[str, Any]],
llm_model: str,
config: ProcessorConfig = ProcessorConfig()
) -> AsyncGenerator[Dict[str, Any], None]:
"""Process a non-streaming LLM response, handling tool calls and execution.
Args:
llm_response: Response from the LLM
thread_id: ID of the conversation thread
prompt_messages: List of messages sent to the LLM (the prompt)
llm_model: The name of the LLM model used
config: Configuration for parsing and execution
Yields:
Complete message objects matching the DB schema.
"""

View File

@ -161,6 +161,9 @@ class ThreadManager:
native_max_auto_continues: int = 25,
max_xml_tool_calls: int = 0,
include_xml_examples: bool = False,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low',
enable_context_manager: bool = True
) -> Union[Dict[str, Any], AsyncGenerator]:
"""Run a conversation thread with LLM integration and tool execution.
@ -178,6 +181,9 @@ class ThreadManager:
finish_reason="tool_calls" (0 disables auto-continue)
max_xml_tool_calls: Maximum number of XML tool calls to allow (0 = no limit)
include_xml_examples: Whether to include XML tool examples in the system prompt
enable_thinking: Whether to enable thinking before making a decision
reasoning_effort: The effort level for reasoning
enable_context_manager: Whether to enable automatic context summarization.
Returns:
An async generator yielding response chunks or error dict
@ -187,6 +193,52 @@ class ThreadManager:
logger.debug(f"Parameters: model={llm_model}, temperature={llm_temperature}, max_tokens={llm_max_tokens}")
logger.debug(f"Auto-continue: max={native_max_auto_continues}, XML tool limit={max_xml_tool_calls}")
# Use a default config if none was provided (needed for XML examples check)
if processor_config is None:
processor_config = ProcessorConfig()
# Apply max_xml_tool_calls if specified and not already set in config
if max_xml_tool_calls > 0 and not processor_config.max_xml_tool_calls:
processor_config.max_xml_tool_calls = max_xml_tool_calls
# Create a working copy of the system prompt to potentially modify
working_system_prompt = system_prompt.copy()
# Add XML examples to system prompt if requested, do this only ONCE before the loop
if include_xml_examples and processor_config.xml_tool_calling:
xml_examples = self.tool_registry.get_xml_examples()
if xml_examples:
examples_content = """
--- XML TOOL CALLING ---
In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format.
Format your tool calls using the specified XML tags. Place parameters marked as 'attribute' within the opening tag (e.g., `<tag attribute='value'>`). Place parameters marked as 'content' between the opening and closing tags. Place parameters marked as 'element' within their own child tags (e.g., `<tag><element>value</element></tag>`). Refer to the examples provided below for the exact structure of each tool.
String and scalar parameters should be specified as attributes, while content goes between tags.
Note that spaces for string values are not stripped. The output is parsed with regular expressions.
Here are the XML tools available with examples:
"""
for tag_name, example in xml_examples.items():
examples_content += f"<{tag_name}> Example: {example}\\n"
system_content = working_system_prompt.get('content')
if isinstance(system_content, str):
working_system_prompt['content'] += examples_content
logger.debug("Appended XML examples to string system prompt content.")
elif isinstance(system_content, list):
appended = False
for item in working_system_prompt['content']: # Modify the copy
if isinstance(item, dict) and item.get('type') == 'text' and 'text' in item:
item['text'] += examples_content
logger.debug("Appended XML examples to the first text block in list system prompt content.")
appended = True
break
if not appended:
logger.warning("System prompt content is a list but no text block found to append XML examples.")
else:
logger.warning(f"System prompt content is of unexpected type ({type(system_content)}), cannot add XML examples.")
# Control whether we need to auto-continue due to tool_calls finish reason
auto_continue = True
auto_continue_count = 0
@ -195,81 +247,46 @@ class ThreadManager:
async def _run_once(temp_msg=None):
try:
# Ensure processor_config is available in this scope
nonlocal processor_config
# Use a default config if none was provided
if processor_config is None:
processor_config = ProcessorConfig()
# Apply max_xml_tool_calls if specified and not already set
if max_xml_tool_calls > 0:
processor_config.max_xml_tool_calls = max_xml_tool_calls
# Add XML examples to system prompt if requested
if include_xml_examples and processor_config.xml_tool_calling:
xml_examples = self.tool_registry.get_xml_examples()
if xml_examples:
# logger.debug(f"Adding {len(xml_examples)} XML examples to system prompt")
# Create or append to content
if isinstance(system_prompt['content'], str):
examples_content = """
--- XML TOOL CALLING ---
In this environment you have access to a set of tools you can use to answer the user's question. The tools are specified in XML format.
{{ FORMATTING INSTRUCTIONS }}
String and scalar parameters should be specified as attributes, while content goes between tags.
Note that spaces for string values are not stripped. The output is parsed with regular expressions.
Here are the XML tools available with examples:
"""
for tag_name, example in xml_examples.items():
examples_content += f"<{tag_name}> Example: {example}\n"
system_prompt['content'] += examples_content
else:
# If content is not a string (might be a list or dict), log a warning
logger.warning("System prompt content is not a string, cannot add XML examples")
nonlocal processor_config
# Note: processor_config is now guaranteed to exist due to check above
# 1. Get messages from thread for LLM call
messages = await self.get_llm_messages(thread_id)
# 2. Check token count before proceeding
# Use litellm to count tokens in the messages
token_count = 0
try:
from litellm import token_counter
token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
# Use the potentially modified working_system_prompt for token counting
token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
token_threshold = self.context_manager.token_threshold
logger.info(f"Thread {thread_id} token count: {token_count}/{token_threshold} ({(token_count/token_threshold)*100:.1f}%)")
# If we're over the threshold, summarize the thread
if token_count >= token_threshold:
if token_count >= token_threshold and enable_context_manager:
logger.info(f"Thread token count ({token_count}) exceeds threshold ({token_threshold}), summarizing...")
# Create summary using context manager
summarized = await self.context_manager.check_and_summarize_if_needed(
thread_id=thread_id,
add_message_callback=self.add_message,
model=llm_model,
force=True # Force summarization
force=True
)
if summarized:
# If summarization was successful, get the updated messages
# This will now include the summary message and only messages after it
logger.info("Summarization complete, fetching updated messages with summary")
messages = await self.get_llm_messages(thread_id)
# Recount tokens after summarization
new_token_count = token_counter(model=llm_model, messages=[system_prompt] + messages)
# Recount tokens after summarization, using the modified prompt
new_token_count = token_counter(model=llm_model, messages=[working_system_prompt] + messages)
logger.info(f"After summarization: token count reduced from {token_count} to {new_token_count}")
else:
logger.warning("Summarization failed or wasn't needed - proceeding with original messages")
elif not enable_context_manager: # Added condition for clarity
logger.info("Automatic summarization disabled. Skipping token count check and summarization.")
except Exception as e:
logger.error(f"Error counting tokens or summarizing: {str(e)}")
# 3. Prepare messages for LLM call + add temporary message if it exists
prepared_messages = [system_prompt]
# Use the working_system_prompt which may contain the XML examples
prepared_messages = [working_system_prompt]
# Find the last user message index
last_user_index = -1
@ -306,13 +323,15 @@ Here are the XML tools available with examples:
logger.debug("Making LLM API call")
try:
llm_response = await make_llm_api_call(
prepared_messages,
prepared_messages, # Pass the potentially modified messages
llm_model,
temperature=llm_temperature,
max_tokens=llm_max_tokens,
tools=openapi_tool_schemas,
tool_choice=tool_choice if processor_config.native_tool_calling else None,
stream=stream
stream=stream,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
logger.debug("Successfully received raw LLM API response stream/object")
@ -326,22 +345,27 @@ Here are the XML tools available with examples:
response_generator = self.response_processor.process_streaming_response(
llm_response=llm_response,
thread_id=thread_id,
config=processor_config
config=processor_config,
prompt_messages=prepared_messages,
llm_model=llm_model
)
return response_generator
else:
logger.debug("Processing non-streaming response")
try:
response = await self.response_processor.process_non_streaming_response(
# Return the async generator directly, don't await it
response_generator = self.response_processor.process_non_streaming_response(
llm_response=llm_response,
thread_id=thread_id,
config=processor_config
config=processor_config,
prompt_messages=prepared_messages,
llm_model=llm_model
)
return response
return response_generator # Return the generator
except Exception as e:
logger.error(f"Error in non-streaming response: {str(e)}", exc_info=True)
raise
logger.error(f"Error setting up non-streaming response: {str(e)}", exc_info=True)
raise # Re-raise the exception to be caught by the outer handler
except Exception as e:
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
@ -358,8 +382,9 @@ Here are the XML tools available with examples:
# Reset auto_continue for this iteration
auto_continue = False
# Run the thread once
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
# Run the thread once, passing the potentially modified system prompt
# Pass temp_msg only on the first iteration
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
# Handle error responses
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error":
@ -402,7 +427,8 @@ Here are the XML tools available with examples:
# If auto-continue is disabled (max=0), just run once
if native_max_auto_continues == 0:
logger.info("Auto-continue is disabled (native_max_auto_continues=0)")
return await _run_once(temporary_message)
# Pass the potentially modified system prompt and temp message
return await _run_once(temporary_message)
# Otherwise return the auto-continue wrapper generator
return auto_continue_wrapper()

View File

@ -1,6 +1,6 @@
streamlit-quill==0.0.3
python-dotenv==1.0.1
litellm>=1.44.0
litellm>=1.66.2
click==8.1.7
questionary==2.0.1
requests>=2.31.0
@ -22,4 +22,5 @@ certifi==2024.2.2
python-ripgrep==0.0.6
daytona_sdk>=0.12.0
boto3>=1.34.0
exa-py>=1.9.1
pydantic
tavily-python>=0.5.4

View File

@ -151,9 +151,11 @@ async def list_files(
for file in files:
# Convert file information to our model
# Ensure forward slashes are used for paths, regardless of OS
full_path = f"{path.rstrip('/')}/{file.name}" if path != '/' else f"/{file.name}"
file_info = FileInfo(
name=file.name,
path=os.path.join(path, file.name),
path=full_path, # Use the constructed path
is_dir=file.is_dir,
size=file.size,
mod_time=str(file.mod_time),

View File

@ -17,6 +17,8 @@ import asyncio
from openai import OpenAIError
import litellm
from utils.logger import logger
from datetime import datetime
import traceback
# litellm.set_verbose=True
litellm.modify_params=True
@ -82,7 +84,9 @@ def prepare_params(
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
model_id: Optional[str] = None
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Dict[str, Any]:
"""Prepare parameters for the API call."""
params = {
@ -152,6 +156,75 @@ def prepare_params(
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
# Apply Anthropic prompt caching (minimal implementation)
# Check model name *after* potential modifications (like adding bedrock/ prefix)
effective_model_name = params.get("model", model_name) # Use model from params if set, else original
if "claude" in effective_model_name.lower() or "anthropic" in effective_model_name.lower():
logger.debug("Applying minimal Anthropic prompt caching.")
messages = params["messages"] # Direct reference, modification affects params
# Ensure messages is a list
if not isinstance(messages, list):
logger.warning(f"Messages is not a list ({type(messages)}), skipping Anthropic cache control.")
return params # Return early if messages format is unexpected
# 1. Process the first message if it's a system prompt with string content
if messages and messages[0].get("role") == "system":
content = messages[0].get("content")
if isinstance(content, str):
# Wrap the string content in the required list structure
messages[0]["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
logger.debug("Applied cache_control to system message (converted from string).")
elif isinstance(content, list):
# If content is already a list, check if the first text block needs cache_control
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
if "cache_control" not in item:
item["cache_control"] = {"type": "ephemeral"}
break # Apply to the first text block only for system prompt
else:
logger.warning("System message content is not a string or list, skipping cache_control.")
# 2. Find and process the last user message
last_user_idx = -1
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "user":
last_user_idx = i
break
if last_user_idx != -1:
last_user_message = messages[last_user_idx]
content = last_user_message.get("content")
if isinstance(content, str):
# Wrap the string content in the required list structure
last_user_message["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
logger.debug(f"Applied cache_control to last user message (string content, index {last_user_idx}).")
elif isinstance(content, list):
# Modify text blocks within the list directly
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
# Add cache_control if not already present
if "cache_control" not in item:
item["cache_control"] = {"type": "ephemeral"}
else:
logger.warning(f"Last user message (index {last_user_idx}) content is not a string or list ({type(content)}), skipping cache_control.")
# Add reasoning_effort for Anthropic models if enabled
use_thinking = enable_thinking if enable_thinking is not None else False
is_anthropic = "anthropic" in effective_model_name.lower() or "claude" in effective_model_name.lower()
if is_anthropic and use_thinking:
effort_level = reasoning_effort if reasoning_effort else 'low'
params["reasoning_effort"] = effort_level
params["temperature"] = 1.0 # Required by Anthropic when reasoning_effort is used
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
return params
async def make_llm_api_call(
@ -166,7 +239,9 @@ async def make_llm_api_call(
api_base: Optional[str] = None,
stream: bool = False,
top_p: Optional[float] = None,
model_id: Optional[str] = None
model_id: Optional[str] = None,
enable_thinking: Optional[bool] = False,
reasoning_effort: Optional[str] = 'low'
) -> Union[Dict[str, Any], AsyncGenerator]:
"""
Make an API call to a language model using LiteLLM.
@ -184,6 +259,8 @@ async def make_llm_api_call(
stream: Whether to stream the response
top_p: Top-p sampling parameter
model_id: Optional ARN for Bedrock inference profiles
enable_thinking: Whether to enable thinking
reasoning_effort: Level of reasoning effort
Returns:
Union[Dict[str, Any], AsyncGenerator]: API response or stream
@ -192,7 +269,7 @@ async def make_llm_api_call(
LLMRetryError: If API call fails after retries
LLMError: For other API-related errors
"""
logger.debug(f"Making LLM API call to model: {model_name}")
logger.debug(f"Making LLM API call to model: {model_name} (Thinking: {enable_thinking}, Effort: {reasoning_effort})")
params = prepare_params(
messages=messages,
model_name=model_name,
@ -205,7 +282,9 @@ async def make_llm_api_call(
api_base=api_base,
stream=stream,
top_p=top_p,
model_id=model_id
model_id=model_id,
enable_thinking=enable_thinking,
reasoning_effort=reasoning_effort
)
last_error = None

View File

@ -12,7 +12,7 @@ function DashboardContent() {
const [isSubmitting, setIsSubmitting] = useState(false);
const router = useRouter();
const handleSubmit = async (message: string) => {
const handleSubmit = async (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => {
if (!message.trim() || isSubmitting) return;
setIsSubmitting(true);
@ -34,7 +34,11 @@ function DashboardContent() {
await addUserMessage(thread.thread_id, message.trim());
// 4. Start the agent with the thread ID
const agentRun = await startAgent(thread.thread_id);
const agentRun = await startAgent(thread.thread_id, {
model_name: options?.model_name,
enable_thinking: options?.enable_thinking,
stream: true
});
// 5. Navigate to the new agent's thread page
router.push(`/dashboard/agents/${thread.thread_id}`);

View File

@ -3,7 +3,7 @@
import React, { useState, useRef, useEffect } from 'react';
import { Textarea } from "@/components/ui/textarea";
import { Button } from "@/components/ui/button";
import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText } from "lucide-react";
import { Send, Square, Loader2, File, Upload, X, Paperclip, FileText, ChevronDown, Cpu } from "lucide-react";
import { createClient } from "@/lib/supabase/client";
import { toast } from "sonner";
import { AnimatePresence, motion } from "framer-motion";
@ -13,13 +13,22 @@ import {
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { cn } from "@/lib/utils";
// Define API_URL
const API_URL = process.env.NEXT_PUBLIC_BACKEND_URL || '';
// Local storage keys
const STORAGE_KEY_MODEL = 'suna-preferred-model';
interface ChatInputProps {
onSubmit: (message: string) => void;
onSubmit: (message: string, options?: { model_name?: string; enable_thinking?: boolean }) => void;
placeholder?: string;
loading?: boolean;
disabled?: boolean;
@ -40,7 +49,7 @@ interface UploadedFile {
export function ChatInput({
onSubmit,
placeholder = "Type your message... (Enter to send, Shift+Enter for new line)",
placeholder = "Describe what you need help with...",
loading = false,
disabled = false,
isAgentRunning = false,
@ -52,47 +61,63 @@ export function ChatInput({
sandboxId
}: ChatInputProps) {
const [inputValue, setInputValue] = useState(value || "");
const [selectedModel, setSelectedModel] = useState("sonnet-3.7");
const textareaRef = useRef<HTMLTextAreaElement | null>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const [uploadedFiles, setUploadedFiles] = useState<UploadedFile[]>([]);
const [isUploading, setIsUploading] = useState(false);
const [isDraggingOver, setIsDraggingOver] = useState(false);
// Allow controlled or uncontrolled usage
useEffect(() => {
if (typeof window !== 'undefined') {
try {
const savedModel = localStorage.getItem(STORAGE_KEY_MODEL);
if (savedModel) {
setSelectedModel(savedModel);
}
} catch (error) {
console.warn('Failed to load preferences from localStorage:', error);
}
}
}, []);
const isControlled = value !== undefined && onChange !== undefined;
// Update local state if controlled and value changes
useEffect(() => {
if (isControlled && value !== inputValue) {
setInputValue(value);
}
}, [value, isControlled, inputValue]);
// Auto-focus on textarea when component loads
useEffect(() => {
if (autoFocus && textareaRef.current) {
textareaRef.current.focus();
}
}, [autoFocus]);
// Adjust textarea height based on content
useEffect(() => {
const textarea = textareaRef.current;
if (!textarea) return;
const adjustHeight = () => {
textarea.style.height = 'auto';
const newHeight = Math.min(textarea.scrollHeight, 200); // Max height of 200px
const newHeight = Math.min(Math.max(textarea.scrollHeight, 50), 200); // Min 50px, max 200px
textarea.style.height = `${newHeight}px`;
};
adjustHeight();
// Adjust on window resize too
window.addEventListener('resize', adjustHeight);
return () => window.removeEventListener('resize', adjustHeight);
}, [inputValue]);
const handleModelChange = (model: string) => {
setSelectedModel(model);
if (typeof window !== 'undefined') {
localStorage.setItem(STORAGE_KEY_MODEL, model);
}
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
if ((!inputValue.trim() && uploadedFiles.length === 0) || loading || (disabled && !isAgentRunning)) return;
@ -104,7 +129,6 @@ export function ChatInput({
let message = inputValue;
// Add file information to the message if files were uploaded
if (uploadedFiles.length > 0) {
const fileInfo = uploadedFiles.map(file =>
`[Uploaded file: ${file.name} (${formatFileSize(file.size)}) at ${file.path}]`
@ -112,13 +136,22 @@ export function ChatInput({
message = message ? `${message}\n\n${fileInfo}` : fileInfo;
}
onSubmit(message);
let baseModelName = selectedModel;
let thinkingEnabled = false;
if (selectedModel === "sonnet-3.7-thinking") {
baseModelName = "sonnet-3.7";
thinkingEnabled = true;
}
onSubmit(message, {
model_name: baseModelName,
enable_thinking: thinkingEnabled
});
if (!isControlled) {
setInputValue("");
}
// Reset the uploaded files after sending
setUploadedFiles([]);
};
@ -175,7 +208,6 @@ export function ChatInput({
const files = Array.from(event.target.files);
await uploadFiles(files);
// Reset the input
event.target.value = '';
};
@ -191,11 +223,9 @@ export function ChatInput({
continue;
}
// Create a FormData object
const formData = new FormData();
formData.append('file', file);
// Upload to workspace root by default
const uploadPath = `/workspace/${file.name}`;
formData.append('path', uploadPath);
@ -206,7 +236,6 @@ export function ChatInput({
throw new Error('No access token available');
}
// Upload using FormData
const response = await fetch(`${API_URL}/sandboxes/${sandboxId}/files`, {
method: 'POST',
headers: {
@ -219,7 +248,6 @@ export function ChatInput({
throw new Error(`Upload failed: ${response.statusText}`);
}
// Add to uploaded files
newUploadedFiles.push({
name: file.name,
path: uploadPath,
@ -229,7 +257,6 @@ export function ChatInput({
toast.success(`File uploaded: ${file.name}`);
}
// Update the uploaded files state
setUploadedFiles(prev => [...prev, ...newUploadedFiles]);
} catch (error) {
@ -273,11 +300,18 @@ export function ChatInput({
}
};
const modelDisplayNames: { [key: string]: string } = {
"sonnet-3.7": "Sonnet 3.7",
"sonnet-3.7-thinking": "Sonnet 3.7 (Thinking)",
"gpt-4.1": "GPT-4.1",
"gemini-flash-2.5": "Gemini Flash 2.5"
};
return (
<div
className={cn(
"w-full border rounded-lg transition-all duration-200 shadow-sm",
uploadedFiles.length > 0 ? "border-border" : "border-input",
"w-full border rounded-xl transition-all duration-200 shadow-sm bg-[#1a1a1a] border-gray-800",
uploadedFiles.length > 0 ? "border-border" : "border-gray-800",
isDraggingOver ? "border-primary border-dashed bg-primary/5" : ""
)}
onDragOver={handleDragOver}
@ -300,18 +334,18 @@ export function ChatInput({
animate={{ opacity: 1, scale: 1 }}
exit={{ opacity: 0, scale: 0.9 }}
transition={{ duration: 0.15 }}
className="px-2 py-1 bg-secondary/20 rounded-full flex items-center gap-1.5 group border border-secondary/30 hover:border-secondary/50 transition-colors text-sm"
className="px-2 py-1 bg-gray-800 rounded-full flex items-center gap-1.5 group border border-gray-700 hover:border-gray-600 transition-colors text-sm"
>
{getFileIcon(file.name)}
<span className="font-medium truncate max-w-[120px]">{file.name}</span>
<span className="text-xs text-muted-foreground flex-shrink-0">
<span className="font-medium truncate max-w-[120px] text-gray-300">{file.name}</span>
<span className="text-xs text-gray-400 flex-shrink-0">
({formatFileSize(file.size)})
</span>
<Button
type="button"
variant="ghost"
size="icon"
className="h-4 w-4 ml-0.5 rounded-full p-0 hover:bg-secondary/50"
className="h-4 w-4 ml-0.5 rounded-full p-0 hover:bg-gray-700"
onClick={() => removeUploadedFile(index)}
>
<X className="h-3 w-3" />
@ -319,42 +353,94 @@ export function ChatInput({
</motion.div>
))}
</div>
<div className="h-px bg-border/40 my-2 mx-1" />
<div className="h-px bg-gray-800 my-2 mx-1" />
</motion.div>
)}
</AnimatePresence>
<div className="relative">
{isDraggingOver && (
<div className="absolute inset-0 flex items-center justify-center z-10 pointer-events-none">
<div className="flex flex-col items-center">
<Upload className="h-6 w-6 text-primary mb-2" />
<p className="text-sm font-medium text-primary">Drop files to upload</p>
</div>
</div>
)}
<div className="flex items-center px-3 py-3">
<div className="relative flex-1 flex items-center">
<Textarea
ref={textareaRef}
value={inputValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={
isAgentRunning
? "Agent is thinking..."
: placeholder
}
className={cn(
"min-h-[50px] py-3 px-4 text-gray-200 resize-none border-0 focus-visible:ring-0 focus-visible:ring-offset-0 bg-transparent w-full text-base",
isDraggingOver ? "opacity-20" : ""
)}
disabled={loading || (disabled && !isAgentRunning)}
rows={1}
/>
</div>
<Textarea
ref={textareaRef}
value={inputValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={
isAgentRunning
? "Agent is thinking..."
: placeholder
}
className={cn(
"min-h-[50px] max-h-[200px] pr-24 resize-none border-0 focus-visible:ring-0 focus-visible:ring-offset-0 bg-white",
isDraggingOver ? "opacity-20" : "",
isAgentRunning ? "rounded-t-lg" : "rounded-lg"
<div className="flex items-center gap-3 ml-3">
{!isAgentRunning && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="icon"
className={cn(
"h-8 w-8 rounded-full p-0 hover:bg-gray-800",
selectedModel === "sonnet-3.7" ? "text-purple-400" :
selectedModel === "sonnet-3.7-thinking" ? "text-violet-400" :
selectedModel === "gpt-4.1" ? "text-green-400" :
selectedModel === "gemini-flash-2.5" ? "text-blue-400" :
"text-gray-400"
)}
aria-label="Select model"
>
<Cpu className="h-5 w-5" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start" className="bg-gray-900 border-gray-800 text-gray-300">
<DropdownMenuItem onClick={() => handleModelChange("sonnet-3.7")} className={cn(
"hover:bg-gray-800 flex items-center justify-between",
selectedModel === "sonnet-3.7" && "text-purple-400"
)}>
<span>Sonnet 3.7</span>
{selectedModel === "sonnet-3.7" && <span className="ml-2"></span>}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => handleModelChange("sonnet-3.7-thinking")} className={cn(
"hover:bg-gray-800 flex items-center justify-between",
selectedModel === "sonnet-3.7-thinking" && "text-violet-400"
)}>
<span>Sonnet 3.7 (Thinking)</span>
{selectedModel === "sonnet-3.7-thinking" && <span className="ml-2"></span>}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => handleModelChange("gpt-4.1")} className={cn(
"hover:bg-gray-800 flex items-center justify-between",
selectedModel === "gpt-4.1" && "text-green-400"
)}>
<span>GPT-4.1</span>
{selectedModel === "gpt-4.1" && <span className="ml-2"></span>}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => handleModelChange("gemini-flash-2.5")} className={cn(
"hover:bg-gray-800 flex items-center justify-between",
selectedModel === "gemini-flash-2.5" && "text-blue-400"
)}>
<span>Gemini Flash 2.5</span>
{selectedModel === "gemini-flash-2.5" && <span className="ml-2"></span>}
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</TooltipTrigger>
<TooltipContent side="top" className="bg-gray-900 text-gray-300 border-gray-800">
<p>Model: {modelDisplayNames[selectedModel as keyof typeof modelDisplayNames] || selectedModel}</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
disabled={loading || (disabled && !isAgentRunning)}
rows={1}
/>
<div className="absolute right-2 bottom-2 flex items-center space-x-1.5">
{/* Upload file button */}
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
@ -364,26 +450,25 @@ export function ChatInput({
variant="ghost"
size="icon"
className={cn(
"h-8 w-8 rounded-full transition-all hover:bg-primary/10 hover:text-primary",
isUploading && "text-primary"
"h-8 w-8 rounded-full p-0 text-gray-400 hover:bg-gray-800",
isUploading && "text-blue-400"
)}
disabled={loading || (disabled && !isAgentRunning) || isUploading}
aria-label="Upload files"
>
{isUploading ? (
<Loader2 className="h-4 w-4 animate-spin" />
<Loader2 className="h-5 w-5 animate-spin" />
) : (
<Paperclip className="h-4 w-4" />
<Paperclip className="h-5 w-5" />
)}
</Button>
</TooltipTrigger>
<TooltipContent side="top">
<TooltipContent side="top" className="bg-gray-900 text-gray-300 border-gray-800">
<p>Attach files</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
{/* Hidden file input */}
<input
type="file"
ref={fileInputRef}
@ -392,30 +477,6 @@ export function ChatInput({
multiple
/>
{/* File browser button */}
{onFileBrowse && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
type="button"
onClick={onFileBrowse}
variant="ghost"
size="icon"
className="h-8 w-8 rounded-full transition-all hover:bg-primary/10 hover:text-primary"
disabled={loading || (disabled && !isAgentRunning)}
aria-label="Browse files"
>
<File className="h-4 w-4" />
</Button>
</TooltipTrigger>
<TooltipContent side="top">
<p>Browse workspace files</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
@ -425,23 +486,23 @@ export function ChatInput({
variant={isAgentRunning ? "destructive" : "default"}
size="icon"
className={cn(
"h-8 w-8 rounded-full",
!isAgentRunning && "bg-primary hover:bg-primary/90",
isAgentRunning && "bg-destructive hover:bg-destructive/90"
"h-10 w-10 rounded-full",
!isAgentRunning && "bg-gray-700 hover:bg-gray-600 text-gray-200",
isAgentRunning && "bg-red-600 hover:bg-red-700"
)}
disabled={((!inputValue.trim() && uploadedFiles.length === 0) && !isAgentRunning) || loading || (disabled && !isAgentRunning)}
aria-label={isAgentRunning ? 'Stop agent' : 'Send message'}
>
{loading ? (
<Loader2 className="h-4 w-4 animate-spin" />
<Loader2 className="h-5 w-5 animate-spin" />
) : isAgentRunning ? (
<Square className="h-4 w-4" />
<Square className="h-5 w-5" />
) : (
<Send className="h-4 w-4" />
<Send className="h-5 w-5" />
)}
</Button>
</TooltipTrigger>
<TooltipContent side="top">
<TooltipContent side="top" className="bg-gray-900 text-gray-300 border-gray-800">
<p>{isAgentRunning ? 'Stop agent' : 'Send message'}</p>
</TooltipContent>
</Tooltip>
@ -453,15 +514,15 @@ export function ChatInput({
<motion.div
initial={{ opacity: 0, y: -10 }}
animate={{ opacity: 1, y: 0 }}
className="py-2 px-3 flex items-center justify-center bg-white border-t rounded-b-lg"
className="py-2 px-3 flex items-center justify-center bg-gray-800 border-t border-gray-700 rounded-b-xl"
>
<div className="text-xs text-muted-foreground flex items-center gap-2">
<div className="text-xs text-gray-400 flex items-center gap-2">
<span className="inline-flex items-center">
<Loader2 className="h-3 w-3 animate-spin mr-1.5" />
Agent is thinking...
</span>
<span className="text-muted-foreground/60 border-l pl-2">
Press <kbd className="inline-flex items-center justify-center p-0.5 mx-1 bg-muted border rounded text-xs"><Square className="h-2.5 w-2.5" /></kbd> to stop
<span className="text-gray-500 border-l border-gray-700 pl-2">
Press <kbd className="inline-flex items-center justify-center p-0.5 mx-1 bg-gray-700 border border-gray-600 rounded text-xs"><Square className="h-2.5 w-2.5" /></kbd> to stop
</span>
</div>
</motion.div>

View File

@ -361,7 +361,15 @@ export const getMessages = async (threadId: string): Promise<Message[]> => {
};
// Agent APIs
export const startAgent = async (threadId: string): Promise<{ agent_run_id: string }> => {
export const startAgent = async (
threadId: string,
options?: {
model_name?: string;
enable_thinking?: boolean;
reasoning_effort?: string;
stream?: boolean;
}
): Promise<{ agent_run_id: string }> => {
try {
const supabase = createClient();
const { data: { session } } = await supabase.auth.getSession();
@ -385,6 +393,8 @@ export const startAgent = async (threadId: string): Promise<{ agent_run_id: stri
},
// Add cache: 'no-store' to prevent caching
cache: 'no-store',
// Add the body, stringifying the options or an empty object
body: JSON.stringify(options || {}),
});
if (!response.ok) {