mirror of https://github.com/kortix-ai/suna.git
chore(dev_): merge main into mcp-5
This commit is contained in:
commit
5249346550
|
@ -43,3 +43,7 @@ FIRECRAWL_URL=
|
|||
DAYTONA_API_KEY=
|
||||
DAYTONA_SERVER_URL=
|
||||
DAYTONA_TARGET=
|
||||
|
||||
LANGFUSE_PUBLIC_KEY="pk-REDACTED"
|
||||
LANGFUSE_SECRET_KEY="sk-REDACTED"
|
||||
LANGFUSE_HOST="https://cloud.langfuse.com"
|
||||
|
|
|
@ -23,7 +23,10 @@ from utils.logger import logger
|
|||
from utils.auth_utils import get_account_id_from_thread
|
||||
from services.billing import check_billing_status
|
||||
from agent.tools.sb_vision_tool import SandboxVisionTool
|
||||
from services.langfuse import langfuse
|
||||
from langfuse.client import StatefulTraceClient
|
||||
from agent.gemini_prompt import get_gemini_system_prompt
|
||||
|
||||
load_dotenv()
|
||||
|
||||
async def run_agent(
|
||||
|
@ -37,7 +40,8 @@ async def run_agent(
|
|||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low',
|
||||
enable_context_manager: bool = True,
|
||||
agent_config: Optional[dict] = None
|
||||
agent_config: Optional[dict] = None,
|
||||
trace: Optional[StatefulTraceClient] = None
|
||||
):
|
||||
"""Run the development agent with specified configuration."""
|
||||
logger.info(f"🚀 Starting agent with model: {model_name}")
|
||||
|
@ -58,6 +62,10 @@ async def run_agent(
|
|||
if not project.data or len(project.data) == 0:
|
||||
raise ValueError(f"Project {project_id} not found")
|
||||
|
||||
if not trace:
|
||||
logger.warning("No trace provided, creating a new one")
|
||||
trace = langfuse.trace(name="agent_run", id=thread_id, session_id=thread_id, metadata={"project_id": project_id})
|
||||
|
||||
project_data = project.data[0]
|
||||
sandbox_info = project_data.get('sandbox', {})
|
||||
if not sandbox_info.get('id'):
|
||||
|
@ -253,6 +261,7 @@ async def run_agent(
|
|||
elif "gpt-4" in model_name.lower():
|
||||
max_tokens = 4096
|
||||
|
||||
generation = trace.generation(name="thread_manager.run_thread")
|
||||
try:
|
||||
# Make the LLM call and process the response
|
||||
response = await thread_manager.run_thread(
|
||||
|
@ -277,7 +286,9 @@ async def run_agent(
|
|||
include_xml_examples=True,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
enable_context_manager=enable_context_manager,
|
||||
generation=generation,
|
||||
trace=trace
|
||||
)
|
||||
|
||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||
|
@ -291,6 +302,7 @@ async def run_agent(
|
|||
# Process the response
|
||||
error_detected = False
|
||||
try:
|
||||
full_response = ""
|
||||
async for chunk in response:
|
||||
# If we receive an error chunk, we should stop after this iteration
|
||||
if isinstance(chunk, dict) and chunk.get('type') == 'status' and chunk.get('status') == 'error':
|
||||
|
@ -311,6 +323,7 @@ async def run_agent(
|
|||
|
||||
# The actual text content is nested within
|
||||
assistant_text = assistant_content_json.get('content', '')
|
||||
full_response += assistant_text
|
||||
if isinstance(assistant_text, str): # Ensure it's a string
|
||||
# Check for the closing tags as they signal the end of the tool usage
|
||||
if '</ask>' in assistant_text or '</complete>' in assistant_text or '</web-browser-takeover>' in assistant_text:
|
||||
|
@ -334,15 +347,19 @@ async def run_agent(
|
|||
# Check if we should stop based on the last tool call or error
|
||||
if error_detected:
|
||||
logger.info(f"Stopping due to error detected in response")
|
||||
generation.end(output=full_response, status_message="error_detected", level="ERROR")
|
||||
break
|
||||
|
||||
if last_tool_call in ['ask', 'complete', 'web-browser-takeover']:
|
||||
logger.info(f"Agent decided to stop with tool: {last_tool_call}")
|
||||
generation.end(output=full_response, status_message="agent_stopped")
|
||||
continue_execution = False
|
||||
|
||||
except Exception as e:
|
||||
# Just log the error and re-raise to stop all iterations
|
||||
error_msg = f"Error during response streaming: {str(e)}"
|
||||
logger.error(f"Error: {error_msg}")
|
||||
generation.end(output=full_response, status_message=error_msg, level="ERROR")
|
||||
yield {
|
||||
"type": "status",
|
||||
"status": "error",
|
||||
|
@ -362,6 +379,10 @@ async def run_agent(
|
|||
}
|
||||
# Stop execution immediately on any error
|
||||
break
|
||||
generation.end(output=full_response)
|
||||
|
||||
langfuse.flush() # Flush Langfuse events at the end of the run
|
||||
|
||||
|
||||
|
||||
# # TESTING
|
||||
|
|
|
@ -21,6 +21,7 @@ from litellm import completion_cost
|
|||
from agentpress.tool import Tool, ToolResult
|
||||
from agentpress.tool_registry import ToolRegistry
|
||||
from utils.logger import logger
|
||||
from langfuse.client import StatefulTraceClient
|
||||
|
||||
# Type alias for XML result adding strategy
|
||||
XmlAddingStrategy = Literal["user_message", "assistant_message", "inline_edit"]
|
||||
|
@ -99,6 +100,7 @@ class ResponseProcessor:
|
|||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
trace: Optional[StatefulTraceClient] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process a streaming LLM response, handling tool calls and execution.
|
||||
|
||||
|
@ -209,7 +211,7 @@ class ResponseProcessor:
|
|||
if started_msg_obj: yield started_msg_obj
|
||||
yielded_tool_indices.add(tool_index) # Mark status as yielded
|
||||
|
||||
execution_task = asyncio.create_task(self._execute_tool(tool_call))
|
||||
execution_task = asyncio.create_task(self._execute_tool(tool_call, trace))
|
||||
pending_tool_executions.append({
|
||||
"task": execution_task, "tool_call": tool_call,
|
||||
"tool_index": tool_index, "context": context
|
||||
|
@ -587,7 +589,8 @@ class ResponseProcessor:
|
|||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig()
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
trace: Optional[StatefulTraceClient] = None,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Process a non-streaming LLM response, handling tool calls and execution.
|
||||
|
||||
|
@ -1057,8 +1060,11 @@ class ResponseProcessor:
|
|||
return parsed_data
|
||||
|
||||
# Tool execution methods
|
||||
async def _execute_tool(self, tool_call: Dict[str, Any]) -> ToolResult:
|
||||
async def _execute_tool(self, tool_call: Dict[str, Any], trace: Optional[StatefulTraceClient] = None) -> ToolResult:
|
||||
"""Execute a single tool call and return the result."""
|
||||
span = None
|
||||
if trace:
|
||||
span = trace.span(name=f"execute_tool.{tool_call['function_name']}", input=tool_call["arguments"])
|
||||
try:
|
||||
function_name = tool_call["function_name"]
|
||||
arguments = tool_call["arguments"]
|
||||
|
@ -1078,14 +1084,20 @@ class ResponseProcessor:
|
|||
tool_fn = available_functions.get(function_name)
|
||||
if not tool_fn:
|
||||
logger.error(f"Tool function '{function_name}' not found in registry")
|
||||
if span:
|
||||
span.end(status_message="tool_not_found", level="ERROR")
|
||||
return ToolResult(success=False, output=f"Tool function '{function_name}' not found")
|
||||
|
||||
logger.debug(f"Found tool function for '{function_name}', executing...")
|
||||
result = await tool_fn(**arguments)
|
||||
logger.info(f"Tool execution complete: {function_name} -> {result}")
|
||||
if span:
|
||||
span.end(status_message="tool_executed", output=result)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {tool_call['function_name']}: {str(e)}", exc_info=True)
|
||||
if span:
|
||||
span.end(status_message="tool_execution_error", output=f"Error executing tool: {str(e)}", level="ERROR")
|
||||
return ToolResult(success=False, output=f"Error executing tool: {str(e)}")
|
||||
|
||||
async def _execute_tools(
|
||||
|
|
|
@ -22,6 +22,8 @@ from agentpress.response_processor import (
|
|||
)
|
||||
from services.supabase import DBConnection
|
||||
from utils.logger import logger
|
||||
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
|
||||
import datetime
|
||||
|
||||
# Type alias for tool choice
|
||||
ToolChoice = Literal["auto", "required", "none"]
|
||||
|
@ -161,7 +163,9 @@ class ThreadManager:
|
|||
include_xml_examples: bool = False,
|
||||
enable_thinking: Optional[bool] = False,
|
||||
reasoning_effort: Optional[str] = 'low',
|
||||
enable_context_manager: bool = True
|
||||
enable_context_manager: bool = True,
|
||||
generation: Optional[StatefulGenerationClient] = None,
|
||||
trace: Optional[StatefulTraceClient] = None
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""Run a conversation thread with LLM integration and tool execution.
|
||||
|
||||
|
@ -322,6 +326,20 @@ Here are the XML tools available with examples:
|
|||
# 5. Make LLM API call
|
||||
logger.debug("Making LLM API call")
|
||||
try:
|
||||
if generation:
|
||||
generation.update(
|
||||
input=prepared_messages,
|
||||
start_time=datetime.datetime.now(datetime.timezone.utc),
|
||||
model=llm_model,
|
||||
model_parameters={
|
||||
"max_tokens": llm_max_tokens,
|
||||
"temperature": llm_temperature,
|
||||
"enable_thinking": enable_thinking,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
"tool_choice": tool_choice,
|
||||
"tools": openapi_tool_schemas,
|
||||
}
|
||||
)
|
||||
llm_response = await make_llm_api_call(
|
||||
prepared_messages, # Pass the potentially modified messages
|
||||
llm_model,
|
||||
|
@ -347,7 +365,8 @@ Here are the XML tools available with examples:
|
|||
thread_id=thread_id,
|
||||
config=processor_config,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model
|
||||
llm_model=llm_model,
|
||||
trace=trace
|
||||
)
|
||||
|
||||
return response_generator
|
||||
|
@ -359,7 +378,8 @@ Here are the XML tools available with examples:
|
|||
thread_id=thread_id,
|
||||
config=processor_config,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model
|
||||
llm_model=llm_model,
|
||||
trace=trace
|
||||
)
|
||||
return response_generator # Return the generator
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
|
@ -249,6 +249,18 @@ files = [
|
|||
[package.extras]
|
||||
visualize = ["Twisted (>=16.1.1)", "graphviz (>0.5.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "2.2.1"
|
||||
description = "Function decoration for backoff and retry"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"},
|
||||
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.37.34"
|
||||
|
@ -1217,6 +1229,33 @@ files = [
|
|||
[package.dependencies]
|
||||
referencing = ">=0.31.0"
|
||||
|
||||
[[package]]
|
||||
name = "langfuse"
|
||||
version = "2.60.5"
|
||||
description = "A client library for accessing langfuse"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "langfuse-2.60.5-py3-none-any.whl", hash = "sha256:fd27d52017f36d6fa5ca652615213a2535dc93dd88c3375eeb811af26384d285"},
|
||||
{file = "langfuse-2.60.5.tar.gz", hash = "sha256:a33ecddc98cf6d12289372e63071b77b72230e7bc8260ee349f1465d53bf425b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.4.0,<5.0.0"
|
||||
backoff = ">=1.10.0"
|
||||
httpx = ">=0.15.4,<1.0"
|
||||
idna = ">=3.7,<4.0"
|
||||
packaging = ">=23.2,<25.0"
|
||||
pydantic = ">=1.10.7,<3.0"
|
||||
requests = ">=2,<3"
|
||||
wrapt = ">=1.14,<2.0"
|
||||
|
||||
[package.extras]
|
||||
langchain = ["langchain (>=0.0.309)"]
|
||||
llama-index = ["llama-index (>=0.10.12,<2.0.0)"]
|
||||
openai = ["openai (>=0.27.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.66.1"
|
||||
|
@ -3539,4 +3578,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "6163a36d6c3507a20552400544de78f7b48a92a98c8c68db7c98263465bf275a"
|
||||
content-hash = "8bf5f2b60329678979d6eceb2c9860e92b9f2f68cad75651239fdacfc3964633"
|
||||
|
|
|
@ -52,6 +52,7 @@ stripe = "^12.0.1"
|
|||
dramatiq = "^1.17.1"
|
||||
pika = "^1.3.2"
|
||||
prometheus-client = "^0.21.1"
|
||||
langfuse = "^2.60.5"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
agentpress = "agentpress.cli:main"
|
||||
|
|
|
@ -32,3 +32,4 @@ stripe>=12.0.1
|
|||
dramatiq>=1.17.1
|
||||
pika>=1.3.2
|
||||
prometheus-client>=0.21.1
|
||||
langfuse>=2.60.5
|
||||
|
|
|
@ -13,6 +13,7 @@ from services.supabase import DBConnection
|
|||
from services import redis
|
||||
from dramatiq.brokers.rabbitmq import RabbitmqBroker
|
||||
import os
|
||||
from services.langfuse import langfuse
|
||||
|
||||
rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq')
|
||||
rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672))
|
||||
|
@ -101,6 +102,7 @@ async def run_agent_background(
|
|||
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
||||
stop_signal_received = True # Stop the run if the checker fails
|
||||
|
||||
trace = langfuse.trace(name="agent_run", id=agent_run_id, session_id=thread_id, metadata={"project_id": project_id, "instance_id": instance_id})
|
||||
try:
|
||||
# Setup Pub/Sub listener for control signals
|
||||
pubsub = await redis.create_pubsub()
|
||||
|
@ -111,13 +113,15 @@ async def run_agent_background(
|
|||
# Ensure active run key exists and has TTL
|
||||
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
|
||||
|
||||
|
||||
# Initialize agent generator
|
||||
agent_gen = run_agent(
|
||||
thread_id=thread_id, project_id=project_id, stream=stream,
|
||||
thread_manager=thread_manager, model_name=model_name,
|
||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager,
|
||||
agent_config=agent_config
|
||||
agent_config=agent_config,
|
||||
trace=trace
|
||||
)
|
||||
|
||||
final_status = "running"
|
||||
|
@ -127,6 +131,7 @@ async def run_agent_background(
|
|||
if stop_signal_received:
|
||||
logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
||||
final_status = "stopped"
|
||||
trace.span(name="agent_run_stopped").end(status_message="agent_run_stopped", level="WARNING")
|
||||
break
|
||||
|
||||
# Store response in Redis list and publish notification
|
||||
|
@ -151,6 +156,7 @@ async def run_agent_background(
|
|||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
|
||||
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
|
||||
trace.span(name="agent_run_completed").end(status_message="agent_run_completed")
|
||||
await redis.rpush(response_list_key, json.dumps(completion_message))
|
||||
await redis.publish(response_channel, "new") # Notify about the completion message
|
||||
|
||||
|
@ -176,6 +182,7 @@ async def run_agent_background(
|
|||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
|
||||
final_status = "failed"
|
||||
trace.span(name="agent_run_failed").end(status_message=error_message, level="ERROR")
|
||||
|
||||
# Push error message to Redis list
|
||||
error_response = {"type": "status", "status": "error", "message": error_message}
|
||||
|
|
|
@ -553,7 +553,8 @@ async def create_checkout_session(
|
|||
metadata={
|
||||
'user_id': current_user_id,
|
||||
'product_id': product_id
|
||||
}
|
||||
},
|
||||
allow_promotion_codes=True
|
||||
)
|
||||
|
||||
# Update customer status to potentially active (will be confirmed by webhook)
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
import os
|
||||
from langfuse import Langfuse
|
||||
|
||||
public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
secret_key = os.getenv("LANGFUSE_SECRET_KEY")
|
||||
host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
|
||||
|
||||
enabled = False
|
||||
if public_key and secret_key:
|
||||
enabled = True
|
||||
|
||||
langfuse = Langfuse(enabled=enabled)
|
|
@ -0,0 +1,25 @@
|
|||
DROP POLICY IF EXISTS "Give read only access to internal users" ON threads;
|
||||
|
||||
CREATE POLICY "Give read only access to internal users" ON threads
|
||||
FOR SELECT
|
||||
USING (
|
||||
((auth.jwt() ->> 'email'::text) ~~ '%@kortix.ai'::text)
|
||||
);
|
||||
|
||||
|
||||
DROP POLICY IF EXISTS "Give read only access to internal users" ON messages;
|
||||
|
||||
CREATE POLICY "Give read only access to internal users" ON messages
|
||||
FOR SELECT
|
||||
USING (
|
||||
((auth.jwt() ->> 'email'::text) ~~ '%@kortix.ai'::text)
|
||||
);
|
||||
|
||||
|
||||
DROP POLICY IF EXISTS "Give read only access to internal users" ON projects;
|
||||
|
||||
CREATE POLICY "Give read only access to internal users" ON projects
|
||||
FOR SELECT
|
||||
USING (
|
||||
((auth.jwt() ->> 'email'::text) ~~ '%@kortix.ai'::text)
|
||||
);
|
|
@ -162,6 +162,11 @@ class Configuration:
|
|||
SANDBOX_IMAGE_NAME = "kortix/suna:0.1.2.8"
|
||||
SANDBOX_ENTRYPOINT = "/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf"
|
||||
|
||||
# LangFuse configuration
|
||||
LANGFUSE_PUBLIC_KEY: Optional[str] = None
|
||||
LANGFUSE_SECRET_KEY: Optional[str] = None
|
||||
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
|
||||
|
||||
@property
|
||||
def STRIPE_PRODUCT_ID(self) -> str:
|
||||
if self.ENV_MODE == EnvMode.STAGING:
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
- ./backend/services/docker/redis.conf:/usr/local/etc/redis/redis.conf:ro
|
||||
|
|
|
@ -1226,7 +1226,7 @@ export default function ThreadPage({
|
|||
isMobile ? "w-full px-4" : "max-w-3xl"
|
||||
)}>
|
||||
{threadAgentLoading || threadAgentError ? (
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-3 mb-6">
|
||||
<Skeleton className="h-4 w-32" />
|
||||
<Skeleton className="h-12 w-full rounded-lg" />
|
||||
</div>
|
||||
|
|
|
@ -63,7 +63,6 @@ export function NavAgents() {
|
|||
const isPerformingActionRef = useRef(false);
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const [isMultiSelectActive, setIsMultiSelectActive] = useState(false);
|
||||
const [selectedThreads, setSelectedThreads] = useState<Set<string>>(new Set());
|
||||
const [deleteProgress, setDeleteProgress] = useState(0);
|
||||
const [totalToDelete, setTotalToDelete] = useState(0);
|
||||
|
@ -147,10 +146,9 @@ export function NavAgents() {
|
|||
|
||||
// Function to handle thread click with loading state
|
||||
const handleThreadClick = (e: React.MouseEvent<HTMLAnchorElement>, threadId: string, url: string) => {
|
||||
// If multi-select is active, prevent navigation and toggle selection
|
||||
if (isMultiSelectActive) {
|
||||
// If thread is selected, prevent navigation
|
||||
if (selectedThreads.has(threadId)) {
|
||||
e.preventDefault();
|
||||
toggleThreadSelection(threadId);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -160,7 +158,12 @@ export function NavAgents() {
|
|||
}
|
||||
|
||||
// Toggle thread selection for multi-select
|
||||
const toggleThreadSelection = (threadId: string) => {
|
||||
const toggleThreadSelection = (threadId: string, e?: React.MouseEvent) => {
|
||||
if (e) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}
|
||||
|
||||
setSelectedThreads(prev => {
|
||||
const newSelection = new Set(prev);
|
||||
if (newSelection.has(threadId)) {
|
||||
|
@ -172,15 +175,6 @@ export function NavAgents() {
|
|||
});
|
||||
};
|
||||
|
||||
// Toggle multi-select mode
|
||||
const toggleMultiSelect = () => {
|
||||
setIsMultiSelectActive(!isMultiSelectActive);
|
||||
// Clear selections when toggling off
|
||||
if (isMultiSelectActive) {
|
||||
setSelectedThreads(new Set());
|
||||
}
|
||||
};
|
||||
|
||||
// Select all threads
|
||||
const selectAllThreads = () => {
|
||||
const allThreadIds = combinedThreads.map(thread => thread.threadId);
|
||||
|
@ -310,7 +304,6 @@ export function NavAgents() {
|
|||
|
||||
// Reset states
|
||||
setSelectedThreads(new Set());
|
||||
setIsMultiSelectActive(false);
|
||||
setDeleteProgress(0);
|
||||
setTotalToDelete(0);
|
||||
},
|
||||
|
@ -332,7 +325,6 @@ export function NavAgents() {
|
|||
|
||||
// Reset states
|
||||
setSelectedThreads(new Set());
|
||||
setIsMultiSelectActive(false);
|
||||
setThreadToDelete(null);
|
||||
isPerformingActionRef.current = false;
|
||||
setDeleteProgress(0);
|
||||
|
@ -352,19 +344,15 @@ export function NavAgents() {
|
|||
return (
|
||||
<SidebarGroup>
|
||||
<div className="flex justify-between items-center">
|
||||
<SidebarGroupLabel>
|
||||
<History className="h-2 w-2 mr-2" />
|
||||
History
|
||||
</SidebarGroupLabel>
|
||||
<SidebarGroupLabel>Tasks</SidebarGroupLabel>
|
||||
{state !== 'collapsed' ? (
|
||||
<div className="flex items-center space-x-1">
|
||||
{isMultiSelectActive ? (
|
||||
{selectedThreads.size > 0 ? (
|
||||
<>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={deselectAllThreads}
|
||||
disabled={selectedThreads.size === 0}
|
||||
className="h-7 w-7"
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
|
@ -382,41 +370,12 @@ export function NavAgents() {
|
|||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleMultiDelete}
|
||||
disabled={selectedThreads.size === 0}
|
||||
className="h-7 w-7 text-destructive"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={toggleMultiSelect}
|
||||
className="h-7 px-2 text-xs"
|
||||
>
|
||||
Done
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={toggleMultiSelect}
|
||||
className="h-7 w-7"
|
||||
disabled={combinedThreads.length === 0}
|
||||
>
|
||||
<div className="h-4 w-4 border rounded border-foreground/30 flex items-center justify-center">
|
||||
{isMultiSelectActive && <Check className="h-3 w-3" />}
|
||||
</div>
|
||||
<span className="sr-only">Select</span>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Select Multiple</TooltipContent>
|
||||
</Tooltip>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
|
@ -431,7 +390,6 @@ export function NavAgents() {
|
|||
</TooltipTrigger>
|
||||
<TooltipContent>New Agent</TooltipContent>
|
||||
</Tooltip>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
) : null}
|
||||
|
@ -476,7 +434,7 @@ export function NavAgents() {
|
|||
const isSelected = selectedThreads.has(thread.threadId);
|
||||
|
||||
return (
|
||||
<SidebarMenuItem key={`thread-${thread.threadId}`}>
|
||||
<SidebarMenuItem key={`thread-${thread.threadId}`} className="group">
|
||||
{state === 'collapsed' ? (
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
|
@ -494,13 +452,7 @@ export function NavAgents() {
|
|||
handleThreadClick(e, thread.threadId, thread.url)
|
||||
}
|
||||
>
|
||||
{isMultiSelectActive ? (
|
||||
<div
|
||||
className={`h-4 w-4 border rounded flex items-center justify-center ${isSelected ? 'bg-primary border-primary' : 'border-foreground'}`}
|
||||
>
|
||||
{isSelected && <Check className="h-3 w-3 text-white" />}
|
||||
</div>
|
||||
) : isThreadLoading ? (
|
||||
{isThreadLoading ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<MessagesSquare className="h-4 w-4" />
|
||||
|
@ -513,15 +465,16 @@ export function NavAgents() {
|
|||
<TooltipContent>{thread.projectName}</TooltipContent>
|
||||
</Tooltip>
|
||||
) : (
|
||||
<div className="relative">
|
||||
<SidebarMenuButton
|
||||
asChild
|
||||
className={
|
||||
className={`relative ${
|
||||
isActive
|
||||
? 'bg-primary/10 text-accent-foreground font-medium'
|
||||
? 'bg-accent text-accent-foreground font-medium'
|
||||
: isSelected
|
||||
? 'bg-primary/10'
|
||||
: ''
|
||||
}
|
||||
}`}
|
||||
>
|
||||
<Link
|
||||
href={thread.url}
|
||||
|
@ -530,31 +483,50 @@ export function NavAgents() {
|
|||
}
|
||||
className="flex items-center"
|
||||
>
|
||||
{isMultiSelectActive ? (
|
||||
<div
|
||||
className={`h-4 w-4 border flex-shrink-0 hover:bg-muted transition rounded mr-2 flex items-center justify-center ${isSelected ? 'bg-primary border-primary' : 'border-foreground/30'}`}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
toggleThreadSelection(thread.threadId);
|
||||
}}
|
||||
>
|
||||
{isSelected && <Check className="h-3 w-3 text-white" />}
|
||||
</div>
|
||||
) : null}
|
||||
<div className="flex items-center group/icon relative">
|
||||
{/* Show checkbox on hover or when selected, otherwise show MessagesSquare */}
|
||||
{isThreadLoading ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<MessagesSquare className="h-4 w-4" />
|
||||
<>
|
||||
{/* MessagesSquare icon - hidden on hover if not selected */}
|
||||
<MessagesSquare
|
||||
className={`h-4 w-4 transition-opacity duration-150 ${
|
||||
isSelected ? 'opacity-0' : 'opacity-100 group-hover/icon:opacity-0'
|
||||
}`}
|
||||
/>
|
||||
|
||||
{/* Checkbox - appears on hover or when selected */}
|
||||
<div
|
||||
className={`absolute inset-0 flex items-center justify-center transition-opacity duration-150 ${
|
||||
isSelected
|
||||
? 'opacity-100'
|
||||
: 'opacity-0 group-hover/icon:opacity-100'
|
||||
}`}
|
||||
onClick={(e) => toggleThreadSelection(thread.threadId, e)}
|
||||
>
|
||||
<div
|
||||
className={`h-4 w-4 border rounded cursor-pointer hover:bg-muted/50 transition-colors flex items-center justify-center ${
|
||||
isSelected
|
||||
? 'bg-primary border-primary'
|
||||
: 'border-muted-foreground/30 bg-background'
|
||||
}`}
|
||||
>
|
||||
{isSelected && <Check className="h-3 w-3 text-primary-foreground" />}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<span>{thread.projectName}</span>
|
||||
</div>
|
||||
<span className="ml-2">{thread.projectName}</span>
|
||||
</Link>
|
||||
</SidebarMenuButton>
|
||||
</div>
|
||||
)}
|
||||
{state !== 'collapsed' && !isMultiSelectActive && (
|
||||
{state !== 'collapsed' && !isSelected && (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<SidebarMenuAction showOnHover>
|
||||
<SidebarMenuAction showOnHover className="group-hover:opacity-100">
|
||||
<MoreHorizontal />
|
||||
<span className="sr-only">More</span>
|
||||
</SidebarMenuAction>
|
||||
|
|
|
@ -7,7 +7,7 @@ import { useAvailableModels } from '@/hooks/react-query/subscriptions/use-model'
|
|||
|
||||
export const STORAGE_KEY_MODEL = 'suna-preferred-model';
|
||||
export const STORAGE_KEY_CUSTOM_MODELS = 'customModels';
|
||||
export const DEFAULT_FREE_MODEL_ID = 'gemini-flash-2.5';
|
||||
export const DEFAULT_FREE_MODEL_ID = 'deepseek';
|
||||
export const DEFAULT_PREMIUM_MODEL_ID = 'claude-sonnet-4';
|
||||
|
||||
export type SubscriptionStatus = 'no_subscription' | 'active';
|
||||
|
@ -269,7 +269,7 @@ export const useModelSelection = () => {
|
|||
models = [
|
||||
{
|
||||
id: DEFAULT_FREE_MODEL_ID,
|
||||
label: 'Gemini Flash 2.5',
|
||||
label: 'DeepSeek',
|
||||
requiresSubscription: false,
|
||||
description: MODELS[DEFAULT_FREE_MODEL_ID]?.description || MODEL_TIERS.free.baseDescription,
|
||||
priority: MODELS[DEFAULT_FREE_MODEL_ID]?.priority || 50
|
||||
|
|
Loading…
Reference in New Issue