mirror of https://github.com/kortix-ai/suna.git
230 lines
8.2 KiB
Python
230 lines
8.2 KiB
Python
import streamlit as st
|
|
from datetime import datetime
|
|
from db_connection import DBConnection
|
|
import asyncio
|
|
import json
|
|
|
|
def format_message_content(content):
|
|
"""Format message content handling various formats."""
|
|
try:
|
|
if isinstance(content, str):
|
|
# Try to parse JSON strings
|
|
try:
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, (dict, list)):
|
|
return json.dumps(parsed, indent=2)
|
|
except json.JSONDecodeError:
|
|
return content
|
|
elif isinstance(content, list):
|
|
formatted_content = []
|
|
for item in content:
|
|
if item.get('type') == 'text':
|
|
formatted_content.append(item['text'])
|
|
elif item.get('type') == 'image_url':
|
|
formatted_content.append("[Image]")
|
|
return "\n".join(formatted_content)
|
|
return json.dumps(content, indent=2)
|
|
except:
|
|
return str(content)
|
|
|
|
async def load_threads():
|
|
"""Load all thread IDs from the database."""
|
|
db = DBConnection()
|
|
rows = await db.fetch_all(
|
|
"""
|
|
SELECT id, created_at
|
|
FROM threads
|
|
ORDER BY created_at DESC
|
|
"""
|
|
)
|
|
return rows
|
|
|
|
async def load_thread_content(thread_id: str, filters: dict):
|
|
"""Load messages from a thread with filters."""
|
|
db = DBConnection()
|
|
|
|
query_parts = ["SELECT type, content, include_in_llm_message_history, created_at FROM messages WHERE thread_id = $1"]
|
|
params = [thread_id]
|
|
param_count = 1 # Keep track of parameter count
|
|
|
|
if filters.get('message_types'):
|
|
# Convert comma-separated string to list and clean up whitespace
|
|
types_list = [t.strip() for t in filters['message_types'].split(',') if t.strip()]
|
|
if types_list:
|
|
placeholders = ','.join(f'${param_count + i + 1}' for i in range(len(types_list)))
|
|
query_parts.append(f"AND type IN ({placeholders})")
|
|
params.extend(types_list)
|
|
param_count += len(types_list)
|
|
|
|
if filters.get('exclude_message_types'):
|
|
# Convert comma-separated string to list and clean up whitespace
|
|
exclude_types_list = [t.strip() for t in filters['exclude_message_types'].split(',') if t.strip()]
|
|
if exclude_types_list:
|
|
placeholders = ','.join(f'${param_count + i + 1}' for i in range(len(exclude_types_list)))
|
|
query_parts.append(f"AND type NOT IN ({placeholders})")
|
|
params.extend(exclude_types_list)
|
|
param_count += len(exclude_types_list)
|
|
|
|
if filters.get('before_timestamp'):
|
|
param_count += 1
|
|
query_parts.append(f"AND created_at < ${param_count}")
|
|
params.append(filters['before_timestamp'])
|
|
|
|
if filters.get('after_timestamp'):
|
|
param_count += 1
|
|
query_parts.append(f"AND created_at > ${param_count}")
|
|
params.append(filters['after_timestamp'])
|
|
|
|
if filters.get('include_in_llm_message_history') is not None:
|
|
param_count += 1
|
|
query_parts.append(f"AND include_in_llm_message_history = ${param_count}")
|
|
params.append(filters['include_in_llm_message_history'])
|
|
|
|
# Add ordering
|
|
order_direction = "DESC" if filters.get('order', 'asc').lower() == 'desc' else "ASC"
|
|
query_parts.append(f"ORDER BY created_at {order_direction}")
|
|
|
|
# Add limit and offset
|
|
if filters.get('limit'):
|
|
param_count += 1
|
|
query_parts.append(f"LIMIT ${param_count}")
|
|
params.append(filters['limit'])
|
|
|
|
if filters.get('offset'):
|
|
param_count += 1
|
|
query_parts.append(f"OFFSET ${param_count}")
|
|
params.append(filters['offset'])
|
|
|
|
query = " ".join(query_parts)
|
|
rows = await db.fetch_all(query, tuple(params))
|
|
return rows
|
|
|
|
def render_message(msg_type: str, content: str, include_in_llm: bool, timestamp: str):
|
|
"""Render a message using Streamlit components."""
|
|
# Message type and metadata
|
|
col1, col2 = st.columns([3, 1])
|
|
with col1:
|
|
st.text(f"Type: {msg_type}")
|
|
with col2:
|
|
st.text("🟢 LLM" if include_in_llm else "⚫ Non-LLM")
|
|
|
|
# Convert PostgreSQL timestamp to string before parsing
|
|
timestamp_str = str(timestamp)
|
|
st.text(f"Time: {datetime.fromisoformat(timestamp_str).strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|
|
# Message content
|
|
st.code(content, language="json")
|
|
|
|
# Separator
|
|
st.divider()
|
|
|
|
def main():
|
|
st.title("Thread Viewer")
|
|
|
|
# Initialize thread data in session state
|
|
if 'threads' not in st.session_state:
|
|
st.session_state.threads = asyncio.run(load_threads())
|
|
|
|
# Thread selection in sidebar
|
|
st.sidebar.title("Select Thread")
|
|
|
|
if not st.session_state.threads:
|
|
st.warning("No threads found in database")
|
|
return
|
|
|
|
# Format thread options with creation date
|
|
thread_options = {
|
|
# Convert PostgreSQL timestamp to string before parsing
|
|
f"{row[0]} ({datetime.fromisoformat(str(row[1])).strftime('%Y-%m-%d %H:%M')})"
|
|
: row[0] for row in st.session_state.threads
|
|
}
|
|
|
|
selected_thread_display = st.sidebar.selectbox(
|
|
"Choose a thread",
|
|
options=list(thread_options.keys()),
|
|
)
|
|
|
|
if selected_thread_display:
|
|
selected_thread_id = thread_options[selected_thread_display]
|
|
|
|
# Display thread ID in sidebar
|
|
st.sidebar.text(f"Thread ID: {selected_thread_id}")
|
|
|
|
# Add refresh button
|
|
if st.sidebar.button("🔄 Refresh Thread"):
|
|
st.session_state.threads = asyncio.run(load_threads())
|
|
st.rerun()
|
|
|
|
# Advanced filtering options in sidebar
|
|
st.sidebar.title("Filter Options")
|
|
|
|
# Message type filters
|
|
col1, col2 = st.sidebar.columns(2)
|
|
with col1:
|
|
message_types = st.text_input(
|
|
"Include Types",
|
|
help="Enter message types to include, separated by commas"
|
|
)
|
|
with col2:
|
|
exclude_message_types = st.text_input(
|
|
"Exclude Types",
|
|
help="Enter message types to exclude, separated by commas"
|
|
)
|
|
|
|
# Limit and offset
|
|
col1, col2 = st.sidebar.columns(2)
|
|
with col1:
|
|
limit = st.number_input("Limit", min_value=1, value=50)
|
|
with col2:
|
|
offset = st.number_input("Offset", min_value=0, value=0)
|
|
|
|
# Timestamp filters
|
|
st.sidebar.subheader("Time Range")
|
|
before_timestamp = st.sidebar.date_input("Before Date", value=None)
|
|
after_timestamp = st.sidebar.date_input("After Date", value=None)
|
|
|
|
# LLM history filter
|
|
include_in_llm = st.sidebar.radio(
|
|
"LLM History Filter",
|
|
options=["All Messages", "LLM Only", "Non-LLM Only"]
|
|
)
|
|
|
|
# Sort order
|
|
order = st.sidebar.radio("Sort Order", ["Ascending", "Descending"])
|
|
|
|
# Prepare filters
|
|
filters = {
|
|
'message_types': message_types if message_types else None,
|
|
'exclude_message_types': exclude_message_types if exclude_message_types else None,
|
|
'limit': limit,
|
|
'offset': offset,
|
|
'order': 'desc' if order == "Descending" else 'asc'
|
|
}
|
|
|
|
# Add timestamp filters if selected
|
|
if before_timestamp:
|
|
filters['before_timestamp'] = before_timestamp.isoformat()
|
|
if after_timestamp:
|
|
filters['after_timestamp'] = after_timestamp.isoformat()
|
|
|
|
# Add LLM history filter
|
|
if include_in_llm == "LLM Only":
|
|
filters['include_in_llm_message_history'] = True
|
|
elif include_in_llm == "Non-LLM Only":
|
|
filters['include_in_llm_message_history'] = False
|
|
|
|
# Load messages with filters
|
|
messages = asyncio.run(load_thread_content(selected_thread_id, filters))
|
|
|
|
if not messages:
|
|
st.info("No messages found with current filters")
|
|
return
|
|
|
|
# Display messages
|
|
for msg_type, content, include_in_llm, timestamp in messages:
|
|
formatted_content = format_message_content(content)
|
|
render_message(msg_type, formatted_content, include_in_llm, timestamp)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|