mirror of https://github.com/kortix-ai/suna.git
Merge branch 'main' into feat/ux
This commit is contained in:
commit
f33f9bb14e
|
@ -154,7 +154,6 @@ async def run_agent(
|
||||||
else:
|
else:
|
||||||
logger.warning("Browser state found but no screenshot data.")
|
logger.warning("Browser state found but no screenshot data.")
|
||||||
|
|
||||||
await client.table('messages').delete().eq('message_id', latest_browser_state_msg.data[0]["message_id"]).execute()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error parsing browser state: {e}")
|
logger.error(f"Error parsing browser state: {e}")
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from agentpress.tool import ToolResult, openapi_schema, xml_schema
|
||||||
from agentpress.thread_manager import ThreadManager
|
from agentpress.thread_manager import ThreadManager
|
||||||
from sandbox.tool_base import SandboxToolsBase
|
from sandbox.tool_base import SandboxToolsBase
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
|
from utils.s3_upload_utils import upload_base64_image
|
||||||
|
|
||||||
|
|
||||||
class SandboxBrowserTool(SandboxToolsBase):
|
class SandboxBrowserTool(SandboxToolsBase):
|
||||||
|
@ -30,7 +31,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
||||||
await self._ensure_sandbox()
|
await self._ensure_sandbox()
|
||||||
|
|
||||||
# Build the curl command
|
# Build the curl command
|
||||||
url = f"http://localhost:8002/api/automation/{endpoint}"
|
url = f"http://localhost:8003/api/automation/{endpoint}"
|
||||||
|
|
||||||
if method == "GET" and params:
|
if method == "GET" and params:
|
||||||
query_params = "&".join([f"{k}={v}" for k, v in params.items()])
|
query_params = "&".join([f"{k}={v}" for k, v in params.items()])
|
||||||
|
@ -59,7 +60,17 @@ class SandboxBrowserTool(SandboxToolsBase):
|
||||||
|
|
||||||
logger.info("Browser automation request completed successfully")
|
logger.info("Browser automation request completed successfully")
|
||||||
|
|
||||||
# Add full result to thread messages for state tracking
|
if "screenshot_base64" in result:
|
||||||
|
try:
|
||||||
|
image_url = await upload_base64_image(result["screenshot_base64"])
|
||||||
|
result["image_url"] = image_url
|
||||||
|
# Remove base64 data from result to keep it clean
|
||||||
|
del result["screenshot_base64"]
|
||||||
|
logger.debug(f"Uploaded screenshot to {image_url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to upload screenshot: {e}")
|
||||||
|
result["image_upload_error"] = str(e)
|
||||||
|
|
||||||
added_message = await self.thread_manager.add_message(
|
added_message = await self.thread_manager.add_message(
|
||||||
thread_id=self.thread_id,
|
thread_id=self.thread_id,
|
||||||
type="browser_state",
|
type="browser_state",
|
||||||
|
@ -67,17 +78,13 @@ class SandboxBrowserTool(SandboxToolsBase):
|
||||||
is_llm_message=False
|
is_llm_message=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return tool-specific success response
|
|
||||||
success_response = {
|
success_response = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": result.get("message", "Browser action completed successfully")
|
"message": result.get("message", "Browser action completed successfully")
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add message ID if available
|
|
||||||
if added_message and 'message_id' in added_message:
|
if added_message and 'message_id' in added_message:
|
||||||
success_response['message_id'] = added_message['message_id']
|
success_response['message_id'] = added_message['message_id']
|
||||||
|
|
||||||
# Add relevant browser-specific info
|
|
||||||
if result.get("url"):
|
if result.get("url"):
|
||||||
success_response["url"] = result["url"]
|
success_response["url"] = result["url"]
|
||||||
if result.get("title"):
|
if result.get("title"):
|
||||||
|
@ -86,9 +93,10 @@ class SandboxBrowserTool(SandboxToolsBase):
|
||||||
success_response["elements_found"] = result["element_count"]
|
success_response["elements_found"] = result["element_count"]
|
||||||
if result.get("pixels_below"):
|
if result.get("pixels_below"):
|
||||||
success_response["scrollable_content"] = result["pixels_below"] > 0
|
success_response["scrollable_content"] = result["pixels_below"] > 0
|
||||||
# Add OCR text when available
|
|
||||||
if result.get("ocr_text"):
|
if result.get("ocr_text"):
|
||||||
success_response["ocr_text"] = result["ocr_text"]
|
success_response["ocr_text"] = result["ocr_text"]
|
||||||
|
if result.get("image_url"):
|
||||||
|
success_response["image_url"] = result["image_url"]
|
||||||
|
|
||||||
return self.success_response(success_response)
|
return self.success_response(success_response)
|
||||||
|
|
||||||
|
@ -104,6 +112,7 @@ class SandboxBrowserTool(SandboxToolsBase):
|
||||||
logger.debug(traceback.format_exc())
|
logger.debug(traceback.format_exc())
|
||||||
return self.fail_response(f"Error executing browser action: {e}")
|
return self.fail_response(f"Error executing browser action: {e}")
|
||||||
|
|
||||||
|
|
||||||
@openapi_schema({
|
@openapi_schema({
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
|
|
|
@ -978,7 +978,7 @@ class ResponseProcessor:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
params[mapping.param_name] = value
|
params[mapping.param_name] = value
|
||||||
parsing_details["attributes"][mapping.param_name] = value # Store raw attribute
|
parsing_details["attributes"][mapping.param_name] = value # Store raw attribute
|
||||||
logger.info(f"Found attribute {mapping.param_name}: {value}")
|
# logger.info(f"Found attribute {mapping.param_name}: {value}")
|
||||||
|
|
||||||
elif mapping.node_type == "element":
|
elif mapping.node_type == "element":
|
||||||
# Extract element content
|
# Extract element content
|
||||||
|
@ -986,7 +986,7 @@ class ResponseProcessor:
|
||||||
if content is not None:
|
if content is not None:
|
||||||
params[mapping.param_name] = content.strip()
|
params[mapping.param_name] = content.strip()
|
||||||
parsing_details["elements"][mapping.param_name] = content.strip() # Store raw element content
|
parsing_details["elements"][mapping.param_name] = content.strip() # Store raw element content
|
||||||
logger.info(f"Found element {mapping.param_name}: {content.strip()}")
|
# logger.info(f"Found element {mapping.param_name}: {content.strip()}")
|
||||||
|
|
||||||
elif mapping.node_type == "text":
|
elif mapping.node_type == "text":
|
||||||
# Extract text content
|
# Extract text content
|
||||||
|
@ -994,7 +994,7 @@ class ResponseProcessor:
|
||||||
if content is not None:
|
if content is not None:
|
||||||
params[mapping.param_name] = content.strip()
|
params[mapping.param_name] = content.strip()
|
||||||
parsing_details["text_content"] = content.strip() # Store raw text content
|
parsing_details["text_content"] = content.strip() # Store raw text content
|
||||||
logger.info(f"Found text content for {mapping.param_name}: {content.strip()}")
|
# logger.info(f"Found text content for {mapping.param_name}: {content.strip()}")
|
||||||
|
|
||||||
elif mapping.node_type == "content":
|
elif mapping.node_type == "content":
|
||||||
# Extract root content
|
# Extract root content
|
||||||
|
@ -1002,7 +1002,7 @@ class ResponseProcessor:
|
||||||
if content is not None:
|
if content is not None:
|
||||||
params[mapping.param_name] = content.strip()
|
params[mapping.param_name] = content.strip()
|
||||||
parsing_details["root_content"] = content.strip() # Store raw root content
|
parsing_details["root_content"] = content.strip() # Store raw root content
|
||||||
logger.info(f"Found root content for {mapping.param_name}")
|
# logger.info(f"Found root content for {mapping.param_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing mapping {mapping}: {e}")
|
logger.error(f"Error processing mapping {mapping}: {e}")
|
||||||
|
|
|
@ -1,7 +1,3 @@
|
||||||
# This is a Docker Compose file for the backend service. For self-hosting, look at the root docker-compose.yml file.
|
|
||||||
|
|
||||||
version: "3.8"
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
api:
|
api:
|
||||||
build:
|
build:
|
||||||
|
@ -133,10 +129,11 @@ services:
|
||||||
- "127.0.0.1:6379:6379"
|
- "127.0.0.1:6379:6379"
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
|
- ./services/docker/redis.conf:/usr/local/etc/redis/redis.conf:ro
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
command: redis-server --appendonly yes --bind 0.0.0.0 --protected-mode no --maxmemory 8gb --maxmemory-policy allkeys-lru
|
command: redis-server /usr/local/etc/redis/redis.conf --appendonly yes --bind 0.0.0.0 --protected-mode no --maxmemory 8gb --maxmemory-policy allkeys-lru
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
|
|
|
@ -19,7 +19,8 @@ You can modify the sandbox environment for development or to add new capabilitie
|
||||||
2. Build a custom image:
|
2. Build a custom image:
|
||||||
```
|
```
|
||||||
cd backend/sandbox/docker
|
cd backend/sandbox/docker
|
||||||
docker-compose build
|
docker compose build
|
||||||
|
docker push kortix/suna:0.1.2
|
||||||
```
|
```
|
||||||
3. Test your changes locally using docker-compose
|
3. Test your changes locally using docker-compose
|
||||||
|
|
||||||
|
@ -30,3 +31,15 @@ To use your custom sandbox image:
|
||||||
1. Change the `image` parameter in `docker-compose.yml` (that defines the image name `kortix/suna:___`)
|
1. Change the `image` parameter in `docker-compose.yml` (that defines the image name `kortix/suna:___`)
|
||||||
2. Update the same image name in `backend/sandbox/sandbox.py` in the `create_sandbox` function
|
2. Update the same image name in `backend/sandbox/sandbox.py` in the `create_sandbox` function
|
||||||
3. If using Daytona for deployment, update the image reference there as well
|
3. If using Daytona for deployment, update the image reference there as well
|
||||||
|
|
||||||
|
## Publishing New Versions
|
||||||
|
|
||||||
|
When publishing a new version of the sandbox:
|
||||||
|
|
||||||
|
1. Update the version number in `docker-compose.yml` (e.g., from `0.1.2` to `0.1.3`)
|
||||||
|
2. Build the new image: `docker compose build`
|
||||||
|
3. Push the new version: `docker push kortix/suna:0.1.3`
|
||||||
|
4. Update all references to the image version in:
|
||||||
|
- `backend/utils/config.py`
|
||||||
|
- Daytona images
|
||||||
|
- Any other services using this image
|
|
@ -68,6 +68,9 @@ RUN apt-get update && apt-get install -y \
|
||||||
iputils-ping \
|
iputils-ping \
|
||||||
dnsutils \
|
dnsutils \
|
||||||
sudo \
|
sudo \
|
||||||
|
# OCR Tools
|
||||||
|
tesseract-ocr \
|
||||||
|
tesseract-ocr-eng \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install Node.js and npm
|
# Install Node.js and npm
|
||||||
|
@ -113,11 +116,13 @@ ENV PYTHONUNBUFFERED=1
|
||||||
ENV CHROME_PATH=/ms-playwright/chromium-*/chrome-linux/chrome
|
ENV CHROME_PATH=/ms-playwright/chromium-*/chrome-linux/chrome
|
||||||
ENV ANONYMIZED_TELEMETRY=false
|
ENV ANONYMIZED_TELEMETRY=false
|
||||||
ENV DISPLAY=:99
|
ENV DISPLAY=:99
|
||||||
ENV RESOLUTION=1920x1080x24
|
ENV RESOLUTION=1024x768x24
|
||||||
ENV VNC_PASSWORD=vncpassword
|
ENV VNC_PASSWORD=vncpassword
|
||||||
ENV CHROME_PERSISTENT_SESSION=true
|
ENV CHROME_PERSISTENT_SESSION=true
|
||||||
ENV RESOLUTION_WIDTH=1920
|
ENV RESOLUTION_WIDTH=1024
|
||||||
ENV RESOLUTION_HEIGHT=1080
|
ENV RESOLUTION_HEIGHT=768
|
||||||
|
# Add Chrome flags to prevent multiple tabs/windows
|
||||||
|
ENV CHROME_FLAGS="--single-process --no-first-run --no-default-browser-check --disable-background-networking --disable-background-timer-throttling --disable-backgrounding-occluded-windows --disable-breakpad --disable-component-extensions-with-background-pages --disable-dev-shm-usage --disable-extensions --disable-features=TranslateUI --disable-ipc-flooding-protection --disable-renderer-backgrounding --enable-features=NetworkServiceInProcess2 --force-color-profile=srgb --metrics-recording-only --mute-audio --no-sandbox --disable-gpu"
|
||||||
|
|
||||||
# Set up supervisor configuration
|
# Set up supervisor configuration
|
||||||
RUN mkdir -p /var/log/supervisor
|
RUN mkdir -p /var/log/supervisor
|
||||||
|
|
|
@ -15,8 +15,6 @@ import traceback
|
||||||
import pytesseract
|
import pytesseract
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
from utils.logger import logger
|
|
||||||
from services.supabase import DBConnection
|
|
||||||
|
|
||||||
#######################################################
|
#######################################################
|
||||||
# Action model definitions
|
# Action model definitions
|
||||||
|
@ -261,16 +259,15 @@ class BrowserActionResult(BaseModel):
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
elements: Optional[str] = None # Formatted string of clickable elements
|
elements: Optional[str] = None # Formatted string of clickable elements
|
||||||
screenshot_base64: Optional[str] = None # For backward compatibility
|
screenshot_base64: Optional[str] = None
|
||||||
screenshot_url: Optional[str] = None
|
|
||||||
pixels_above: int = 0
|
pixels_above: int = 0
|
||||||
pixels_below: int = 0
|
pixels_below: int = 0
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
ocr_text: Optional[str] = None
|
ocr_text: Optional[str] = None # Added field for OCR text
|
||||||
|
|
||||||
# Additional metadata
|
# Additional metadata
|
||||||
element_count: int = 0
|
element_count: int = 0 # Number of interactive elements found
|
||||||
interactive_elements: Optional[List[Dict[str, Any]]] = None
|
interactive_elements: Optional[List[Dict[str, Any]]] = None # Simplified list of interactive elements
|
||||||
viewport_width: Optional[int] = None
|
viewport_width: Optional[int] = None
|
||||||
viewport_height: Optional[int] = None
|
viewport_height: Optional[int] = None
|
||||||
|
|
||||||
|
@ -291,7 +288,6 @@ class BrowserAutomation:
|
||||||
self.include_attributes = ["id", "href", "src", "alt", "aria-label", "placeholder", "name", "role", "title", "value"]
|
self.include_attributes = ["id", "href", "src", "alt", "aria-label", "placeholder", "name", "role", "title", "value"]
|
||||||
self.screenshot_dir = os.path.join(os.getcwd(), "screenshots")
|
self.screenshot_dir = os.path.join(os.getcwd(), "screenshots")
|
||||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||||
self.db = DBConnection() # Initialize DB connection
|
|
||||||
|
|
||||||
# Register routes
|
# Register routes
|
||||||
self.router.on_startup.append(self.startup)
|
self.router.on_startup.append(self.startup)
|
||||||
|
@ -360,12 +356,12 @@ class BrowserAutomation:
|
||||||
self.current_page_index = 0
|
self.current_page_index = 0
|
||||||
except Exception as page_error:
|
except Exception as page_error:
|
||||||
print(f"Error finding existing page, creating new one. ( {page_error})")
|
print(f"Error finding existing page, creating new one. ( {page_error})")
|
||||||
page = await self.browser.new_page()
|
page = await self.browser.new_page(viewport={'width': 1024, 'height': 768})
|
||||||
print("New page created successfully")
|
print("New page created successfully")
|
||||||
self.pages.append(page)
|
self.pages.append(page)
|
||||||
self.current_page_index = 0
|
self.current_page_index = 0
|
||||||
# Navigate to about:blank to ensure page is ready
|
# Navigate directly to google.com instead of about:blank
|
||||||
# await page.goto("google.com", timeout=30000)
|
await page.goto("https://www.google.com", wait_until="domcontentloaded", timeout=30000)
|
||||||
print("Navigated to google.com")
|
print("Navigated to google.com")
|
||||||
|
|
||||||
print("Browser initialization completed successfully")
|
print("Browser initialization completed successfully")
|
||||||
|
@ -603,95 +599,50 @@ class BrowserAutomation:
|
||||||
is_top_element=True
|
is_top_element=True
|
||||||
)
|
)
|
||||||
dummy_map = {1: dummy_root}
|
dummy_map = {1: dummy_root}
|
||||||
|
current_url = "unknown"
|
||||||
|
try:
|
||||||
|
if 'page' in locals():
|
||||||
|
current_url = page.url
|
||||||
|
except:
|
||||||
|
pass
|
||||||
return DOMState(
|
return DOMState(
|
||||||
element_tree=dummy_root,
|
element_tree=dummy_root,
|
||||||
selector_map=dummy_map,
|
selector_map=dummy_map,
|
||||||
url=page.url if 'page' in locals() else "about:blank",
|
url=current_url,
|
||||||
title="Error page",
|
title="Error page",
|
||||||
pixels_above=0,
|
pixels_above=0,
|
||||||
pixels_below=0
|
pixels_below=0
|
||||||
)
|
)
|
||||||
|
|
||||||
async def take_screenshot(self) -> str:
|
async def take_screenshot(self) -> str:
|
||||||
"""Take a screenshot and return as base64 encoded string or S3 URL"""
|
"""Take a screenshot and return as base64 encoded string"""
|
||||||
try:
|
try:
|
||||||
page = await self.get_current_page()
|
page = await self.get_current_page()
|
||||||
screenshot_bytes = await page.screenshot(type='jpeg', quality=60, full_page=False)
|
|
||||||
|
|
||||||
client = await self.db.client
|
# Wait for network to be idle and DOM to be stable
|
||||||
|
try:
|
||||||
if client:
|
await page.wait_for_load_state("networkidle", timeout=60000) # Increased timeout to 60s
|
||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
||||||
random_id = random.randint(1000, 9999)
|
|
||||||
filename = f"screenshot_{timestamp}_{random_id}.jpg"
|
|
||||||
|
|
||||||
logger.info(f"Attempting to upload screenshot: {filename}")
|
|
||||||
result = await self.upload_to_storage(client, screenshot_bytes, filename)
|
|
||||||
|
|
||||||
if isinstance(result, dict) and result.get("is_s3") and result.get("url"):
|
|
||||||
if await self.verify_file_exists(client, filename):
|
|
||||||
logger.info(f"Screenshot upload verified: {filename}")
|
|
||||||
else:
|
|
||||||
logger.error(f"Screenshot upload failed verification: {filename}")
|
|
||||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
|
||||||
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
logger.warning("No Supabase client available, falling back to base64")
|
|
||||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error taking screenshot: {str(e)}")
|
print(f"Warning: Network idle timeout, proceeding anyway: {e}")
|
||||||
traceback.print_exc()
|
|
||||||
return ""
|
|
||||||
|
|
||||||
async def upload_to_storage(self, client, file_bytes: bytes, filename: str) -> str:
|
# Wait for any animations to complete
|
||||||
"""Upload file to Supabase Storage and return the URL"""
|
# await page.wait_for_timeout(1000) # Wait 1 second for animations
|
||||||
try:
|
|
||||||
bucket_name = 'screenshots'
|
|
||||||
|
|
||||||
buckets = client.storage.list_buckets()
|
# Take screenshot with increased timeout and better options
|
||||||
if not any(bucket.name == bucket_name for bucket in buckets):
|
screenshot_bytes = await page.screenshot(
|
||||||
logger.info(f"Creating bucket: {bucket_name}")
|
type='jpeg',
|
||||||
try:
|
quality=60,
|
||||||
client.storage.create_bucket(bucket_name)
|
full_page=False,
|
||||||
logger.info("Bucket created successfully")
|
timeout=60000, # Increased timeout to 60s
|
||||||
except Exception as e:
|
scale='device' # Use device scale factor
|
||||||
logger.error(f"Failed to create bucket: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.info(f"Uploading file: {filename}")
|
|
||||||
try:
|
|
||||||
result = client.storage.from_(bucket_name).upload(
|
|
||||||
path=filename,
|
|
||||||
file=file_bytes,
|
|
||||||
file_options={"content-type": "image/jpeg"}
|
|
||||||
)
|
)
|
||||||
logger.info("File upload successful")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to upload file: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
file_url = client.storage.from_(bucket_name).get_public_url(filename)
|
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||||
logger.info(f"Generated URL: {file_url}")
|
|
||||||
|
|
||||||
return {"url": file_url, "is_s3": True}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in upload_to_storage: {str(e)}")
|
print(f"Error taking screenshot: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return base64.b64encode(file_bytes).decode('utf-8')
|
# Return an empty string rather than failing
|
||||||
|
return ""
|
||||||
async def verify_file_exists(self, client, filename: str) -> bool:
|
|
||||||
"""Verify that a file exists in the storage bucket"""
|
|
||||||
logger.info(f"=== Verifying file exists: {filename} ===")
|
|
||||||
try:
|
|
||||||
bucket_name = 'screenshots'
|
|
||||||
files = client.storage.from_(bucket_name).list()
|
|
||||||
exists = any(f['name'] == filename for f in files)
|
|
||||||
logger.info(f"File verification result: {'exists' if exists else 'not found'}")
|
|
||||||
return exists
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error verifying file: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def save_screenshot_to_file(self) -> str:
|
async def save_screenshot_to_file(self) -> str:
|
||||||
"""Take a screenshot and save to file, returning the path"""
|
"""Take a screenshot and save to file, returning the path"""
|
||||||
|
@ -734,32 +685,20 @@ class BrowserAutomation:
|
||||||
"""Helper method to get updated browser state after any action
|
"""Helper method to get updated browser state after any action
|
||||||
Returns a tuple of (dom_state, screenshot, elements, metadata)
|
Returns a tuple of (dom_state, screenshot, elements, metadata)
|
||||||
"""
|
"""
|
||||||
logger.info(f"=== Starting get_updated_browser_state for action: {action_name} ===")
|
|
||||||
try:
|
try:
|
||||||
# Wait a moment for any potential async processes to settle
|
# Wait a moment for any potential async processes to settle
|
||||||
logger.info("Waiting for async processes to settle")
|
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
# Get updated state
|
# Get updated state
|
||||||
logger.info("Getting current DOM state")
|
|
||||||
dom_state = await self.get_current_dom_state()
|
dom_state = await self.get_current_dom_state()
|
||||||
logger.info(f"DOM state retrieved - URL: {dom_state.url}, Title: {dom_state.title}")
|
|
||||||
|
|
||||||
logger.info("Taking screenshot")
|
|
||||||
screenshot = await self.take_screenshot()
|
screenshot = await self.take_screenshot()
|
||||||
logger.info(f"Screenshot result type: {'dict' if isinstance(screenshot, dict) else 'base64 string'}")
|
|
||||||
if isinstance(screenshot, dict) and screenshot.get("url"):
|
|
||||||
logger.info(f"Screenshot URL: {screenshot['url']}")
|
|
||||||
|
|
||||||
# Format elements for output
|
# Format elements for output
|
||||||
logger.info("Formatting clickable elements")
|
|
||||||
elements = dom_state.element_tree.clickable_elements_to_string(
|
elements = dom_state.element_tree.clickable_elements_to_string(
|
||||||
include_attributes=self.include_attributes
|
include_attributes=self.include_attributes
|
||||||
)
|
)
|
||||||
logger.info(f"Found {len(dom_state.selector_map)} clickable elements")
|
|
||||||
|
|
||||||
# Collect additional metadata
|
# Collect additional metadata
|
||||||
logger.info("Collecting metadata")
|
|
||||||
page = await self.get_current_page()
|
page = await self.get_current_page()
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
|
@ -785,9 +724,8 @@ class BrowserAutomation:
|
||||||
|
|
||||||
metadata['interactive_elements'] = interactive_elements
|
metadata['interactive_elements'] = interactive_elements
|
||||||
|
|
||||||
# Get viewport dimensions
|
# Get viewport dimensions - Fix syntax error in JavaScript
|
||||||
try:
|
try:
|
||||||
logger.info("Getting viewport dimensions")
|
|
||||||
viewport = await page.evaluate("""
|
viewport = await page.evaluate("""
|
||||||
() => {
|
() => {
|
||||||
return {
|
return {
|
||||||
|
@ -798,24 +736,21 @@ class BrowserAutomation:
|
||||||
""")
|
""")
|
||||||
metadata['viewport_width'] = viewport.get('width', 0)
|
metadata['viewport_width'] = viewport.get('width', 0)
|
||||||
metadata['viewport_height'] = viewport.get('height', 0)
|
metadata['viewport_height'] = viewport.get('height', 0)
|
||||||
logger.info(f"Viewport dimensions: {metadata['viewport_width']}x{metadata['viewport_height']}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting viewport dimensions: {e}")
|
print(f"Error getting viewport dimensions: {e}")
|
||||||
metadata['viewport_width'] = 0
|
metadata['viewport_width'] = 0
|
||||||
metadata['viewport_height'] = 0
|
metadata['viewport_height'] = 0
|
||||||
|
|
||||||
# Extract OCR text from screenshot if available
|
# Extract OCR text from screenshot if available
|
||||||
ocr_text = ""
|
ocr_text = ""
|
||||||
if screenshot:
|
if screenshot:
|
||||||
logger.info("Extracting OCR text from screenshot")
|
|
||||||
ocr_text = await self.extract_ocr_text_from_screenshot(screenshot)
|
ocr_text = await self.extract_ocr_text_from_screenshot(screenshot)
|
||||||
metadata['ocr_text'] = ocr_text
|
metadata['ocr_text'] = ocr_text
|
||||||
logger.info(f"OCR text length: {len(ocr_text)} characters")
|
|
||||||
|
|
||||||
logger.info(f"=== Completed get_updated_browser_state for {action_name} ===")
|
print(f"Got updated state after {action_name}: {len(dom_state.selector_map)} elements")
|
||||||
return dom_state, screenshot, elements, metadata
|
return dom_state, screenshot, elements, metadata
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in get_updated_browser_state for {action_name}: {e}")
|
print(f"Error getting updated state after {action_name}: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
# Return empty values in case of error
|
# Return empty values in case of error
|
||||||
return None, "", "", {}
|
return None, "", "", {}
|
||||||
|
@ -824,17 +759,10 @@ class BrowserAutomation:
|
||||||
elements: str, metadata: dict, error: str = "", content: str = None,
|
elements: str, metadata: dict, error: str = "", content: str = None,
|
||||||
fallback_url: str = None) -> BrowserActionResult:
|
fallback_url: str = None) -> BrowserActionResult:
|
||||||
"""Helper method to build a consistent BrowserActionResult"""
|
"""Helper method to build a consistent BrowserActionResult"""
|
||||||
|
# Ensure elements is never None to avoid display issues
|
||||||
if elements is None:
|
if elements is None:
|
||||||
elements = ""
|
elements = ""
|
||||||
|
|
||||||
screenshot_base64 = None
|
|
||||||
screenshot_url = None
|
|
||||||
|
|
||||||
if isinstance(screenshot, dict) and screenshot.get("is_s3"):
|
|
||||||
screenshot_url = screenshot.get("url")
|
|
||||||
else:
|
|
||||||
screenshot_base64 = screenshot
|
|
||||||
|
|
||||||
return BrowserActionResult(
|
return BrowserActionResult(
|
||||||
success=success,
|
success=success,
|
||||||
message=message,
|
message=message,
|
||||||
|
@ -842,8 +770,7 @@ class BrowserAutomation:
|
||||||
url=dom_state.url if dom_state else fallback_url or "",
|
url=dom_state.url if dom_state else fallback_url or "",
|
||||||
title=dom_state.title if dom_state else "",
|
title=dom_state.title if dom_state else "",
|
||||||
elements=elements,
|
elements=elements,
|
||||||
screenshot_base64=screenshot_base64,
|
screenshot_base64=screenshot,
|
||||||
screenshot_url=screenshot_url,
|
|
||||||
pixels_above=dom_state.pixels_above if dom_state else 0,
|
pixels_above=dom_state.pixels_above if dom_state else 0,
|
||||||
pixels_below=dom_state.pixels_below if dom_state else 0,
|
pixels_below=dom_state.pixels_below if dom_state else 0,
|
||||||
content=content,
|
content=content,
|
||||||
|
@ -2157,4 +2084,4 @@ if __name__ == '__main__':
|
||||||
asyncio.run(test_browser_api_2())
|
asyncio.run(test_browser_api_2())
|
||||||
else:
|
else:
|
||||||
print("Starting API server")
|
print("Starting API server")
|
||||||
uvicorn.run("browser_api:api_app", host="0.0.0.0", port=8002)
|
uvicorn.run("browser_api:api_app", host="0.0.0.0", port=8003)
|
|
@ -6,7 +6,7 @@ services:
|
||||||
dockerfile: ${DOCKERFILE:-Dockerfile}
|
dockerfile: ${DOCKERFILE:-Dockerfile}
|
||||||
args:
|
args:
|
||||||
TARGETPLATFORM: ${TARGETPLATFORM:-linux/amd64}
|
TARGETPLATFORM: ${TARGETPLATFORM:-linux/amd64}
|
||||||
image: kortix/suna:0.1.2
|
image: kortix/suna:0.1.2.8
|
||||||
ports:
|
ports:
|
||||||
- "6080:6080" # noVNC web interface
|
- "6080:6080" # noVNC web interface
|
||||||
- "5901:5901" # VNC port
|
- "5901:5901" # VNC port
|
||||||
|
@ -27,6 +27,7 @@ services:
|
||||||
- VNC_PASSWORD=${VNC_PASSWORD:-vncpassword}
|
- VNC_PASSWORD=${VNC_PASSWORD:-vncpassword}
|
||||||
- CHROME_DEBUGGING_PORT=9222
|
- CHROME_DEBUGGING_PORT=9222
|
||||||
- CHROME_DEBUGGING_HOST=localhost
|
- CHROME_DEBUGGING_HOST=localhost
|
||||||
|
- CHROME_FLAGS=${CHROME_FLAGS:-"--single-process --no-first-run --no-default-browser-check --disable-background-networking --disable-background-timer-throttling --disable-backgrounding-occluded-windows --disable-breakpad --disable-component-extensions-with-background-pages --disable-dev-shm-usage --disable-extensions --disable-features=TranslateUI --disable-ipc-flooding-protection --disable-renderer-backgrounding --enable-features=NetworkServiceInProcess2 --force-color-profile=srgb --metrics-recording-only --mute-audio --no-sandbox --disable-gpu"}
|
||||||
volumes:
|
volumes:
|
||||||
- /tmp/.X11-unix:/tmp/.X11-unix
|
- /tmp/.X11-unix:/tmp/.X11-unix
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
timeout 120
|
|
@ -6,6 +6,9 @@ from typing import Optional
|
||||||
from supabase import create_async_client, AsyncClient
|
from supabase import create_async_client, AsyncClient
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.config import config
|
from utils.config import config
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
class DBConnection:
|
class DBConnection:
|
||||||
"""Singleton database connection manager using Supabase."""
|
"""Singleton database connection manager using Supabase."""
|
||||||
|
@ -66,4 +69,45 @@ class DBConnection:
|
||||||
raise RuntimeError("Database not initialized")
|
raise RuntimeError("Database not initialized")
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
async def upload_base64_image(self, base64_data: str, bucket_name: str = "browser-screenshots") -> str:
|
||||||
|
"""Upload a base64 encoded image to Supabase storage and return the URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base64_data (str): Base64 encoded image data (with or without data URL prefix)
|
||||||
|
bucket_name (str): Name of the storage bucket to upload to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Public URL of the uploaded image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Remove data URL prefix if present
|
||||||
|
if base64_data.startswith('data:'):
|
||||||
|
base64_data = base64_data.split(',')[1]
|
||||||
|
|
||||||
|
# Decode base64 data
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
|
filename = f"image_{timestamp}_{unique_id}.png"
|
||||||
|
|
||||||
|
# Upload to Supabase storage
|
||||||
|
client = await self.client
|
||||||
|
storage_response = await client.storage.from_(bucket_name).upload(
|
||||||
|
filename,
|
||||||
|
image_data,
|
||||||
|
{"content-type": "image/png"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public URL
|
||||||
|
public_url = await client.storage.from_(bucket_name).get_public_url(filename)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully uploaded image to {public_url}")
|
||||||
|
return public_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading base64 image: {e}")
|
||||||
|
raise RuntimeError(f"Failed to upload image: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -155,11 +155,11 @@ class Configuration:
|
||||||
STRIPE_DEFAULT_TRIAL_DAYS: int = 14
|
STRIPE_DEFAULT_TRIAL_DAYS: int = 14
|
||||||
|
|
||||||
# Stripe Product IDs
|
# Stripe Product IDs
|
||||||
STRIPE_PRODUCT_ID_PROD: str = 'prod_SCl7AQ2C8kK1CD' # Production product ID
|
STRIPE_PRODUCT_ID_PROD: str = 'prod_SCl7AQ2C8kK1CD'
|
||||||
STRIPE_PRODUCT_ID_STAGING: str = 'prod_SCgIj3G7yPOAWY' # Staging product ID
|
STRIPE_PRODUCT_ID_STAGING: str = 'prod_SCgIj3G7yPOAWY'
|
||||||
|
|
||||||
# Sandbox configuration
|
# Sandbox configuration
|
||||||
SANDBOX_IMAGE_NAME = "kortix/suna:0.1.2"
|
SANDBOX_IMAGE_NAME = "kortix/suna:0.1.2.8"
|
||||||
SANDBOX_ENTRYPOINT = "/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf"
|
SANDBOX_ENTRYPOINT = "/usr/bin/supervisord -n -c /etc/supervisor/conf.d/supervisord.conf"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
"""
|
||||||
|
Utility functions for handling image operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from utils.logger import logger
|
||||||
|
from services.supabase import DBConnection
|
||||||
|
|
||||||
|
async def upload_base64_image(base64_data: str, bucket_name: str = "browser-screenshots") -> str:
|
||||||
|
"""Upload a base64 encoded image to Supabase storage and return the URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base64_data (str): Base64 encoded image data (with or without data URL prefix)
|
||||||
|
bucket_name (str): Name of the storage bucket to upload to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Public URL of the uploaded image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Remove data URL prefix if present
|
||||||
|
if base64_data.startswith('data:'):
|
||||||
|
base64_data = base64_data.split(',')[1]
|
||||||
|
|
||||||
|
# Decode base64 data
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
unique_id = str(uuid.uuid4())[:8]
|
||||||
|
filename = f"image_{timestamp}_{unique_id}.png"
|
||||||
|
|
||||||
|
# Upload to Supabase storage
|
||||||
|
db = DBConnection()
|
||||||
|
client = await db.client
|
||||||
|
storage_response = await client.storage.from_(bucket_name).upload(
|
||||||
|
filename,
|
||||||
|
image_data,
|
||||||
|
{"content-type": "image/png"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get public URL
|
||||||
|
public_url = await client.storage.from_(bucket_name).get_public_url(filename)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully uploaded image to {public_url}")
|
||||||
|
return public_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading base64 image: {e}")
|
||||||
|
raise RuntimeError(f"Failed to upload image: {str(e)}")
|
|
@ -3,7 +3,8 @@ services:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- redis_data:/data
|
||||||
command: redis-server --save 60 1 --loglevel warning
|
- ./backend/services/docker/redis.conf:/usr/local/etc/redis/redis.conf:ro
|
||||||
|
command: redis-server /usr/local/etc/redis/redis.conf --save 60 1 --loglevel warning
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
|
@ -12,6 +13,9 @@ services:
|
||||||
|
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq
|
image: rabbitmq
|
||||||
|
ports:
|
||||||
|
- "5672:5672"
|
||||||
|
- "15672:15672"
|
||||||
volumes:
|
volumes:
|
||||||
- rabbitmq_data:/var/lib/rabbitmq
|
- rabbitmq_data:/var/lib/rabbitmq
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
|
@ -1041,37 +1041,9 @@ export default function ThreadPage({
|
||||||
setDebugMode(debugParam === 'true');
|
setDebugMode(debugParam === 'true');
|
||||||
}, [searchParams]);
|
}, [searchParams]);
|
||||||
|
|
||||||
const handleUpgrade = useCallback(() => {
|
|
||||||
router.push('/settings/billing');
|
|
||||||
}, [router]);
|
|
||||||
|
|
||||||
// Check user tier and show dialog if needed
|
|
||||||
useEffect(() => {
|
|
||||||
if (initialLoadCompleted.current && billingStatusQuery.data) {
|
|
||||||
const isPro = billingStatusQuery.data.subscription?.plan_name?.toLowerCase().includes('pro');
|
|
||||||
console.log("Billing check for upgrade dialog:", {
|
|
||||||
isPro,
|
|
||||||
subscription: billingStatusQuery.data.subscription,
|
|
||||||
userId: threadQuery.data?.created_by || 'user'
|
|
||||||
});
|
|
||||||
|
|
||||||
// Always show dialog for debugging
|
|
||||||
setShowUpgradeDialog(true);
|
|
||||||
console.log("Setting showUpgradeDialog to true");
|
|
||||||
} else {
|
|
||||||
console.log("Not showing upgrade dialog:", {
|
|
||||||
initialLoadCompleted: initialLoadCompleted.current,
|
|
||||||
hasBillingData: !!billingStatusQuery.data
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [initialLoadCompleted, billingStatusQuery.data, threadQuery.data]);
|
|
||||||
|
|
||||||
// Main rendering function for the thread page
|
|
||||||
if (!initialLoadCompleted.current || isLoading) {
|
if (!initialLoadCompleted.current || isLoading) {
|
||||||
// Use the new ThreadSkeleton component instead of inline skeleton
|
|
||||||
return <ThreadSkeleton isSidePanelOpen={isSidePanelOpen} />;
|
return <ThreadSkeleton isSidePanelOpen={isSidePanelOpen} />;
|
||||||
} else if (error) {
|
} else if (error) {
|
||||||
// Error state...
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-screen">
|
<div className="flex h-screen">
|
||||||
<div
|
<div
|
||||||
|
|
|
@ -509,7 +509,7 @@ function LoginContent() {
|
||||||
|
|
||||||
{/* Forgot Password Dialog */}
|
{/* Forgot Password Dialog */}
|
||||||
<Dialog open={forgotPasswordOpen} onOpenChange={setForgotPasswordOpen}>
|
<Dialog open={forgotPasswordOpen} onOpenChange={setForgotPasswordOpen}>
|
||||||
<DialogContent className="sm:max-w-md rounded-xl bg-[#F3F4F6] dark:bg-[#F9FAFB]/[0.02] border border-border">
|
<DialogContent className="sm:max-w-md rounded-xl bg-[#F3F4F6] dark:bg-[#17171A] border border-border [&>button]:hidden">
|
||||||
<DialogHeader>
|
<DialogHeader>
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<DialogTitle className="text-xl font-medium">
|
<DialogTitle className="text-xl font-medium">
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import { useEffect, useCallback, useRef, useState } from 'react';
|
import { useEffect, useCallback, useRef, useState } from 'react';
|
||||||
import Script from 'next/script';
|
import Script from 'next/script';
|
||||||
import { createClient } from '@/lib/supabase/client';
|
import { createClient } from '@/lib/supabase/client';
|
||||||
|
import { useTheme } from 'next-themes';
|
||||||
|
|
||||||
// Add type declarations for Google One Tap
|
// Add type declarations for Google One Tap
|
||||||
declare global {
|
declare global {
|
||||||
|
@ -68,6 +69,7 @@ interface GoogleSignInProps {
|
||||||
export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
|
export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
|
||||||
const googleClientId = process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID;
|
const googleClientId = process.env.NEXT_PUBLIC_GOOGLE_CLIENT_ID;
|
||||||
const [isLoading, setIsLoading] = useState(false);
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
const { resolvedTheme } = useTheme();
|
||||||
|
|
||||||
const handleGoogleSignIn = useCallback(
|
const handleGoogleSignIn = useCallback(
|
||||||
async (response: GoogleSignInResponse) => {
|
async (response: GoogleSignInResponse) => {
|
||||||
|
@ -184,7 +186,7 @@ export default function GoogleSignIn({ returnUrl }: GoogleSignInProps) {
|
||||||
if (buttonContainer) {
|
if (buttonContainer) {
|
||||||
window.google.accounts.id.renderButton(buttonContainer, {
|
window.google.accounts.id.renderButton(buttonContainer, {
|
||||||
type: 'standard',
|
type: 'standard',
|
||||||
theme: 'outline',
|
theme: resolvedTheme === 'dark' ? 'filled_black' : 'outline',
|
||||||
size: 'large',
|
size: 'large',
|
||||||
text: 'continue_with',
|
text: 'continue_with',
|
||||||
shape: 'pill',
|
shape: 'pill',
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
'use client';
|
||||||
|
|
||||||
|
import React, { createContext, useContext, useCallback, useEffect, useRef } from 'react';
|
||||||
|
import { useBillingStatusQuery } from '@/hooks/react-query/threads/use-billing-status';
|
||||||
|
import { BillingStatusResponse } from '@/lib/api';
|
||||||
|
import { isLocalMode } from '@/lib/config';
|
||||||
|
|
||||||
|
interface BillingContextType {
|
||||||
|
billingStatus: BillingStatusResponse | null;
|
||||||
|
isLoading: boolean;
|
||||||
|
error: Error | null;
|
||||||
|
checkBillingStatus: () => Promise<boolean>;
|
||||||
|
lastCheckTime: number | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BillingContext = createContext<BillingContextType | null>(null);
|
||||||
|
|
||||||
|
export function BillingProvider({ children }: { children: React.ReactNode }) {
|
||||||
|
const billingStatusQuery = useBillingStatusQuery();
|
||||||
|
const lastCheckRef = useRef<number | null>(null);
|
||||||
|
const checkInProgressRef = useRef<boolean>(false);
|
||||||
|
|
||||||
|
const checkBillingStatus = useCallback(async (force = false): Promise<boolean> => {
|
||||||
|
if (isLocalMode()) {
|
||||||
|
console.log('Running in local development mode - billing checks are disabled');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (checkInProgressRef.current) {
|
||||||
|
return !billingStatusQuery.data?.can_run;
|
||||||
|
}
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
if (!force && lastCheckRef.current && now - lastCheckRef.current < 60000) {
|
||||||
|
return !billingStatusQuery.data?.can_run;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
checkInProgressRef.current = true;
|
||||||
|
if (force || billingStatusQuery.isStale) {
|
||||||
|
await billingStatusQuery.refetch();
|
||||||
|
}
|
||||||
|
lastCheckRef.current = now;
|
||||||
|
return !billingStatusQuery.data?.can_run;
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Error checking billing status:', err);
|
||||||
|
return false;
|
||||||
|
} finally {
|
||||||
|
checkInProgressRef.current = false;
|
||||||
|
}
|
||||||
|
}, [billingStatusQuery]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!billingStatusQuery.data) {
|
||||||
|
checkBillingStatus(true);
|
||||||
|
}
|
||||||
|
}, [checkBillingStatus, billingStatusQuery.data]);
|
||||||
|
|
||||||
|
const value = {
|
||||||
|
billingStatus: billingStatusQuery.data || null,
|
||||||
|
isLoading: billingStatusQuery.isLoading,
|
||||||
|
error: billingStatusQuery.error,
|
||||||
|
checkBillingStatus,
|
||||||
|
lastCheckTime: lastCheckRef.current,
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<BillingContext.Provider value={value}>
|
||||||
|
{children}
|
||||||
|
</BillingContext.Provider>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useBilling() {
|
||||||
|
const context = useContext(BillingContext);
|
||||||
|
if (!context) {
|
||||||
|
throw new Error('useBilling must be used within a BillingProvider');
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
import { createQueryHook } from "@/hooks/use-query";
|
import { createQueryHook } from "@/hooks/use-query";
|
||||||
import { threadKeys } from "./keys";
|
import { threadKeys } from "./keys";
|
||||||
import { checkBillingStatus } from "@/lib/api";
|
import { checkBillingStatus, BillingStatusResponse } from "@/lib/api";
|
||||||
|
import { Query } from "@tanstack/react-query";
|
||||||
|
|
||||||
export const useBillingStatusQuery = (enabled = true) =>
|
export const useBillingStatusQuery = (enabled = true) =>
|
||||||
createQueryHook(
|
createQueryHook(
|
||||||
|
@ -10,5 +11,17 @@ export const useBillingStatusQuery = (enabled = true) =>
|
||||||
enabled,
|
enabled,
|
||||||
retry: 1,
|
retry: 1,
|
||||||
staleTime: 1000 * 60 * 5,
|
staleTime: 1000 * 60 * 5,
|
||||||
|
gcTime: 1000 * 60 * 10, // 10 minutes (using gcTime instead of cacheTime)
|
||||||
|
refetchOnWindowFocus: false, // Disable refetch on window focus
|
||||||
|
refetchOnMount: false, // Disable refetch on component mount
|
||||||
|
refetchOnReconnect: false, // Disable refetch on reconnect
|
||||||
|
// Only refetch if the data is stale and the query is enabled
|
||||||
|
refetchInterval: (query: Query<BillingStatusResponse, Error>) => {
|
||||||
|
// If we have data and it indicates the user can't run, check more frequently
|
||||||
|
if (query.state.data && !query.state.data.can_run) {
|
||||||
|
return 1000 * 60; // Check every minute if user can't run
|
||||||
|
}
|
||||||
|
return false; // Don't refetch automatically otherwise
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)();
|
)();
|
||||||
|
|
Loading…
Reference in New Issue