mirror of https://github.com/kortix-ai/suna.git
feat: add section in tool_list
This commit is contained in:
parent
7239ce4668
commit
457b548a65
|
@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
|||
from enum import Enum
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
|
@ -17,98 +16,74 @@ class Task(BaseModel):
|
|||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
updated_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
completed_at: Optional[str] = None
|
||||
|
||||
def update(self, content: Optional[str] = None, status: Optional[TaskStatus] = None):
|
||||
"""Update task content and/or status"""
|
||||
if content is not None:
|
||||
self.content = content
|
||||
|
||||
if status is not None:
|
||||
self.status = status
|
||||
if status == TaskStatus.COMPLETED:
|
||||
self.completed_at = datetime.now(timezone.utc).isoformat()
|
||||
elif status == TaskStatus.PENDING:
|
||||
self.completed_at = None
|
||||
|
||||
self.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
class TaskUpdateRequest(BaseModel):
|
||||
id: str
|
||||
content: Optional[str] = None
|
||||
status: Optional[TaskStatus] = None
|
||||
section: str = "General" # Simplified: just a string
|
||||
|
||||
class TaskListTool(SandboxToolsBase):
|
||||
"""Tool for managing tasks stored in a single task_list message.
|
||||
|
||||
Provides simple CRUD operations with batch support for efficient task management.
|
||||
Tasks persist in a single message with type "task_list"
|
||||
"""
|
||||
"""Simplified task management tool with same external interface"""
|
||||
|
||||
def __init__(self, project_id: str, thread_manager, thread_id: str):
|
||||
super().__init__(project_id, thread_manager)
|
||||
self.thread_id = thread_id
|
||||
self.task_list_message_type = "task_list"
|
||||
|
||||
async def _find_task_list_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""Find the single task_list message in the thread"""
|
||||
async def _load_tasks(self) -> List[Task]:
|
||||
"""Load tasks from storage"""
|
||||
|
||||
try:
|
||||
client = await self.thread_manager.db.client
|
||||
result = await client.table('messages').select('*')\
|
||||
.eq('thread_id', self.thread_id)\
|
||||
.eq('type', self.task_list_message_type)\
|
||||
.order('created_at', desc=True).limit(1).execute()
|
||||
|
||||
# Look for the most recent task_list message
|
||||
result = await client.table('messages').select('*').eq('thread_id', self.thread_id).eq('type', self.task_list_message_type).order('created_at', desc=True).limit(1).execute()
|
||||
|
||||
if result.data:
|
||||
return result.data[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding task_list message: {e}")
|
||||
return None
|
||||
|
||||
async def _get_tasks(self) -> List[Task]:
|
||||
"""Get tasks from the task_list message"""
|
||||
try:
|
||||
message = await self._find_task_list_message()
|
||||
|
||||
if message and message.get('content'):
|
||||
# Parse the message content to get tasks
|
||||
if isinstance(message['content'], str):
|
||||
content_data = json.loads(message['content'])
|
||||
else:
|
||||
content_data = message['content']
|
||||
if result.data and result.data[0].get('content'):
|
||||
content = result.data[0]['content']
|
||||
if isinstance(content, str):
|
||||
content = json.loads(content)
|
||||
|
||||
tasks_data = content_data.get('tasks', [])
|
||||
return [Task(**task_data) for task_data in tasks_data]
|
||||
# Handle both old nested format and new simple format
|
||||
tasks = []
|
||||
if 'tasks' in content:
|
||||
# New simple format
|
||||
tasks = [Task(**task_data) for task_data in content['tasks']]
|
||||
elif 'sections' in content:
|
||||
# Old nested format - convert to simple
|
||||
for section_data in content['sections']:
|
||||
section_name = section_data.get('title', 'General')
|
||||
for task_data in section_data.get('tasks', []):
|
||||
task_data['section'] = section_name
|
||||
tasks.append(Task(**task_data))
|
||||
|
||||
return tasks
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting tasks from message: {e}")
|
||||
logger.error(f"Error loading tasks: {e}")
|
||||
return []
|
||||
|
||||
async def _save_tasks(self, tasks: List[Task]):
|
||||
"""Save tasks to the task_list message"""
|
||||
"""Save tasks to storage (simplified)"""
|
||||
try:
|
||||
client = await self.thread_manager.db.client
|
||||
|
||||
# Prepare content
|
||||
# Simple storage format
|
||||
content = {
|
||||
"tasks": [task.model_dump() for task in tasks]
|
||||
'tasks': [task.model_dump() for task in tasks]
|
||||
}
|
||||
|
||||
# Find existing task_list message
|
||||
existing_message = await self._find_task_list_message()
|
||||
# Find existing message
|
||||
result = await client.table('messages').select('message_id')\
|
||||
.eq('thread_id', self.thread_id)\
|
||||
.eq('type', self.task_list_message_type)\
|
||||
.order('created_at', desc=True).limit(1).execute()
|
||||
|
||||
if existing_message:
|
||||
# Update existing message
|
||||
await client.table('messages').update({
|
||||
'content': content
|
||||
}).eq('message_id', existing_message['message_id']).execute()
|
||||
if result.data:
|
||||
# Update existing
|
||||
await client.table('messages').update({'content': content})\
|
||||
.eq('message_id', result.data[0]['message_id']).execute()
|
||||
else:
|
||||
# Create new task_list message
|
||||
# Create new
|
||||
await client.table('messages').insert({
|
||||
'thread_id': self.thread_id,
|
||||
'type': self.task_list_message_type,
|
||||
|
@ -116,10 +91,39 @@ class TaskListTool(SandboxToolsBase):
|
|||
'is_llm_message': False,
|
||||
'metadata': {}
|
||||
}).execute()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving tasks to message: {e}")
|
||||
logger.error(f"Error saving tasks: {e}")
|
||||
raise
|
||||
|
||||
def _format_response(self, tasks: List[Task], message: str = "") -> Dict[str, Any]:
|
||||
"""Format tasks for response (maintains expected nested structure)"""
|
||||
# Group by section for backward compatibility
|
||||
sections_map = {}
|
||||
for task in tasks:
|
||||
section = task.section
|
||||
if section not in sections_map:
|
||||
sections_map[section] = []
|
||||
sections_map[section].append(task.model_dump())
|
||||
|
||||
sections = [
|
||||
{
|
||||
"id": section_name.lower().replace(" ", "_"),
|
||||
"title": section_name,
|
||||
"tasks": task_list
|
||||
}
|
||||
for section_name, task_list in sections_map.items()
|
||||
]
|
||||
|
||||
response = {
|
||||
"sections": sections,
|
||||
"total": len(tasks)
|
||||
}
|
||||
|
||||
if message:
|
||||
response["message"] = message
|
||||
|
||||
return response
|
||||
|
||||
@openapi_schema({
|
||||
"type": "function",
|
||||
|
@ -156,31 +160,17 @@ class TaskListTool(SandboxToolsBase):
|
|||
async def view_tasks(self, status_filter: str = "all") -> ToolResult:
|
||||
"""View tasks with optional status filter"""
|
||||
try:
|
||||
tasks = await self._get_tasks()
|
||||
tasks = await self._load_tasks()
|
||||
|
||||
# Filter if needed
|
||||
# Apply filter
|
||||
if status_filter != "all":
|
||||
tasks = [t for t in tasks if t.status.value == status_filter]
|
||||
|
||||
if not tasks:
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps({
|
||||
"tasks": [],
|
||||
"message": f"No {status_filter} tasks found.",
|
||||
"filter": status_filter
|
||||
}, indent=2)
|
||||
)
|
||||
|
||||
message = f"No {status_filter} tasks found." if not tasks else ""
|
||||
response_data = self._format_response(tasks, message)
|
||||
response_data["filter"] = status_filter
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps({
|
||||
"tasks": [task.model_dump() for task in tasks],
|
||||
"total": len(tasks),
|
||||
"filter": status_filter
|
||||
}, indent=2)
|
||||
)
|
||||
return ToolResult(success=True, output=json.dumps(response_data, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error viewing tasks: {e}")
|
||||
|
@ -190,78 +180,103 @@ class TaskListTool(SandboxToolsBase):
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_tasks",
|
||||
"description": "Create one or more tasks. Supports batch creation for efficiency.",
|
||||
"description": "Create tasks organized by sections. Supports batch creation for efficiency.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"sections": {
|
||||
"type": "array",
|
||||
"description": "List of tasks to create",
|
||||
"description": "List of sections with their tasks",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Task description"
|
||||
"description": "Section title"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "completed", "cancelled"],
|
||||
"default": "pending",
|
||||
"description": "Initial task status"
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"description": "Tasks in this section",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Task description"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "completed", "cancelled"],
|
||||
"default": "pending",
|
||||
"description": "Initial task status"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
},
|
||||
"minItems": 1
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
"required": ["title", "tasks"]
|
||||
},
|
||||
"minItems": 1
|
||||
}
|
||||
},
|
||||
"required": ["tasks"]
|
||||
"required": ["sections"]
|
||||
}
|
||||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="create-tasks",
|
||||
mappings=[
|
||||
{"param_name": "tasks", "node_type": "element", "path": "tasks", "required": True}
|
||||
{"param_name": "sections", "node_type": "element", "path": "sections", "required": True}
|
||||
],
|
||||
example='''
|
||||
<function_calls>
|
||||
<invoke name="create_tasks">
|
||||
<parameter name="tasks">[
|
||||
{"content": "Research API documentation"},
|
||||
{"content": "Implement authentication"},
|
||||
{"content": "Write unit tests"},
|
||||
{"content": "Deploy to production"}
|
||||
<parameter name="sections">[
|
||||
{
|
||||
"title": "Setup & Planning",
|
||||
"tasks": [
|
||||
{"content": "Research API documentation"},
|
||||
{"content": "Setup development environment"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"title": "Development",
|
||||
"tasks": [
|
||||
{"content": "Create API client"},
|
||||
{"content": "Implement API integration"}
|
||||
]
|
||||
}
|
||||
]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
'''
|
||||
)
|
||||
async def create_tasks(self, tasks: List[Dict[str, Any]]) -> ToolResult:
|
||||
"""Create multiple tasks in a single operation"""
|
||||
async def create_tasks(self, sections: List[Dict[str, Any]]) -> ToolResult:
|
||||
"""Create tasks organized by sections"""
|
||||
try:
|
||||
existing_tasks = await self._get_tasks()
|
||||
existing_tasks = await self._load_tasks()
|
||||
created_count = 0
|
||||
|
||||
# Validate input and create task objects
|
||||
created_tasks = []
|
||||
for task_data in tasks:
|
||||
new_task = Task(
|
||||
content=task_data["content"],
|
||||
status=TaskStatus(task_data.get("status", "pending"))
|
||||
)
|
||||
existing_tasks.append(new_task)
|
||||
created_tasks.append(new_task.model_dump())
|
||||
for section_data in sections:
|
||||
section_title = section_data["title"]
|
||||
|
||||
for task_data in section_data["tasks"]:
|
||||
new_task = Task(
|
||||
content=task_data["content"],
|
||||
status=TaskStatus(task_data.get("status", "pending")),
|
||||
section=section_title
|
||||
)
|
||||
existing_tasks.append(new_task)
|
||||
created_count += 1
|
||||
|
||||
await self._save_tasks(existing_tasks)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps({
|
||||
"message": f"Created {len(created_tasks)} tasks",
|
||||
"tasks": created_tasks
|
||||
}, indent=2)
|
||||
)
|
||||
message = f"Created {created_count} tasks"
|
||||
response_data = self._format_response(existing_tasks, message)
|
||||
|
||||
return ToolResult(success=True, output=json.dumps(response_data, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating tasks: {e}")
|
||||
|
@ -271,7 +286,7 @@ class TaskListTool(SandboxToolsBase):
|
|||
"type": "function",
|
||||
"function": {
|
||||
"name": "update_tasks",
|
||||
"description": "Update one or more tasks. Can update content or status.",
|
||||
"description": "Update tasks by their IDs. Can change content, status, or move between sections.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -293,6 +308,10 @@ class TaskListTool(SandboxToolsBase):
|
|||
"type": "string",
|
||||
"enum": ["pending", "completed", "cancelled"],
|
||||
"description": "New task status (optional)"
|
||||
},
|
||||
"section": {
|
||||
"type": "string",
|
||||
"description": "New section name (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["id"]
|
||||
|
@ -314,46 +333,39 @@ class TaskListTool(SandboxToolsBase):
|
|||
<invoke name="update_tasks">
|
||||
<parameter name="updates">[
|
||||
{"id": "task-id-1", "status": "completed"},
|
||||
{"id": "task-id-2", "content": "Updated task description"}
|
||||
{"id": "task-id-2", "content": "Updated description", "section": "New Section"}
|
||||
]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
'''
|
||||
)
|
||||
async def update_tasks(self, updates: List[Dict[str, Any]]) -> ToolResult:
|
||||
"""Update multiple tasks in a single operation"""
|
||||
"""Update multiple tasks"""
|
||||
try:
|
||||
tasks = await self._get_tasks()
|
||||
tasks = await self._load_tasks()
|
||||
task_map = {task.id: task for task in tasks}
|
||||
updated_count = 0
|
||||
|
||||
# Create task map for quick lookup
|
||||
task_map = {task.id: task for task in tasks}
|
||||
|
||||
for update_data in updates:
|
||||
update_request = TaskUpdateRequest(**update_data)
|
||||
|
||||
if update_request.id not in task_map:
|
||||
continue
|
||||
|
||||
task = task_map[update_request.id]
|
||||
|
||||
if update_request.content is not None:
|
||||
task.update(content=update_request.content)
|
||||
|
||||
if update_request.status is not None:
|
||||
task.update(status=update_request.status)
|
||||
|
||||
updated_count += 1
|
||||
task_id = update_data["id"]
|
||||
if task_id in task_map:
|
||||
task = task_map[task_id]
|
||||
|
||||
if "content" in update_data:
|
||||
task.content = update_data["content"]
|
||||
if "status" in update_data:
|
||||
task.status = TaskStatus(update_data["status"])
|
||||
if "section" in update_data:
|
||||
task.section = update_data["section"]
|
||||
|
||||
updated_count += 1
|
||||
|
||||
await self._save_tasks(tasks)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps({
|
||||
"message": f"Updated {updated_count} tasks",
|
||||
"tasks": [task.model_dump() for task in tasks]
|
||||
}, indent=2)
|
||||
)
|
||||
message = f"Updated {updated_count} tasks"
|
||||
response_data = self._format_response(tasks, message)
|
||||
|
||||
return ToolResult(success=True, output=json.dumps(response_data, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating tasks: {e}")
|
||||
|
@ -370,9 +382,7 @@ class TaskListTool(SandboxToolsBase):
|
|||
"task_ids": {
|
||||
"type": "array",
|
||||
"description": "List of task IDs to delete",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"items": {"type": "string"},
|
||||
"minItems": 1
|
||||
}
|
||||
},
|
||||
|
@ -394,24 +404,19 @@ class TaskListTool(SandboxToolsBase):
|
|||
'''
|
||||
)
|
||||
async def delete_tasks(self, task_ids: List[str]) -> ToolResult:
|
||||
"""Delete multiple tasks in a single operation"""
|
||||
"""Delete multiple tasks"""
|
||||
try:
|
||||
tasks = await self._get_tasks()
|
||||
|
||||
# Filter out deleted tasks
|
||||
tasks = await self._load_tasks()
|
||||
task_id_set = set(task_ids)
|
||||
remaining_tasks = [task for task in tasks if task.id not in task_id_set]
|
||||
deleted_count = len(tasks) - len(remaining_tasks)
|
||||
|
||||
await self._save_tasks(remaining_tasks)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps({
|
||||
"message": f"Deleted {deleted_count} tasks",
|
||||
"tasks": [task.model_dump() for task in remaining_tasks]
|
||||
}, indent=2)
|
||||
)
|
||||
message = f"Deleted {deleted_count} tasks"
|
||||
response_data = self._format_response(remaining_tasks, message)
|
||||
|
||||
return ToolResult(success=True, output=json.dumps(response_data, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting tasks: {e}")
|
||||
|
@ -420,14 +425,19 @@ class TaskListTool(SandboxToolsBase):
|
|||
@openapi_schema({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "clear_all_tasks",
|
||||
"description": "Delete all tasks. Use with caution - this cannot be undone!",
|
||||
"name": "clear_tasks",
|
||||
"description": "Clear all tasks or tasks in specific sections.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"confirm": {
|
||||
"type": "boolean",
|
||||
"description": "Must be true to confirm clearing all tasks"
|
||||
"description": "Must be true to confirm clearing tasks"
|
||||
},
|
||||
"sections": {
|
||||
"type": "array",
|
||||
"description": "Section names to clear (optional - if not provided, all tasks will be cleared)",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"required": ["confirm"]
|
||||
|
@ -435,36 +445,43 @@ class TaskListTool(SandboxToolsBase):
|
|||
}
|
||||
})
|
||||
@xml_schema(
|
||||
tag_name="clear-all-tasks",
|
||||
tag_name="clear-tasks",
|
||||
mappings=[
|
||||
{"param_name": "confirm", "node_type": "element", "path": "confirm", "required": True}
|
||||
{"param_name": "confirm", "node_type": "element", "path": "confirm", "required": True},
|
||||
{"param_name": "sections", "node_type": "element", "path": "sections", "required": False}
|
||||
],
|
||||
example='''
|
||||
<function_calls>
|
||||
<invoke name="clear_all_tasks">
|
||||
<invoke name="clear_tasks">
|
||||
<parameter name="confirm">true</parameter>
|
||||
<parameter name="sections">["Setup & Planning", "Development"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
'''
|
||||
)
|
||||
async def clear_all_tasks(self, confirm: bool) -> ToolResult:
|
||||
"""Clear all tasks"""
|
||||
async def clear_tasks(self, confirm: bool, sections: Optional[List[str]] = None) -> ToolResult:
|
||||
"""Clear all tasks or tasks in specific sections"""
|
||||
try:
|
||||
if not confirm:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
output="❌ Must confirm=true to clear all tasks"
|
||||
)
|
||||
return ToolResult(success=False, output="❌ Must confirm=true to clear tasks")
|
||||
|
||||
await self._save_tasks([])
|
||||
tasks = await self._load_tasks()
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
output=json.dumps({
|
||||
"message": "All tasks have been cleared",
|
||||
"tasks": []
|
||||
}, indent=2)
|
||||
)
|
||||
if sections:
|
||||
section_set = set(sections)
|
||||
remaining_tasks = [task for task in tasks if task.section not in section_set]
|
||||
deleted_count = len(tasks) - len(remaining_tasks)
|
||||
message = f"Deleted {deleted_count} tasks from {len(sections)} sections"
|
||||
else:
|
||||
remaining_tasks = []
|
||||
deleted_count = len(tasks)
|
||||
message = f"Deleted all {deleted_count} tasks"
|
||||
|
||||
await self._save_tasks(remaining_tasks)
|
||||
|
||||
response_data = self._format_response(remaining_tasks, message)
|
||||
|
||||
return ToolResult(success=True, output=json.dumps(response_data, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing tasks: {e}")
|
||||
|
|
Loading…
Reference in New Issue