mirror of https://github.com/kortix-ai/suna.git
Merge branch 'main' into feat/ux
This commit is contained in:
commit
d01c13ecbf
|
@ -122,18 +122,29 @@ async def run_agent(
|
|||
try:
|
||||
browser_content = json.loads(latest_browser_state_msg.data[0]["content"])
|
||||
screenshot_base64 = browser_content.get("screenshot_base64")
|
||||
# Create a copy of the browser state without screenshot
|
||||
screenshot_url = browser_content.get("screenshot_url")
|
||||
|
||||
# Create a copy of the browser state without screenshot data
|
||||
browser_state_text = browser_content.copy()
|
||||
browser_state_text.pop('screenshot_base64', None)
|
||||
browser_state_text.pop('screenshot_url', None)
|
||||
browser_state_text.pop('screenshot_url_base64', None)
|
||||
|
||||
if browser_state_text:
|
||||
temp_message_content_list.append({
|
||||
"type": "text",
|
||||
"text": f"The following is the current state of the browser:\n{json.dumps(browser_state_text, indent=2)}"
|
||||
})
|
||||
if screenshot_base64:
|
||||
|
||||
# Prioritize screenshot_url if available
|
||||
if screenshot_url:
|
||||
temp_message_content_list.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": screenshot_url,
|
||||
}
|
||||
})
|
||||
elif screenshot_base64:
|
||||
# Fallback to base64 if URL not available
|
||||
temp_message_content_list.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
|
@ -141,7 +152,7 @@ async def run_agent(
|
|||
}
|
||||
})
|
||||
else:
|
||||
logger.warning("Browser state found but no screenshot base64 data.")
|
||||
logger.warning("Browser state found but no screenshot data.")
|
||||
|
||||
await client.table('messages').delete().eq('message_id', latest_browser_state_msg.data[0]["message_id"]).execute()
|
||||
except Exception as e:
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -49,6 +49,9 @@ vncdotool = "^1.2.0"
|
|||
tavily-python = "^0.5.4"
|
||||
pytesseract = "^0.3.13"
|
||||
stripe = "^12.0.1"
|
||||
dramatiq = "^1.17.1"
|
||||
pika = "^1.3.2"
|
||||
prometheus-client = "^0.21.1"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
agentpress = "agentpress.cli:main"
|
||||
|
|
|
@ -15,6 +15,8 @@ import traceback
|
|||
import pytesseract
|
||||
from PIL import Image
|
||||
import io
|
||||
from utils.logger import logger
|
||||
from services.supabase import DBConnection
|
||||
|
||||
#######################################################
|
||||
# Action model definitions
|
||||
|
@ -259,15 +261,16 @@ class BrowserActionResult(BaseModel):
|
|||
url: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
elements: Optional[str] = None # Formatted string of clickable elements
|
||||
screenshot_base64: Optional[str] = None
|
||||
screenshot_base64: Optional[str] = None # For backward compatibility
|
||||
screenshot_url: Optional[str] = None
|
||||
pixels_above: int = 0
|
||||
pixels_below: int = 0
|
||||
content: Optional[str] = None
|
||||
ocr_text: Optional[str] = None # Added field for OCR text
|
||||
ocr_text: Optional[str] = None
|
||||
|
||||
# Additional metadata
|
||||
element_count: int = 0 # Number of interactive elements found
|
||||
interactive_elements: Optional[List[Dict[str, Any]]] = None # Simplified list of interactive elements
|
||||
element_count: int = 0
|
||||
interactive_elements: Optional[List[Dict[str, Any]]] = None
|
||||
viewport_width: Optional[int] = None
|
||||
viewport_height: Optional[int] = None
|
||||
|
||||
|
@ -288,6 +291,7 @@ class BrowserAutomation:
|
|||
self.include_attributes = ["id", "href", "src", "alt", "aria-label", "placeholder", "name", "role", "title", "value"]
|
||||
self.screenshot_dir = os.path.join(os.getcwd(), "screenshots")
|
||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||
self.db = DBConnection() # Initialize DB connection
|
||||
|
||||
# Register routes
|
||||
self.router.on_startup.append(self.startup)
|
||||
|
@ -609,15 +613,85 @@ class BrowserAutomation:
|
|||
)
|
||||
|
||||
async def take_screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 encoded string"""
|
||||
"""Take a screenshot and return as base64 encoded string or S3 URL"""
|
||||
try:
|
||||
page = await self.get_current_page()
|
||||
screenshot_bytes = await page.screenshot(type='jpeg', quality=60, full_page=False)
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
|
||||
client = await self.db.client
|
||||
|
||||
if client:
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
random_id = random.randint(1000, 9999)
|
||||
filename = f"screenshot_{timestamp}_{random_id}.jpg"
|
||||
|
||||
logger.info(f"Attempting to upload screenshot: {filename}")
|
||||
result = await self.upload_to_storage(client, screenshot_bytes, filename)
|
||||
|
||||
if isinstance(result, dict) and result.get("is_s3") and result.get("url"):
|
||||
if await self.verify_file_exists(client, filename):
|
||||
logger.info(f"Screenshot upload verified: {filename}")
|
||||
else:
|
||||
logger.error(f"Screenshot upload failed verification: {filename}")
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
|
||||
return result
|
||||
else:
|
||||
logger.warning("No Supabase client available, falling back to base64")
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
except Exception as e:
|
||||
print(f"Error taking screenshot: {e}")
|
||||
# Return an empty string rather than failing
|
||||
logger.error(f"Error taking screenshot: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
async def upload_to_storage(self, client, file_bytes: bytes, filename: str) -> str:
|
||||
"""Upload file to Supabase Storage and return the URL"""
|
||||
try:
|
||||
bucket_name = 'screenshots'
|
||||
|
||||
buckets = client.storage.list_buckets()
|
||||
if not any(bucket.name == bucket_name for bucket in buckets):
|
||||
logger.info(f"Creating bucket: {bucket_name}")
|
||||
try:
|
||||
client.storage.create_bucket(bucket_name)
|
||||
logger.info("Bucket created successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create bucket: {str(e)}")
|
||||
raise
|
||||
|
||||
logger.info(f"Uploading file: {filename}")
|
||||
try:
|
||||
result = client.storage.from_(bucket_name).upload(
|
||||
path=filename,
|
||||
file=file_bytes,
|
||||
file_options={"content-type": "image/jpeg"}
|
||||
)
|
||||
logger.info("File upload successful")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file: {str(e)}")
|
||||
raise
|
||||
|
||||
file_url = client.storage.from_(bucket_name).get_public_url(filename)
|
||||
logger.info(f"Generated URL: {file_url}")
|
||||
|
||||
return {"url": file_url, "is_s3": True}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in upload_to_storage: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return base64.b64encode(file_bytes).decode('utf-8')
|
||||
|
||||
async def verify_file_exists(self, client, filename: str) -> bool:
|
||||
"""Verify that a file exists in the storage bucket"""
|
||||
logger.info(f"=== Verifying file exists: {filename} ===")
|
||||
try:
|
||||
bucket_name = 'screenshots'
|
||||
files = client.storage.from_(bucket_name).list()
|
||||
exists = any(f['name'] == filename for f in files)
|
||||
logger.info(f"File verification result: {'exists' if exists else 'not found'}")
|
||||
return exists
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying file: {str(e)}")
|
||||
return False
|
||||
|
||||
async def save_screenshot_to_file(self) -> str:
|
||||
"""Take a screenshot and save to file, returning the path"""
|
||||
|
@ -660,20 +734,32 @@ class BrowserAutomation:
|
|||
"""Helper method to get updated browser state after any action
|
||||
Returns a tuple of (dom_state, screenshot, elements, metadata)
|
||||
"""
|
||||
logger.info(f"=== Starting get_updated_browser_state for action: {action_name} ===")
|
||||
try:
|
||||
# Wait a moment for any potential async processes to settle
|
||||
logger.info("Waiting for async processes to settle")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get updated state
|
||||
logger.info("Getting current DOM state")
|
||||
dom_state = await self.get_current_dom_state()
|
||||
logger.info(f"DOM state retrieved - URL: {dom_state.url}, Title: {dom_state.title}")
|
||||
|
||||
logger.info("Taking screenshot")
|
||||
screenshot = await self.take_screenshot()
|
||||
logger.info(f"Screenshot result type: {'dict' if isinstance(screenshot, dict) else 'base64 string'}")
|
||||
if isinstance(screenshot, dict) and screenshot.get("url"):
|
||||
logger.info(f"Screenshot URL: {screenshot['url']}")
|
||||
|
||||
# Format elements for output
|
||||
logger.info("Formatting clickable elements")
|
||||
elements = dom_state.element_tree.clickable_elements_to_string(
|
||||
include_attributes=self.include_attributes
|
||||
)
|
||||
logger.info(f"Found {len(dom_state.selector_map)} clickable elements")
|
||||
|
||||
# Collect additional metadata
|
||||
logger.info("Collecting metadata")
|
||||
page = await self.get_current_page()
|
||||
metadata = {}
|
||||
|
||||
|
@ -699,8 +785,9 @@ class BrowserAutomation:
|
|||
|
||||
metadata['interactive_elements'] = interactive_elements
|
||||
|
||||
# Get viewport dimensions - Fix syntax error in JavaScript
|
||||
# Get viewport dimensions
|
||||
try:
|
||||
logger.info("Getting viewport dimensions")
|
||||
viewport = await page.evaluate("""
|
||||
() => {
|
||||
return {
|
||||
|
@ -711,33 +798,43 @@ class BrowserAutomation:
|
|||
""")
|
||||
metadata['viewport_width'] = viewport.get('width', 0)
|
||||
metadata['viewport_height'] = viewport.get('height', 0)
|
||||
logger.info(f"Viewport dimensions: {metadata['viewport_width']}x{metadata['viewport_height']}")
|
||||
except Exception as e:
|
||||
print(f"Error getting viewport dimensions: {e}")
|
||||
logger.error(f"Error getting viewport dimensions: {e}")
|
||||
metadata['viewport_width'] = 0
|
||||
metadata['viewport_height'] = 0
|
||||
|
||||
# Extract OCR text from screenshot if available
|
||||
ocr_text = ""
|
||||
if screenshot:
|
||||
logger.info("Extracting OCR text from screenshot")
|
||||
ocr_text = await self.extract_ocr_text_from_screenshot(screenshot)
|
||||
metadata['ocr_text'] = ocr_text
|
||||
logger.info(f"OCR text length: {len(ocr_text)} characters")
|
||||
|
||||
print(f"Got updated state after {action_name}: {len(dom_state.selector_map)} elements")
|
||||
logger.info(f"=== Completed get_updated_browser_state for {action_name} ===")
|
||||
return dom_state, screenshot, elements, metadata
|
||||
except Exception as e:
|
||||
print(f"Error getting updated state after {action_name}: {e}")
|
||||
logger.error(f"Error in get_updated_browser_state for {action_name}: {e}")
|
||||
traceback.print_exc()
|
||||
# Return empty values in case of error
|
||||
return None, "", "", {}
|
||||
|
||||
def build_action_result(self, success: bool, message: str, dom_state, screenshot: str,
|
||||
elements: str, metadata: dict, error: str = "", content: str = None,
|
||||
fallback_url: str = None) -> BrowserActionResult:
|
||||
elements: str, metadata: dict, error: str = "", content: str = None,
|
||||
fallback_url: str = None) -> BrowserActionResult:
|
||||
"""Helper method to build a consistent BrowserActionResult"""
|
||||
# Ensure elements is never None to avoid display issues
|
||||
if elements is None:
|
||||
elements = ""
|
||||
|
||||
screenshot_base64 = None
|
||||
screenshot_url = None
|
||||
|
||||
if isinstance(screenshot, dict) and screenshot.get("is_s3"):
|
||||
screenshot_url = screenshot.get("url")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
|
||||
return BrowserActionResult(
|
||||
success=success,
|
||||
message=message,
|
||||
|
@ -745,7 +842,8 @@ class BrowserAutomation:
|
|||
url=dom_state.url if dom_state else fallback_url or "",
|
||||
title=dom_state.title if dom_state else "",
|
||||
elements=elements,
|
||||
screenshot_base64=screenshot,
|
||||
screenshot_base64=screenshot_base64,
|
||||
screenshot_url=screenshot_url,
|
||||
pixels_above=dom_state.pixels_above if dom_state else 0,
|
||||
pixels_below=dom_state.pixels_below if dom_state else 0,
|
||||
content=content,
|
||||
|
|
|
@ -100,22 +100,23 @@ def setup_logger(name: str = 'agentpress') -> logging.Logger:
|
|||
except Exception as e:
|
||||
print(f"Error setting up file handler: {e}")
|
||||
|
||||
# Console handler - WARNING in production, INFO in other environments
|
||||
# Console handler - WARNING in production, DEBUG in other environments
|
||||
try:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
if config.ENV_MODE == EnvMode.PRODUCTION:
|
||||
console_handler.setLevel(logging.WARNING)
|
||||
else:
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
|
||||
console_formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(message)s'
|
||||
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Add console handler to logger
|
||||
logger.addHandler(console_handler)
|
||||
print(f"Added console handler with level: {console_handler.level}")
|
||||
logger.info(f"Added console handler with level: {console_handler.level}")
|
||||
logger.info(f"Log file will be created at: {log_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error setting up console handler: {e}")
|
||||
|
||||
|
|
|
@ -12,6 +12,20 @@ services:
|
|||
timeout: 5s
|
||||
retries: 3
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq
|
||||
# ports:
|
||||
# - "127.0.0.1:5672:5672"
|
||||
volumes:
|
||||
- rabbitmq_data:/var/lib/rabbitmq
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "rabbitmq-diagnostics", "-q", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
|
@ -29,6 +43,8 @@ services:
|
|||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
|
||||
frontend:
|
||||
build:
|
||||
|
@ -45,4 +61,5 @@ services:
|
|||
- backend
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
redis-data:
|
||||
rabbitmq_data:
|
|
@ -3,37 +3,19 @@
|
|||
import React, {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react';
|
||||
import Image from 'next/image';
|
||||
import { useRouter, useSearchParams } from 'next/navigation';
|
||||
import {
|
||||
ArrowDown,
|
||||
CheckCircle,
|
||||
CircleDashed,
|
||||
AlertTriangle,
|
||||
Info,
|
||||
File,
|
||||
ChevronRight,
|
||||
} from 'lucide-react';
|
||||
import {
|
||||
addUserMessage,
|
||||
startAgent,
|
||||
stopAgent,
|
||||
getAgentRuns,
|
||||
getMessages,
|
||||
getProject,
|
||||
getThread,
|
||||
updateProject,
|
||||
BillingError,
|
||||
Project,
|
||||
Message as BaseApiMessageType,
|
||||
BillingError,
|
||||
checkBillingStatus,
|
||||
} from '@/lib/api';
|
||||
import { toast } from 'sonner';
|
||||
import { Skeleton } from '@/components/ui/skeleton';
|
||||
import { ChatInput } from '@/components/thread/chat-input/chat-input';
|
||||
import { FileViewerModal } from '@/components/thread/file-viewer-modal';
|
||||
import { SiteHeader } from '@/components/thread/thread-site-header';
|
||||
|
@ -58,7 +40,11 @@ import {
|
|||
import {
|
||||
safeJsonParse,
|
||||
} from '@/components/thread/utils';
|
||||
|
||||
import { useThreadQuery } from '@/hooks/react-query/threads/use-threads';
|
||||
import { useAddUserMessageMutation, useMessagesQuery } from '@/hooks/react-query/threads/use-messages';
|
||||
import { useProjectQuery } from '@/hooks/react-query/threads/use-project';
|
||||
import { useAgentRunsQuery, useStartAgentMutation, useStopAgentMutation } from '@/hooks/react-query/threads/use-agent-run';
|
||||
import { useBillingStatusQuery } from '@/hooks/react-query/threads/use-billing-status';
|
||||
|
||||
// Extend the base Message type with the expected database fields
|
||||
interface ApiMessageType extends BaseApiMessageType {
|
||||
|
@ -136,6 +122,18 @@ export default function ThreadPage({
|
|||
// Add debug mode state - check for debug=true in URL
|
||||
const [debugMode, setDebugMode] = useState(false);
|
||||
|
||||
const threadQuery = useThreadQuery(threadId);
|
||||
const messagesQuery = useMessagesQuery(threadId);
|
||||
const projectId = threadQuery.data?.project_id || '';
|
||||
const projectQuery = useProjectQuery(projectId);
|
||||
const agentRunsQuery = useAgentRunsQuery(threadId);
|
||||
const billingStatusQuery = useBillingStatusQuery();
|
||||
|
||||
const addUserMessageMutation = useAddUserMessageMutation();
|
||||
const startAgentMutation = useStartAgentMutation();
|
||||
const stopAgentMutation = useStopAgentMutation();
|
||||
|
||||
|
||||
const handleProjectRenamed = useCallback((newName: string) => {
|
||||
setProjectName(newName);
|
||||
}, []);
|
||||
|
@ -356,89 +354,86 @@ export default function ThreadPage({
|
|||
messagesEndRef.current?.scrollIntoView({ behavior });
|
||||
};
|
||||
|
||||
// Effect to load initial data using React Query
|
||||
useEffect(() => {
|
||||
let isMounted = true;
|
||||
|
||||
async function loadData() {
|
||||
async function initializeData() {
|
||||
if (!initialLoadCompleted.current) setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
if (!threadId) throw new Error('Thread ID is required');
|
||||
|
||||
const threadData = await getThread(threadId).catch((err) => {
|
||||
throw new Error('Failed to load thread data: ' + err.message);
|
||||
});
|
||||
// Check if we have thread data
|
||||
if (threadQuery.isError) {
|
||||
throw new Error('Failed to load thread data: ' + threadQuery.error);
|
||||
}
|
||||
|
||||
if (!isMounted) return;
|
||||
|
||||
if (threadData?.project_id) {
|
||||
const projectData = await getProject(threadData.project_id);
|
||||
if (isMounted && projectData) {
|
||||
// Set project data
|
||||
setProject(projectData);
|
||||
// Process project data when available
|
||||
if (projectQuery.data) {
|
||||
// Set project data
|
||||
setProject(projectQuery.data);
|
||||
|
||||
// Make sure sandbox ID is set correctly
|
||||
if (typeof projectData.sandbox === 'string') {
|
||||
setSandboxId(projectData.sandbox);
|
||||
} else if (projectData.sandbox?.id) {
|
||||
setSandboxId(projectData.sandbox.id);
|
||||
}
|
||||
// Make sure sandbox ID is set correctly
|
||||
if (typeof projectQuery.data.sandbox === 'string') {
|
||||
setSandboxId(projectQuery.data.sandbox);
|
||||
} else if (projectQuery.data.sandbox?.id) {
|
||||
setSandboxId(projectQuery.data.sandbox.id);
|
||||
}
|
||||
|
||||
setProjectName(projectData.name || '');
|
||||
setProjectName(projectQuery.data.name || '');
|
||||
}
|
||||
|
||||
// Process messages data when available
|
||||
if (messagesQuery.data && !messagesLoadedRef.current) {
|
||||
// Map API message type to UnifiedMessage type
|
||||
const unifiedMessages = (messagesQuery.data || [])
|
||||
.filter((msg) => msg.type !== 'status')
|
||||
.map((msg: ApiMessageType) => ({
|
||||
message_id: msg.message_id || null,
|
||||
thread_id: msg.thread_id || threadId,
|
||||
type: (msg.type || 'system') as UnifiedMessage['type'],
|
||||
is_llm_message: Boolean(msg.is_llm_message),
|
||||
content: msg.content || '',
|
||||
metadata: msg.metadata || '{}',
|
||||
created_at: msg.created_at || new Date().toISOString(),
|
||||
updated_at: msg.updated_at || new Date().toISOString(),
|
||||
}));
|
||||
|
||||
setMessages(unifiedMessages);
|
||||
console.log('[PAGE] Loaded Messages (excluding status, keeping browser_state):', unifiedMessages.length);
|
||||
messagesLoadedRef.current = true;
|
||||
|
||||
if (!hasInitiallyScrolled.current) {
|
||||
scrollToBottom('auto');
|
||||
hasInitiallyScrolled.current = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!messagesLoadedRef.current) {
|
||||
const messagesData = await getMessages(threadId);
|
||||
if (isMounted) {
|
||||
// Map API message type to UnifiedMessage type
|
||||
const unifiedMessages = (messagesData || [])
|
||||
.filter((msg) => msg.type !== 'status')
|
||||
.map((msg: ApiMessageType) => ({
|
||||
message_id: msg.message_id || null,
|
||||
thread_id: msg.thread_id || threadId,
|
||||
type: (msg.type || 'system') as UnifiedMessage['type'],
|
||||
is_llm_message: Boolean(msg.is_llm_message),
|
||||
content: msg.content || '',
|
||||
metadata: msg.metadata || '{}',
|
||||
created_at: msg.created_at || new Date().toISOString(),
|
||||
updated_at: msg.updated_at || new Date().toISOString(),
|
||||
}));
|
||||
// Check for active agent runs
|
||||
if (agentRunsQuery.data && !agentRunsCheckedRef.current && isMounted) {
|
||||
console.log('[PAGE] Checking for active agent runs...');
|
||||
agentRunsCheckedRef.current = true;
|
||||
|
||||
setMessages(unifiedMessages);
|
||||
console.log('[PAGE] Loaded Messages (excluding status, keeping browser_state):', unifiedMessages.length);
|
||||
messagesLoadedRef.current = true;
|
||||
|
||||
if (!hasInitiallyScrolled.current) {
|
||||
scrollToBottom('auto');
|
||||
hasInitiallyScrolled.current = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!agentRunsCheckedRef.current && isMounted) {
|
||||
try {
|
||||
console.log('[PAGE] Checking for active agent runs...');
|
||||
const agentRuns = await getAgentRuns(threadId);
|
||||
agentRunsCheckedRef.current = true;
|
||||
|
||||
const activeRun = agentRuns.find((run) => run.status === 'running');
|
||||
if (activeRun && isMounted) {
|
||||
console.log('[PAGE] Found active run on load:', activeRun.id);
|
||||
setAgentRunId(activeRun.id);
|
||||
} else {
|
||||
console.log('[PAGE] No active agent runs found');
|
||||
if (isMounted) setAgentStatus('idle');
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[PAGE] Error checking for active runs:', err);
|
||||
agentRunsCheckedRef.current = true;
|
||||
const activeRun = agentRunsQuery.data.find((run) => run.status === 'running');
|
||||
if (activeRun && isMounted) {
|
||||
console.log('[PAGE] Found active run on load:', activeRun.id);
|
||||
setAgentRunId(activeRun.id);
|
||||
} else {
|
||||
console.log('[PAGE] No active agent runs found');
|
||||
if (isMounted) setAgentStatus('idle');
|
||||
}
|
||||
}
|
||||
|
||||
initialLoadCompleted.current = true;
|
||||
// Mark initialization as complete when we have the core data
|
||||
if (threadQuery.data && messagesQuery.data && agentRunsQuery.data) {
|
||||
initialLoadCompleted.current = true;
|
||||
setIsLoading(false);
|
||||
}
|
||||
|
||||
} catch (err) {
|
||||
console.error('Error loading thread data:', err);
|
||||
if (isMounted) {
|
||||
|
@ -446,18 +441,27 @@ export default function ThreadPage({
|
|||
err instanceof Error ? err.message : 'Failed to load thread';
|
||||
setError(errorMessage);
|
||||
toast.error(errorMessage);
|
||||
setIsLoading(false);
|
||||
}
|
||||
} finally {
|
||||
if (isMounted) setIsLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
loadData();
|
||||
if (threadId) {
|
||||
initializeData();
|
||||
}
|
||||
|
||||
return () => {
|
||||
isMounted = false;
|
||||
};
|
||||
}, [threadId]);
|
||||
}, [
|
||||
threadId,
|
||||
threadQuery.data,
|
||||
threadQuery.isError,
|
||||
threadQuery.error,
|
||||
projectQuery.data,
|
||||
messagesQuery.data,
|
||||
agentRunsQuery.data
|
||||
]);
|
||||
|
||||
const handleSubmitMessage = useCallback(
|
||||
async (
|
||||
|
@ -483,10 +487,18 @@ export default function ThreadPage({
|
|||
scrollToBottom('smooth');
|
||||
|
||||
try {
|
||||
const results = await Promise.allSettled([
|
||||
addUserMessage(threadId, message),
|
||||
startAgent(threadId, options),
|
||||
]);
|
||||
// Use React Query mutations instead of direct API calls
|
||||
const messagePromise = addUserMessageMutation.mutateAsync({
|
||||
threadId,
|
||||
message
|
||||
});
|
||||
|
||||
const agentPromise = startAgentMutation.mutateAsync({
|
||||
threadId,
|
||||
options
|
||||
});
|
||||
|
||||
const results = await Promise.allSettled([messagePromise, agentPromise]);
|
||||
|
||||
// Handle failure to add the user message
|
||||
if (results[0].status === 'rejected') {
|
||||
|
@ -525,6 +537,11 @@ export default function ThreadPage({
|
|||
// If agent started successfully
|
||||
const agentResult = results[1].value;
|
||||
setAgentRunId(agentResult.agent_run_id);
|
||||
|
||||
// Refresh queries after successful operations
|
||||
messagesQuery.refetch();
|
||||
agentRunsQuery.refetch();
|
||||
|
||||
} catch (err) {
|
||||
// Catch errors from addUserMessage or non-BillingError agent start errors
|
||||
console.error('Error sending message or starting agent:', err);
|
||||
|
@ -540,8 +557,8 @@ export default function ThreadPage({
|
|||
setIsSending(false);
|
||||
}
|
||||
},
|
||||
[threadId, project?.account_id],
|
||||
); // Ensure project.account_id is a dependency
|
||||
[threadId, project?.account_id, addUserMessageMutation, startAgentMutation, messagesQuery, agentRunsQuery],
|
||||
);
|
||||
|
||||
const handleStopAgent = useCallback(async () => {
|
||||
console.log(`[PAGE] Requesting agent stop via hook.`);
|
||||
|
@ -549,11 +566,18 @@ export default function ThreadPage({
|
|||
|
||||
// First stop the streaming and let the hook handle refetching
|
||||
await stopStreaming();
|
||||
|
||||
// We don't need to refetch messages here since the hook will do that
|
||||
// The centralizing of refetching in the hook simplifies this logic
|
||||
}, [stopStreaming]);
|
||||
|
||||
|
||||
// Use React Query's stopAgentMutation if we have an agent run ID
|
||||
if (agentRunId) {
|
||||
try {
|
||||
await stopAgentMutation.mutateAsync(agentRunId);
|
||||
// Refresh agent runs after stopping
|
||||
agentRunsQuery.refetch();
|
||||
} catch (error) {
|
||||
console.error('Error stopping agent:', error);
|
||||
}
|
||||
}
|
||||
}, [stopStreaming, agentRunId, stopAgentMutation, agentRunsQuery]);
|
||||
|
||||
useEffect(() => {
|
||||
const lastMsg = messages[messages.length - 1];
|
||||
|
@ -561,7 +585,7 @@ export default function ThreadPage({
|
|||
if ((isNewUserMessage || agentStatus === 'running') && !userHasScrolled) {
|
||||
scrollToBottom('smooth');
|
||||
}
|
||||
}, [messages, agentStatus, userHasScrolled, scrollToBottom]);
|
||||
}, [messages, agentStatus, userHasScrolled]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!latestMessageRef.current || messages.length === 0) return;
|
||||
|
@ -571,7 +595,7 @@ export default function ThreadPage({
|
|||
);
|
||||
observer.observe(latestMessageRef.current);
|
||||
return () => observer.disconnect();
|
||||
}, [messages, streamingTextContent, streamingToolCall, setShowScrollButton]);
|
||||
}, [messages, streamingTextContent, streamingToolCall]);
|
||||
|
||||
useEffect(() => {
|
||||
console.log(`[PAGE] 🔄 Page AgentStatus: ${agentStatus}, Hook Status: ${streamHookStatus}, Target RunID: ${agentRunId || 'none'}, Hook RunID: ${currentHookRunId || 'none'}`);
|
||||
|
@ -930,57 +954,34 @@ export default function ThreadPage({
|
|||
}
|
||||
}, [projectName]);
|
||||
|
||||
// Add another useEffect to ensure messages are refreshed when agent status changes to idle
|
||||
// Update messages when they change in the query
|
||||
useEffect(() => {
|
||||
if (
|
||||
agentStatus === 'idle' &&
|
||||
streamHookStatus !== 'streaming' &&
|
||||
streamHookStatus !== 'connecting'
|
||||
) {
|
||||
console.log(
|
||||
'[PAGE] Agent status changed to idle, ensuring messages are up to date',
|
||||
);
|
||||
// Only do this if we're not in the initial loading state
|
||||
if (!isLoading && initialLoadCompleted.current) {
|
||||
// Double-check messages after a short delay to ensure we have latest content
|
||||
const timer = setTimeout(() => {
|
||||
getMessages(threadId)
|
||||
.then((messagesData) => {
|
||||
if (messagesData) {
|
||||
console.log(
|
||||
`[PAGE] Backup refetch completed with ${messagesData.length} messages`,
|
||||
);
|
||||
// Map API message type to UnifiedMessage type
|
||||
const unifiedMessages = (messagesData || [])
|
||||
.filter((msg) => msg.type !== 'status')
|
||||
.map((msg: ApiMessageType) => ({
|
||||
message_id: msg.message_id || null,
|
||||
thread_id: msg.thread_id || threadId,
|
||||
type: (msg.type || 'system') as UnifiedMessage['type'],
|
||||
is_llm_message: Boolean(msg.is_llm_message),
|
||||
content: msg.content || '',
|
||||
metadata: msg.metadata || '{}',
|
||||
created_at: msg.created_at || new Date().toISOString(),
|
||||
updated_at: msg.updated_at || new Date().toISOString(),
|
||||
}));
|
||||
if (messagesQuery.data && messagesQuery.status === 'success') {
|
||||
// Only update if we're not in initial loading and the agent isn't running
|
||||
if (!isLoading && agentStatus !== 'running' && agentStatus !== 'connecting') {
|
||||
// Map API message type to UnifiedMessage type
|
||||
const unifiedMessages = (messagesQuery.data || [])
|
||||
.filter((msg) => msg.type !== 'status')
|
||||
.map((msg: ApiMessageType) => ({
|
||||
message_id: msg.message_id || null,
|
||||
thread_id: msg.thread_id || threadId,
|
||||
type: (msg.type || 'system') as UnifiedMessage['type'],
|
||||
is_llm_message: Boolean(msg.is_llm_message),
|
||||
content: msg.content || '',
|
||||
metadata: msg.metadata || '{}',
|
||||
created_at: msg.created_at || new Date().toISOString(),
|
||||
updated_at: msg.updated_at || new Date().toISOString(),
|
||||
}));
|
||||
|
||||
setMessages(unifiedMessages);
|
||||
// Reset auto-opened panel to allow tool detection with fresh messages
|
||||
setAutoOpenedPanel(false);
|
||||
scrollToBottom('smooth');
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('Error in backup message refetch:', err);
|
||||
});
|
||||
}, 1000);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
setMessages(unifiedMessages);
|
||||
// Reset auto-opened panel to allow tool detection with fresh messages
|
||||
setAutoOpenedPanel(false);
|
||||
scrollToBottom('smooth');
|
||||
}
|
||||
}
|
||||
}, [agentStatus, threadId, isLoading, streamHookStatus]);
|
||||
}, [messagesQuery.data, messagesQuery.status, isLoading, agentStatus, threadId]);
|
||||
|
||||
// Update the checkBillingStatus function
|
||||
// Check billing status and handle billing limit
|
||||
const checkBillingLimits = useCallback(async () => {
|
||||
// Skip billing checks in local development mode
|
||||
if (isLocalMode()) {
|
||||
|
@ -991,9 +992,11 @@ export default function ThreadPage({
|
|||
}
|
||||
|
||||
try {
|
||||
const result = await checkBillingStatus();
|
||||
// Use React Query to get billing status
|
||||
await billingStatusQuery.refetch();
|
||||
const result = billingStatusQuery.data;
|
||||
|
||||
if (!result.can_run) {
|
||||
if (result && !result.can_run) {
|
||||
setBillingData({
|
||||
currentUsage: result.subscription?.minutes_limit || 0,
|
||||
limit: result.subscription?.minutes_limit || 0,
|
||||
|
@ -1008,9 +1011,9 @@ export default function ThreadPage({
|
|||
console.error('Error checking billing status:', err);
|
||||
return false;
|
||||
}
|
||||
}, [project?.account_id]);
|
||||
}, [project?.account_id, billingStatusQuery]);
|
||||
|
||||
// Update useEffect to use the renamed function
|
||||
// Check billing when agent status changes
|
||||
useEffect(() => {
|
||||
const previousStatus = previousAgentStatus.current;
|
||||
|
||||
|
@ -1023,7 +1026,7 @@ export default function ThreadPage({
|
|||
previousAgentStatus.current = agentStatus;
|
||||
}, [agentStatus, checkBillingLimits]);
|
||||
|
||||
// Update other useEffect to use the renamed function
|
||||
// Check billing on initial load
|
||||
useEffect(() => {
|
||||
if (project?.account_id && initialLoadCompleted.current) {
|
||||
console.log('Checking billing status on page load');
|
||||
|
@ -1031,7 +1034,7 @@ export default function ThreadPage({
|
|||
}
|
||||
}, [project?.account_id, checkBillingLimits, initialLoadCompleted]);
|
||||
|
||||
// Update the last useEffect to use the renamed function
|
||||
// Check billing after messages loaded
|
||||
useEffect(() => {
|
||||
if (messagesLoadedRef.current && project?.account_id && !isLoading) {
|
||||
console.log('Checking billing status after messages loaded');
|
||||
|
@ -1085,10 +1088,10 @@ export default function ThreadPage({
|
|||
? 'This thread either does not exist or you do not have access to it.'
|
||||
: error
|
||||
}
|
||||
</p >
|
||||
</div >
|
||||
</div >
|
||||
</div >
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ToolCallSidePanel
|
||||
isOpen={isSidePanelOpen && initialLoadCompleted.current}
|
||||
onClose={() => {
|
||||
|
@ -1128,7 +1131,7 @@ export default function ThreadPage({
|
|||
onDismiss={() => setShowBillingAlert(false)}
|
||||
isOpen={showBillingAlert}
|
||||
/>
|
||||
</div >
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
|
|
|
@ -58,7 +58,7 @@ export const PaywallDialog: React.FC<PaywallDialogProps> = ({
|
|||
strayBackdrops.forEach(element => element.remove());
|
||||
};
|
||||
}, []);
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
document.body.classList.remove('overflow-hidden');
|
||||
|
|
|
@ -84,6 +84,7 @@ export function BrowserToolView({
|
|||
browserStateMessage.content,
|
||||
{},
|
||||
);
|
||||
console.log('Browser state content: ', browserStateContent)
|
||||
screenshotBase64 = browserStateContent?.screenshot_base64 || null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
import { createQueryKeys } from "@/hooks/use-query";
|
||||
|
||||
export const threadKeys = createQueryKeys({
|
||||
all: ['threads'] as const,
|
||||
details: (threadId: string) => ['thread', threadId] as const,
|
||||
messages: (threadId: string) => ['thread', threadId, 'messages'] as const,
|
||||
project: (projectId: string) => ['project', projectId] as const,
|
||||
agentRuns: (threadId: string) => ['thread', threadId, 'agent-runs'] as const,
|
||||
billingStatus: ['billing', 'status'] as const,
|
||||
});
|
|
@ -0,0 +1,39 @@
|
|||
import { createMutationHook, createQueryHook } from "@/hooks/use-query";
|
||||
import { threadKeys } from "./keys";
|
||||
import { BillingError, getAgentRuns, startAgent, stopAgent } from "@/lib/api";
|
||||
|
||||
export const useAgentRunsQuery = (threadId: string) =>
|
||||
createQueryHook(
|
||||
threadKeys.agentRuns(threadId),
|
||||
() => getAgentRuns(threadId),
|
||||
{
|
||||
enabled: !!threadId,
|
||||
retry: 1,
|
||||
}
|
||||
)();
|
||||
|
||||
export const useStartAgentMutation = () =>
|
||||
createMutationHook(
|
||||
({
|
||||
threadId,
|
||||
options,
|
||||
}: {
|
||||
threadId: string;
|
||||
options?: {
|
||||
model_name?: string;
|
||||
enable_thinking?: boolean;
|
||||
reasoning_effort?: string;
|
||||
stream?: boolean;
|
||||
};
|
||||
}) => startAgent(threadId, options),
|
||||
{
|
||||
onError: (error) => {
|
||||
if (!(error instanceof BillingError)) {
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
}
|
||||
)();
|
||||
|
||||
export const useStopAgentMutation = () =>
|
||||
createMutationHook((agentRunId: string) => stopAgent(agentRunId))();
|
|
@ -0,0 +1,14 @@
|
|||
import { createQueryHook } from "@/hooks/use-query";
|
||||
import { threadKeys } from "./keys";
|
||||
import { checkBillingStatus } from "@/lib/api";
|
||||
|
||||
export const useBillingStatusQuery = (enabled = true) =>
|
||||
createQueryHook(
|
||||
threadKeys.billingStatus,
|
||||
() => checkBillingStatus(),
|
||||
{
|
||||
enabled,
|
||||
retry: 1,
|
||||
staleTime: 1000 * 60 * 5,
|
||||
}
|
||||
)();
|
|
@ -0,0 +1,24 @@
|
|||
import { createMutationHook, createQueryHook } from "@/hooks/use-query";
|
||||
import { threadKeys } from "./keys";
|
||||
import { addUserMessage, getMessages } from "@/lib/api";
|
||||
|
||||
export const useMessagesQuery = (threadId: string) =>
|
||||
createQueryHook(
|
||||
threadKeys.messages(threadId),
|
||||
() => getMessages(threadId),
|
||||
{
|
||||
enabled: !!threadId,
|
||||
retry: 1,
|
||||
}
|
||||
)();
|
||||
|
||||
export const useAddUserMessageMutation = () =>
|
||||
createMutationHook(
|
||||
({
|
||||
threadId,
|
||||
message,
|
||||
}: {
|
||||
threadId: string;
|
||||
message: string;
|
||||
}) => addUserMessage(threadId, message)
|
||||
)();
|
|
@ -0,0 +1,27 @@
|
|||
import { createMutationHook, createQueryHook } from "@/hooks/use-query";
|
||||
import { threadKeys } from "./keys";
|
||||
import { getProject, Project, updateProject } from "@/lib/api";
|
||||
|
||||
export const useProjectQuery = (projectId: string | undefined) =>
|
||||
createQueryHook(
|
||||
threadKeys.project(projectId || ""),
|
||||
() =>
|
||||
projectId
|
||||
? getProject(projectId)
|
||||
: Promise.reject("No project ID"),
|
||||
{
|
||||
enabled: !!projectId,
|
||||
retry: 1,
|
||||
}
|
||||
)();
|
||||
|
||||
export const useUpdateProjectMutation = () =>
|
||||
createMutationHook(
|
||||
({
|
||||
projectId,
|
||||
data,
|
||||
}: {
|
||||
projectId: string;
|
||||
data: Partial<Project>;
|
||||
}) => updateProject(projectId, data)
|
||||
)();
|
|
@ -0,0 +1,13 @@
|
|||
import { createQueryHook } from "@/hooks/use-query";
|
||||
import { threadKeys } from "./keys";
|
||||
import { getThread } from "@/lib/api";
|
||||
|
||||
export const useThreadQuery = (threadId: string) =>
|
||||
createQueryHook(
|
||||
threadKeys.details(threadId),
|
||||
() => getThread(threadId),
|
||||
{
|
||||
enabled: !!threadId,
|
||||
retry: 1,
|
||||
}
|
||||
)();
|
Loading…
Reference in New Issue