suna/sdk/kortix/stream.py

390 lines
13 KiB
Python

"""
@fileoverview
This file is mostly important for dealing with the content of the streamed data,
perhaps the frontend if the stream doesn't need to be decoded in the backend.
"""
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Callable, Any, Literal, AsyncGenerator
import json
import re
import httpx
def try_parse_json(json_str: str) -> Optional[Any]:
"""Utility function to safely parse JSON strings."""
try:
return json.loads(json_str)
except (json.JSONDecodeError, TypeError):
return None
async def stream_from_url(url: str, **kwargs) -> AsyncGenerator[str, None]:
"""
Helper function that takes a URL and returns an async generator yielding lines.
Args:
url: The URL to stream from
**kwargs: Additional arguments to pass to httpx.AsyncClient.stream()
Yields:
str: Each line from the streaming response
"""
async with httpx.AsyncClient() as client:
async with client.stream("GET", url, **kwargs) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip(): # Only yield non-empty lines
yield line.strip()
@dataclass
class BaseStreamEvent:
"""The base structure for any event coming from the stream."""
thread_id: str
type: Literal["status", "assistant", "tool"]
is_llm_message: bool
metadata: str # Often a JSON string
created_at: str
updated_at: str
message_id: Optional[str] = None
content: Optional[str] = None
status: Optional[str] = None
message: Optional[str] = None
@dataclass
class AssistantMessageChunk(BaseStreamEvent):
"""
Represents a chunk of a streaming assistant message.
`message_id` is null, and `sequence` is used for ordering.
"""
sequence: Optional[int] = None
def __post_init__(self):
# Ensure message_id is None for chunks and type is assistant
self.message_id = None
self.type = "assistant"
if not self.content:
self.content = ""
@dataclass
class CompleteMessage(BaseStreamEvent):
"""
Represents a final, complete message (assistant, tool, or status with an ID).
`message_id` is a non-null string.
"""
def __post_init__(self):
# Ensure message_id is set and content has a default
if not self.message_id:
self.message_id = ""
if not self.content:
self.content = ""
@dataclass
class AssistantContentChunk:
"""The structure of the content within an AssistantMessageChunk."""
role: Literal["assistant"]
content: str
@dataclass
class ProcessedStreamResult:
"""The structured result after processing the stream."""
final_messages: List[CompleteMessage]
thread_id: Optional[str] = None
run_ids: List[str] = field(default_factory=list)
@dataclass
class FunctionCallDetails:
"""Details about a function call in the stream."""
name: Optional[str] = None
# We could add `parameters: Optional[str]` here later
@dataclass
class ToolResultContent:
"""The content of a tool result message. The inner `content` is a JSON string."""
role: Literal["user"] # Or 'tool', depending on the API spec
content: str # A JSON string containing ToolExecutionResult
@dataclass
class ToolExecutionResult:
"""The parsed result of a tool execution."""
tool_execution: Dict[str, Any]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolExecutionResult":
return cls(tool_execution=data.get("tool_execution", {}))
@dataclass
class ToolResultMessage(CompleteMessage):
"""Represents a complete tool result message."""
def __post_init__(self):
super().__post_init__()
# Ensure type is tool
self.type = "tool"
# Type guard functions
def is_assistant_message_chunk(event: Any) -> bool:
"""Type guard to check if an event is a streaming chunk."""
return (
hasattr(event, "type")
and event.type == "assistant"
and hasattr(event, "message_id")
and event.message_id is None
and hasattr(event, "sequence")
and isinstance(event.sequence, int)
)
def is_complete_message(event: Any) -> bool:
"""Type guard to check for any complete, final message."""
return hasattr(event, "message_id") and event.message_id is not None
def is_tool_result_message(event: Any) -> bool:
"""Type guard to specifically identify a tool result message."""
return (
hasattr(event, "type")
and event.type == "tool"
and hasattr(event, "message_id")
and event.message_id is not None
)
# Callback type definitions
OnStreamStartCallback = Callable[[], None]
OnTextUpdateCallback = Callable[[str], None]
OnMessageEndCallback = Callable[[CompleteMessage], None]
OnStatusUpdateCallback = Callable[[Any], None]
OnFunctionCallStartCallback = Callable[[], None]
OnFunctionCallUpdateCallback = Callable[[FunctionCallDetails], None]
OnFunctionCallEndCallback = Callable[[], None]
OnToolResultCallback = Callable[[ToolResultMessage], None]
@dataclass
class RealtimeCallbacks:
"""Callbacks for handling real-time stream events."""
on_stream_start: Optional[OnStreamStartCallback] = None
on_text_update: Optional[OnTextUpdateCallback] = None
on_message_end: Optional[OnMessageEndCallback] = None
on_status_update: Optional[OnStatusUpdateCallback] = None
on_function_call_start: Optional[OnFunctionCallStartCallback] = None
on_function_call_update: Optional[OnFunctionCallUpdateCallback] = None
on_function_call_end: Optional[OnFunctionCallEndCallback] = None
on_tool_result: Optional[OnToolResultCallback] = None
ParsingState = Literal["text", "in_function_call", "function_call_ended"]
@dataclass
class StreamState:
"""Internal state for stream processing."""
chunks: List[AssistantMessageChunk] = field(default_factory=list)
full_text: str = ""
parsing_state: ParsingState = "text"
current_function_call: Optional[FunctionCallDetails] = None
class RealtimeStreamProcessor:
"""
Example streaming parser for reference.
Processes real-time stream data with callback support.
"""
def __init__(self, callbacks: Optional[RealtimeCallbacks] = None):
self.messages: Dict[str, CompleteMessage] = {}
self.state = StreamState()
self.callbacks = callbacks or RealtimeCallbacks()
self.stream_active = False
self.invoke_name_regex = re.compile(r'<invoke\s+name="([^"]+)"')
def _create_default_state(self) -> StreamState:
"""Create a new default state."""
return StreamState()
def _start_stream_if_inactive(self) -> None:
"""Start the stream if it's not already active."""
if not self.stream_active:
self.stream_active = True
if self.callbacks.on_stream_start:
self.callbacks.on_stream_start()
def _handle_chunk(self, chunk: AssistantMessageChunk) -> None:
"""Handle an incoming assistant message chunk."""
self._start_stream_if_inactive()
if not chunk.content:
return
chunk_content_data = try_parse_json(chunk.content)
if not chunk_content_data or "content" not in chunk_content_data:
return
chunk_content = chunk_content_data["content"]
# Use the instance state
self.state.chunks.append(chunk)
# Sort chunks by sequence to handle out-of-order delivery
self.state.chunks.sort(key=lambda c: c.sequence or 0)
# Rebuild full text from sorted chunks
full_text_parts = []
for c in self.state.chunks:
if c.content is not None:
content_data = try_parse_json(c.content)
if content_data and "content" in content_data:
full_text_parts.append(content_data["content"])
self.state.full_text = "".join(full_text_parts)
self._update_parsing_state(self.state)
if self.callbacks.on_text_update:
self.callbacks.on_text_update(self.state.full_text)
def _update_parsing_state(self, state: StreamState) -> None:
"""Update the parsing state based on current content."""
if state.parsing_state == "text":
if "<function_calls>" in state.full_text:
state.parsing_state = "in_function_call"
state.current_function_call = FunctionCallDetails(name=None)
if self.callbacks.on_function_call_start:
self.callbacks.on_function_call_start()
self._update_parsing_state(state)
elif state.parsing_state == "in_function_call":
if state.current_function_call and state.current_function_call.name is None:
match = self.invoke_name_regex.search(state.full_text)
if match:
state.current_function_call.name = match.group(1)
if self.callbacks.on_function_call_update:
self.callbacks.on_function_call_update(
FunctionCallDetails(name=state.current_function_call.name)
)
if "</function_calls>" in state.full_text:
state.parsing_state = "function_call_ended"
if self.callbacks.on_function_call_end:
self.callbacks.on_function_call_end()
state.current_function_call = None
elif state.parsing_state == "function_call_ended":
pass
def process_line(self, line: str) -> None:
"""Process a single line from the stream."""
if not line.startswith("data:"):
return
try:
event_data = try_parse_json(line[5:]) # Remove "data:" prefix
if not event_data:
return
# Convert dict to appropriate object based on type
if (
event_data.get("type") == "assistant"
and event_data.get("message_id") is None
and "sequence" in event_data
):
chunk = AssistantMessageChunk(
message_id=None,
thread_id=event_data.get("thread_id", ""),
type="assistant",
is_llm_message=event_data.get("is_llm_message", False),
metadata=event_data.get("metadata", ""),
created_at=event_data.get("created_at", ""),
updated_at=event_data.get("updated_at", ""),
content=event_data.get("content", ""),
sequence=event_data.get("sequence"),
)
self._handle_chunk(chunk)
elif (
event_data.get("type") == "tool"
and event_data.get("message_id") is not None
):
tool_message = ToolResultMessage(
message_id=event_data["message_id"],
thread_id=event_data.get("thread_id", ""),
type="tool",
is_llm_message=event_data.get("is_llm_message", False),
metadata=event_data.get("metadata", ""),
created_at=event_data.get("created_at", ""),
updated_at=event_data.get("updated_at", ""),
content=event_data.get("content", ""),
)
self._handle_tool_result_message(tool_message)
elif (
event_data.get("type") == "assistant"
and event_data.get("message_id") is not None
):
complete_message = CompleteMessage(
message_id=event_data["message_id"],
thread_id=event_data.get("thread_id", ""),
type="assistant",
is_llm_message=event_data.get("is_llm_message", False),
metadata=event_data.get("metadata", ""),
created_at=event_data.get("created_at", ""),
updated_at=event_data.get("updated_at", ""),
content=event_data.get("content", ""),
)
self._handle_complete_assistant_message(complete_message)
elif event_data.get("type") == "status" and self.callbacks.on_status_update:
status_details = try_parse_json(event_data.get("content", "{}")) or {}
if event_data.get("status"):
status_details["status_type"] = event_data["status"]
if event_data.get("message"):
status_details["message"] = event_data["message"]
# Merge event data with status details
full_status = {**event_data, **status_details}
self.callbacks.on_status_update(full_status)
except Exception as error:
print(f"Failed to process stream line: {line}, error: {error}")
def _handle_tool_result_message(self, message: ToolResultMessage) -> None:
"""Handle a tool result message."""
if message.message_id:
self.messages[message.message_id] = message
if self.callbacks.on_tool_result:
self.callbacks.on_tool_result(message)
def _handle_complete_assistant_message(self, message: CompleteMessage) -> None:
"""Handle a complete assistant message."""
if message.message_id:
self.messages[message.message_id] = message
if self.callbacks.on_message_end:
self.callbacks.on_message_end(message)
self.state = self._create_default_state()