suna/backend/triggers/providers/schedule_provider.py

259 lines
11 KiB
Python

import asyncio
import json
import os
from datetime import datetime, timezone
from typing import Dict, Any, Optional
from qstash.client import QStash
from utils.logger import logger
from ..core import TriggerProvider, TriggerType, TriggerEvent, TriggerResult, TriggerConfig, ProviderDefinition
class ScheduleTriggerProvider(TriggerProvider):
"""Schedule trigger provider using Upstash QStash."""
def __init__(self, provider_definition: Optional[ProviderDefinition] = None):
super().__init__(TriggerType.SCHEDULE, provider_definition)
self.qstash_token = os.getenv("QSTASH_TOKEN")
self.webhook_base_url = os.getenv("WEBHOOK_BASE_URL", "http://localhost:8000")
if not self.qstash_token:
logger.warning("QSTASH_TOKEN not found. QStash provider will not work without it.")
self.qstash = None
else:
self.qstash = QStash(token=self.qstash_token)
async def validate_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate schedule configuration."""
if not self.qstash:
raise ValueError("QSTASH_TOKEN environment variable is required for QStash scheduling")
if 'cron_expression' not in config:
raise ValueError("cron_expression is required for QStash schedule triggers")
if 'agent_prompt' not in config:
raise ValueError("agent_prompt is required for schedule triggers")
try:
import croniter
croniter.croniter(config['cron_expression'])
except ImportError:
raise ValueError("croniter package is required for cron expressions. Please install it with: pip install croniter")
except Exception as e:
raise ValueError(f"Invalid cron expression: {str(e)}")
return config
async def setup_trigger(self, trigger_config: TriggerConfig) -> bool:
"""Set up scheduled trigger using QStash."""
try:
webhook_url = f"{self.webhook_base_url}/api/triggers/qstash/webhook"
webhook_payload = {
"trigger_id": trigger_config.trigger_id,
"agent_id": trigger_config.agent_id,
"agent_prompt": trigger_config.config['agent_prompt'],
"schedule_name": trigger_config.name,
"cron_expression": trigger_config.config['cron_expression'],
"event_type": "scheduled",
"provider": "qstash"
}
schedule_id = await asyncio.to_thread(
self.qstash.schedule.create,
destination=webhook_url,
cron=trigger_config.config['cron_expression'],
body=json.dumps(webhook_payload),
headers={
"Content-Type": "application/json",
"X-Schedule-Provider": "qstash",
"X-Trigger-ID": trigger_config.trigger_id,
"X-Agent-ID": trigger_config.agent_id
},
retries=3,
delay="5s"
)
trigger_config.config['qstash_schedule_id'] = schedule_id
logger.info(f"Successfully created QStash schedule {schedule_id} for trigger {trigger_config.trigger_id}")
return True
except Exception as e:
logger.error(f"Error setting up QStash scheduled trigger {trigger_config.trigger_id}: {e}")
return False
async def teardown_trigger(self, trigger_config: TriggerConfig) -> bool:
"""Remove scheduled trigger from QStash."""
try:
schedule_id = trigger_config.config.get('qstash_schedule_id')
if not schedule_id:
logger.warning(f"No QStash schedule ID found for trigger {trigger_config.trigger_id}")
return True
await asyncio.to_thread(
self.qstash.schedule.delete,
schedule_id=schedule_id
)
logger.info(f"Successfully deleted QStash schedule {schedule_id}")
return True
except Exception as e:
logger.error(f"Error removing QStash scheduled trigger {trigger_config.trigger_id}: {e}")
return False
async def process_event(self, event: TriggerEvent) -> TriggerResult:
"""Process scheduled trigger event from QStash."""
try:
raw_data = event.raw_data
agent_prompt = raw_data.get('agent_prompt', 'Execute scheduled task')
execution_variables = {
'scheduled_at': event.timestamp.isoformat(),
'trigger_id': event.trigger_id,
'agent_id': event.agent_id,
'schedule_name': raw_data.get('schedule_name', 'Scheduled Task'),
'execution_source': 'qstash',
'cron_expression': raw_data.get('cron_expression'),
'qstash_message_id': raw_data.get('messageId')
}
return TriggerResult(
success=True,
should_execute_agent=True,
agent_prompt=agent_prompt,
execution_variables=execution_variables
)
except Exception as e:
return TriggerResult(
success=False,
error_message=f"Error processing QStash scheduled trigger event: {str(e)}"
)
async def health_check(self, trigger_config: TriggerConfig) -> bool:
"""Check if the QStash scheduled trigger is healthy."""
try:
schedule_id = trigger_config.config.get('qstash_schedule_id')
if not schedule_id:
return False
schedule = await asyncio.to_thread(
self.qstash.schedule.get,
schedule_id=schedule_id
)
return getattr(schedule, 'is_active', False)
except Exception as e:
logger.error(f"Health check failed for QStash scheduled trigger {trigger_config.trigger_id}: {e}")
return False
async def pause_trigger(self, trigger_config: TriggerConfig) -> bool:
"""Pause a QStash schedule."""
try:
schedule_id = trigger_config.config.get('qstash_schedule_id')
if not schedule_id:
return False
await asyncio.to_thread(
self.qstash.schedules.pause,
schedule_id=schedule_id
)
logger.info(f"Successfully paused QStash schedule {schedule_id}")
return True
except Exception as e:
logger.error(f"Error pausing QStash schedule: {e}")
return False
async def resume_trigger(self, trigger_config: TriggerConfig) -> bool:
"""Resume a QStash schedule."""
try:
schedule_id = trigger_config.config.get('qstash_schedule_id')
if not schedule_id:
return False
await asyncio.to_thread(
self.qstash.schedules.resume,
schedule_id=schedule_id
)
logger.info(f"Successfully resumed QStash schedule {schedule_id}")
return True
except Exception as e:
logger.error(f"Error resuming QStash schedule: {e}")
return False
async def update_trigger(self, trigger_config: TriggerConfig) -> bool:
"""Update a QStash schedule by recreating it."""
try:
schedule_id = trigger_config.config.get('qstash_schedule_id')
webhook_url = f"{self.webhook_base_url}/api/triggers/qstash/webhook"
webhook_payload = {
"trigger_id": trigger_config.trigger_id,
"agent_id": trigger_config.agent_id,
"agent_prompt": trigger_config.config['agent_prompt'],
"schedule_name": trigger_config.name,
"cron_expression": trigger_config.config['cron_expression'],
"event_type": "scheduled",
"provider": "qstash"
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.qstash_base_url}/schedules",
headers={
"Authorization": f"Bearer {self.qstash_token}",
"Content-Type": "application/json"
},
json={
"scheduleId": schedule_id,
"destination": webhook_url,
"cron": trigger_config.config['cron_expression'],
"body": webhook_payload,
"headers": {
"Content-Type": "application/json",
"X-Schedule-Provider": "qstash",
"X-Trigger-ID": trigger_config.trigger_id,
"X-Agent-ID": trigger_config.agent_id
},
"retries": 3,
"delay": "5s"
},
timeout=30.0
)
if response.status_code == 200:
logger.info(f"Successfully updated QStash schedule {schedule_id}")
return True
else:
logger.error(f"Failed to update QStash schedule: {response.status_code} - {response.text}")
return False
except Exception as e:
logger.error(f"Error updating QStash schedule: {e}")
return False
def get_webhook_url(self, trigger_id: str, base_url: str) -> Optional[str]:
"""Return webhook URL for QStash schedules."""
return f"{base_url}/api/triggers/qstash/webhook"
async def list_schedules(self) -> list:
"""List all QStash schedules."""
try:
schedules_data = await asyncio.to_thread(
self.qstash.schedules.list
)
schedules = []
for schedule in schedules_data:
schedules.append({
'id': getattr(schedule, 'schedule_id', None),
'destination': getattr(schedule, 'destination', None),
'cron': getattr(schedule, 'cron', None),
'is_active': getattr(schedule, 'is_active', False),
'created_at': getattr(schedule, 'created_at', None),
'next_delivery': getattr(schedule, 'next_delivery', None)
})
return schedules
except Exception as e:
logger.error(f"Error listing QStash schedules: {e}")
return []