suna/backend/triggers/trigger_service.py

306 lines
11 KiB
Python

import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Dict, Any, Optional, List
from services.supabase import DBConnection
from utils.logger import logger
class TriggerType(str, Enum):
SCHEDULE = "schedule"
WEBHOOK = "webhook"
EVENT = "event"
@dataclass
class TriggerEvent:
trigger_id: str
agent_id: str
trigger_type: TriggerType
raw_data: Dict[str, Any]
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
context: Dict[str, Any] = field(default_factory=dict)
@dataclass
class TriggerResult:
success: bool
should_execute_agent: bool = False
should_execute_workflow: bool = False
agent_prompt: Optional[str] = None
workflow_id: Optional[str] = None
workflow_input: Optional[Dict[str, Any]] = None
execution_variables: Dict[str, Any] = field(default_factory=dict)
error_message: Optional[str] = None
@dataclass
class Trigger:
trigger_id: str
agent_id: str
provider_id: str
trigger_type: TriggerType
name: str
description: Optional[str]
is_active: bool
config: Dict[str, Any]
created_at: datetime
updated_at: datetime
class TriggerService:
def __init__(self, db_connection: DBConnection):
self._db = db_connection
async def create_trigger(
self,
agent_id: str,
provider_id: str,
name: str,
config: Dict[str, Any],
description: Optional[str] = None
) -> Trigger:
trigger_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
validated_config = await provider_service.validate_trigger_config(provider_id, config)
trigger_type = await provider_service.get_provider_trigger_type(provider_id)
trigger = Trigger(
trigger_id=trigger_id,
agent_id=agent_id,
provider_id=provider_id,
trigger_type=trigger_type,
name=name,
description=description,
is_active=True,
config=validated_config,
created_at=now,
updated_at=now
)
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
raise ValueError(f"Failed to setup trigger with provider: {provider_id}")
await self._save_trigger(trigger)
logger.info(f"Created trigger {trigger_id} for agent {agent_id}")
return trigger
async def get_trigger(self, trigger_id: str) -> Optional[Trigger]:
client = await self._db.client
result = await client.table('agent_triggers').select('*').eq('trigger_id', trigger_id).execute()
if not result.data:
return None
return self._map_to_trigger(result.data[0])
async def get_agent_triggers(self, agent_id: str) -> List[Trigger]:
client = await self._db.client
result = await client.table('agent_triggers').select('*').eq('agent_id', agent_id).execute()
return [self._map_to_trigger(data) for data in result.data]
async def update_trigger(
self,
trigger_id: str,
config: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
is_active: Optional[bool] = None
) -> Trigger:
trigger = await self.get_trigger(trigger_id)
if not trigger:
raise ValueError(f"Trigger not found: {trigger_id}")
# Track previous activation state to optimize provider reconciliation
previous_is_active = trigger.is_active
if config is not None:
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
config = await provider_service.validate_trigger_config(trigger.provider_id, config)
if name is not None:
trigger.name = name
if description is not None:
trigger.description = description
if is_active is not None:
trigger.is_active = is_active
if config is not None:
trigger.config = config
trigger.updated_at = datetime.now(timezone.utc)
# Reconcile provider when config changes or activation state toggles
config_changed = config is not None
activation_toggled = (is_active is not None) and (previous_is_active != trigger.is_active)
if config_changed or activation_toggled:
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
if config_changed:
# For config changes, fully teardown and (re)setup if active
await provider_service.teardown_trigger(trigger)
if trigger.is_active:
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
raise ValueError(f"Failed to update trigger setup: {trigger_id}")
else:
# Only activation toggled; call the minimal required action
if trigger.is_active:
setup_success = await provider_service.setup_trigger(trigger)
if not setup_success:
raise ValueError(f"Failed to enable trigger: {trigger_id}")
else:
await provider_service.teardown_trigger(trigger)
await self._update_trigger(trigger)
logger.info(f"Updated trigger {trigger_id}")
return trigger
async def delete_trigger(self, trigger_id: str) -> bool:
trigger = await self.get_trigger(trigger_id)
if not trigger:
return False
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
# First disable remotely so webhooks stop quickly
try:
await provider_service.teardown_trigger(trigger)
except Exception:
pass
# Then request remote delete if provider supports it
try:
await provider_service.delete_remote_trigger(trigger)
except Exception:
pass
client = await self._db.client
result = await client.table('agent_triggers').delete().eq('trigger_id', trigger_id).execute()
success = len(result.data) > 0
if success:
logger.info(f"Deleted trigger {trigger_id}")
return success
async def process_trigger_event(self, trigger_id: str, raw_data: Dict[str, Any]) -> TriggerResult:
trigger = await self.get_trigger(trigger_id)
if not trigger:
return TriggerResult(success=False, error_message=f"Trigger not found: {trigger_id}")
if not trigger.is_active:
return TriggerResult(success=False, error_message=f"Trigger is inactive: {trigger_id}")
event = TriggerEvent(
trigger_id=trigger_id,
agent_id=trigger.agent_id,
trigger_type=trigger.trigger_type,
raw_data=raw_data
)
from .provider_service import get_provider_service
provider_service = get_provider_service(self._db)
result = await provider_service.process_event(trigger, event)
try:
await self._log_trigger_event(event, result)
except Exception as e:
logger.warning(f"Failed to log trigger event: {e}")
return result
async def _save_trigger(self, trigger: Trigger) -> None:
client = await self._db.client
config_with_provider = {**trigger.config, "provider_id": trigger.provider_id}
await client.table('agent_triggers').insert({
'trigger_id': trigger.trigger_id,
'agent_id': trigger.agent_id,
'trigger_type': trigger.trigger_type.value,
'name': trigger.name,
'description': trigger.description,
'is_active': trigger.is_active,
'config': config_with_provider,
'created_at': trigger.created_at.isoformat(),
'updated_at': trigger.updated_at.isoformat()
}).execute()
async def _update_trigger(self, trigger: Trigger) -> None:
client = await self._db.client
config_with_provider = {**trigger.config, "provider_id": trigger.provider_id}
await client.table('agent_triggers').update({
'trigger_type': trigger.trigger_type.value,
'name': trigger.name,
'description': trigger.description,
'is_active': trigger.is_active,
'config': config_with_provider,
'updated_at': trigger.updated_at.isoformat()
}).eq('trigger_id', trigger.trigger_id).execute()
def _map_to_trigger(self, data: Dict[str, Any]) -> Trigger:
config_data = data.get('config', {})
# Prefer explicit provider_id saved in config; otherwise infer for backwards compatibility
provider_id = config_data.get('provider_id')
if not provider_id:
# Older event-based Composio triggers didn't persist provider_id. Infer from config.
if isinstance(config_data, dict) and (
'composio_trigger_id' in config_data or 'trigger_slug' in config_data
):
provider_id = 'composio'
else:
provider_id = data['trigger_type']
clean_config = {k: v for k, v in config_data.items() if k != 'provider_id'}
return Trigger(
trigger_id=data['trigger_id'],
agent_id=data['agent_id'],
provider_id=provider_id,
trigger_type=TriggerType(data['trigger_type']),
name=data['name'],
description=data.get('description'),
is_active=data.get('is_active', True),
config=clean_config,
created_at=datetime.fromisoformat(data['created_at'].replace('Z', '+00:00')),
updated_at=datetime.fromisoformat(data['updated_at'].replace('Z', '+00:00'))
)
async def _log_trigger_event(self, event: TriggerEvent, result: TriggerResult) -> None:
client = await self._db.client
await client.table('trigger_event_logs').insert({
'log_id': str(uuid.uuid4()),
'trigger_id': event.trigger_id,
'agent_id': event.agent_id,
'trigger_type': event.trigger_type.value,
'event_data': event.raw_data,
'success': result.success,
'should_execute_agent': result.should_execute_agent,
'should_execute_workflow': result.should_execute_workflow,
'agent_prompt': result.agent_prompt,
'workflow_id': result.workflow_id,
'workflow_input': result.workflow_input,
'execution_variables': result.execution_variables,
'error_message': result.error_message,
'event_timestamp': event.timestamp.isoformat(),
'logged_at': datetime.now(timezone.utc).isoformat()
}).execute()
def get_trigger_service(db_connection: DBConnection) -> TriggerService:
return TriggerService(db_connection)