suna/sdk/kortix/api/threads.py

587 lines
17 KiB
Python

from dataclasses import dataclass, asdict
from typing import Optional, List, Dict, Any
import httpx
from datetime import datetime
# Import from shared models
from ..models import (
Role,
MessageType,
BaseMessage,
ChatMessage,
AgentRun,
ContentObject,
)
@dataclass
class MessageCreateRequest:
content: str
type: str = "user" # Should be MessageType value
is_llm_message: bool = True
def __post_init__(self):
"""Validate that type is a valid MessageType"""
try:
MessageType(self.type)
except ValueError:
raise ValueError(
f"Invalid message type: {self.type}. Must be one of {[t.value for t in MessageType]}"
)
@classmethod
def create_user_message(cls, content: str) -> "MessageCreateRequest":
"""Create a user message"""
return cls(content=content, type=MessageType.USER.value, is_llm_message=True)
@classmethod
def create_system_message(cls, content: str) -> "MessageCreateRequest":
"""Create a system message"""
return cls(content=content, type="system", is_llm_message=False)
@dataclass
class AgentStartRequest:
model_name: Optional[str] = None
enable_thinking: Optional[bool] = False
reasoning_effort: Optional[str] = "low"
stream: Optional[bool] = True
enable_context_manager: Optional[bool] = False
agent_id: Optional[str] = None
@dataclass
class ProjectData:
project_id: str
name: str
description: str
account_id: str
sandbox: Dict[str, Any]
is_public: bool
created_at: str
updated_at: str
@dataclass
class AgentRunApiResponse(AgentRun):
"""Extended AgentRun with additional API fields"""
agent_id: Optional[str] = None
agent_version_id: Optional[str] = None
@dataclass
class Thread:
thread_id: str
account_id: str
project_id: Optional[str]
metadata: Dict[str, Any]
is_public: bool
created_at: str
updated_at: str
project: Optional[ProjectData] = None
message_count: Optional[int] = None
recent_agent_runs: Optional[List[AgentRunApiResponse]] = None
@dataclass
class Message:
message_id: str
thread_id: str
type: str # Will map to MessageType enum values
is_llm_message: bool
content: Any # Can be string, dict, or ContentObject
created_at: str
updated_at: str
agent_id: str
agent_version_id: str
metadata: Any
@property
def message_type(self) -> MessageType:
"""Get the MessageType enum value"""
try:
return MessageType(self.type)
except ValueError:
# Fallback for unknown message types
return MessageType.USER
@property
def is_user_message(self) -> bool:
"""Check if this is a user message"""
return self.message_type == MessageType.USER
@property
def is_assistant_message(self) -> bool:
"""Check if this is an assistant message"""
return self.message_type == MessageType.ASSISTANT
def get_content_as_string(self) -> str:
"""Get content as string, handling different content types"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, dict):
return self.content.get("content", str(self.content))
else:
return str(self.content)
@dataclass
class PaginationInfo:
page: int
limit: int
total: int
pages: int
@dataclass
class ThreadsResponse:
threads: List[Thread]
pagination: PaginationInfo
@dataclass
class MessagesResponse:
messages: List[Message]
@dataclass
class CreateThreadResponse:
thread_id: str
project_id: str
@dataclass
class AgentResponse:
agent_id: str
account_id: str
name: str
description: Optional[str]
system_prompt: str
configured_mcps: List[Dict[str, Any]]
custom_mcps: List[Dict[str, Any]]
agentpress_tools: Dict[str, Any]
is_default: bool
is_public: Optional[bool]
marketplace_published_at: Optional[str]
download_count: Optional[int]
tags: Optional[List[str]]
avatar: Optional[str]
avatar_color: Optional[str]
created_at: str
updated_at: Optional[str]
current_version_id: Optional[str]
version_count: Optional[int]
current_version: Optional[Any]
metadata: Optional[Dict[str, Any]]
@dataclass
class ThreadAgentResponse:
agent: Optional[AgentResponse]
source: str # "thread", "default", "none", "missing"
message: str
@dataclass
class AgentStartResponse:
agent_run_id: str
status: str
@dataclass
class AgentRunResponse:
id: str
threadId: str
status: str
startedAt: Optional[str]
completedAt: Optional[str]
error: Optional[str]
@dataclass
class AgentRunsResponse:
agent_runs: List[Dict[str, Any]]
def to_dict(obj) -> Dict[str, Any]:
"""Convert dataclass to dictionary, handling nested dataclasses."""
if hasattr(obj, "__dataclass_fields__"):
return asdict(obj)
return obj
def from_dict(cls, data: Dict[str, Any]):
"""Create dataclass instance from dictionary."""
if not hasattr(cls, "__dataclass_fields__"):
return data
# Handle nested dataclasses
field_types = {
field.name: field.type for field in cls.__dataclass_fields__.values()
}
processed_data = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle Optional types
if hasattr(field_type, "__origin__") and field_type.__origin__ is type(
Optional[str].__origin__
):
field_type = field_type.__args__[0]
# Handle List types
if hasattr(field_type, "__origin__") and field_type.__origin__ is list:
if value is not None and len(value) > 0:
list_type = field_type.__args__[0]
if hasattr(list_type, "__dataclass_fields__"):
processed_data[key] = [
from_dict(list_type, item) for item in value
]
else:
processed_data[key] = value
else:
processed_data[key] = value
# Handle nested dataclasses
elif hasattr(field_type, "__dataclass_fields__") and value is not None:
processed_data[key] = from_dict(field_type, value)
else:
processed_data[key] = value
else:
processed_data[key] = value
return cls(**processed_data)
class ThreadsClient:
"""Client for interacting with threads APIs."""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
custom_headers: Optional[Dict[str, str]] = None,
timeout: float = 30.0,
):
"""Initialize the threads client.
Args:
base_url: The base URL for the API
auth_token: Optional authentication token
custom_headers: Optional custom headers to include in requests
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
# Set up default headers
self.headers = {"Content-Type": "application/json"}
if auth_token:
self.headers["X-API-Key"] = auth_token
if custom_headers:
self.headers.update(custom_headers)
# Initialize HTTP client
self.client = httpx.AsyncClient(
headers=self.headers, timeout=timeout, base_url=self.base_url
)
async def close(self):
"""Close the HTTP client."""
await self.client.aclose()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
"""Handle HTTP response and raise appropriate exceptions."""
if response.status_code == 200:
return response.json()
elif response.status_code == 201:
return response.json()
elif response.status_code == 404:
raise ValueError(f"Not found: {response.text}")
elif response.status_code == 403:
raise PermissionError(f"Access denied: {response.text}")
elif response.status_code >= 400:
try:
error_data = response.json()
error_message = error_data.get("detail", response.text)
except:
error_message = response.text
raise RuntimeError(f"API error ({response.status_code}): {error_message}")
else:
response.raise_for_status()
return response.json()
async def get_threads(
self,
page: int = 1,
limit: int = 1000,
) -> ThreadsResponse:
"""Get all threads for the current user with associated project data.
Args:
page: Page number (1-based)
limit: Number of items per page (max 1000)
Returns:
ThreadsResponse containing paginated threads
"""
params = {
"page": page,
"limit": limit,
}
response = await self.client.get("/threads", params=params)
data = self._handle_response(response)
# Convert threads data
threads = []
for thread_data in data["threads"]:
project_data = None
if thread_data.get("project"):
project_data = from_dict(ProjectData, thread_data["project"])
agent_runs_data = []
if thread_data.get("recent_agent_runs"):
agent_runs_data = [
from_dict(AgentRunApiResponse, run_data)
for run_data in thread_data["recent_agent_runs"]
]
thread = from_dict(
Thread,
{
**thread_data,
"project": project_data,
"recent_agent_runs": agent_runs_data,
},
)
threads.append(thread)
pagination = from_dict(PaginationInfo, data["pagination"])
return ThreadsResponse(threads=threads, pagination=pagination)
async def get_thread(self, thread_id: str) -> Thread:
"""Get a specific thread by ID with complete related data.
Args:
thread_id: The thread ID
Returns:
Thread with complete data including project, message count, and recent agent runs
"""
response = await self.client.get(f"/threads/{thread_id}")
data = self._handle_response(response)
# Handle nested project data
project_data = None
if data.get("project"):
project_data = from_dict(ProjectData, data["project"])
# Handle recent agent runs
agent_runs_data = []
if data.get("recent_agent_runs"):
agent_runs_data = [
from_dict(AgentRunApiResponse, run_data)
for run_data in data["recent_agent_runs"]
]
return from_dict(
Thread,
{**data, "project": project_data, "recent_agent_runs": agent_runs_data},
)
async def get_thread_messages(
self, thread_id: str, order: str = "desc"
) -> MessagesResponse:
"""Get ALL messages for a thread.
Args:
thread_id: The thread ID
order: Order by created_at: 'asc' or 'desc'
Returns:
MessagesResponse containing all messages
"""
params = {"order": order}
response = await self.client.get(
f"/threads/{thread_id}/messages", params=params
)
data = self._handle_response(response)
messages = [from_dict(Message, msg_data) for msg_data in data["messages"]]
return MessagesResponse(messages=messages)
async def add_message_to_thread(self, thread_id: str, message: str) -> Message:
"""Add a simple message to a thread.
Args:
thread_id: The thread ID
message: The message content
Returns:
The created message
"""
# This endpoint expects form data, not JSON
response = await self.client.post(
f"/threads/{thread_id}/messages/add",
params={"message": message},
headers={k: v for k, v in self.headers.items() if k != "Content-Type"},
)
data = self._handle_response(response)
return from_dict(Message, data)
async def delete_message_from_thread(self, thread_id: str, message_id: str) -> None:
"""Delete a message from a thread.
Args:
thread_id: The thread ID
message_id: The message ID
Returns:
None
"""
response = await self.client.delete(
f"/threads/{thread_id}/messages/{message_id}"
)
self._handle_response(response)
async def create_message(
self, thread_id: str, request: MessageCreateRequest
) -> Message:
"""Create a structured message in a thread.
Args:
thread_id: The thread ID
request: The message creation request
Returns:
The created message
"""
response = await self.client.post(
f"/threads/{thread_id}/messages", json=to_dict(request)
)
data = self._handle_response(response)
return from_dict(Message, data)
async def create_thread(self, name: Optional[str] = None) -> CreateThreadResponse:
"""Create a new thread with optional name.
Args:
name: Optional name for the thread/project
Returns:
CreateThreadResponse containing the new thread ID and project ID
"""
data = None if name is None else {"name": name}
response = await self.client.post(
"/threads",
data=data,
headers={k: v for k, v in self.headers.items() if k != "Content-Type"},
)
data = self._handle_response(response)
return from_dict(CreateThreadResponse, data)
async def delete_thread(self, thread_id: str) -> None:
raise NotImplementedError("Not implemented")
async def get_thread_agent(self, thread_id: str) -> ThreadAgentResponse:
"""Get the agent details for a specific thread.
Args:
thread_id: The thread ID
Returns:
ThreadAgentResponse with agent details and source information
"""
response = await self.client.get(f"/thread/{thread_id}/agent")
data = self._handle_response(response)
agent_data = None
if data.get("agent"):
agent_data = from_dict(AgentResponse, data["agent"])
return ThreadAgentResponse(
agent=agent_data, source=data["source"], message=data["message"]
)
async def start_agent(
self, thread_id: str, request: AgentStartRequest
) -> AgentStartResponse:
"""Start an agent for a specific thread.
Args:
thread_id: The thread ID
request: The agent start request
Returns:
AgentStartResponse with agent run ID and status
"""
response = await self.client.post(
f"/thread/{thread_id}/agent/start", json=to_dict(request)
)
data = self._handle_response(response)
return from_dict(AgentStartResponse, data)
async def stop_agent(self, agent_run_id: str) -> Dict[str, str]:
"""Stop a running agent.
Args:
agent_run_id: The agent run ID
Returns:
Status response
"""
response = await self.client.post(f"/agent-run/{agent_run_id}/stop")
data = self._handle_response(response)
return data
def get_agent_run_stream_url(
self, agent_run_id: str, token: Optional[str] = None
) -> str:
"""Get the URL for streaming agent run responses.
Args:
agent_run_id: The agent run ID
token: Optional authentication token for streaming
Returns:
The streaming URL
"""
url = f"{self.base_url}/agent-run/{agent_run_id}/stream"
return url
def create_threads_client(
base_url: str,
auth_token: Optional[str] = None,
custom_headers: Optional[Dict[str, str]] = None,
timeout: float = 120.0,
) -> ThreadsClient:
"""Create a new ThreadsClient instance.
Args:
base_url: The base URL for the API
auth_token: Optional authentication token
custom_headers: Optional custom headers to include in requests
timeout: Request timeout in seconds
Returns:
A new ThreadsClient instance
"""
return ThreadsClient(
base_url=base_url,
auth_token=auth_token,
custom_headers=custom_headers,
timeout=timeout,
)