suna/core/ui/thread_management.py

60 lines
2.3 KiB
Python
Raw Normal View History

2024-10-06 01:04:15 +08:00
import streamlit as st
import requests
from core.ui.utils import API_BASE_URL
2024-10-08 03:13:11 +08:00
from datetime import datetime
2024-10-06 01:04:15 +08:00
def display_thread_management():
2024-10-08 03:13:11 +08:00
st.subheader("Thread Management")
2024-10-06 01:04:15 +08:00
2024-10-08 03:13:11 +08:00
if st.button(" Create New Thread", key="create_thread_button"):
create_new_thread()
2024-10-06 01:04:15 +08:00
2024-10-08 03:13:11 +08:00
display_thread_selector()
2024-10-06 01:04:15 +08:00
def create_new_thread():
response = requests.post(f"{API_BASE_URL}/threads/")
if response.status_code == 200:
thread_id = response.json()['thread_id']
st.session_state.selected_thread = thread_id
2024-10-08 03:13:11 +08:00
st.success(f"New thread created with ID: {thread_id}")
2024-10-06 01:04:15 +08:00
st.rerun()
else:
st.error("Failed to create a new thread.")
def display_thread_selector():
threads_response = requests.get(f"{API_BASE_URL}/threads/")
if threads_response.status_code == 200:
threads = threads_response.json()
2024-10-08 03:13:11 +08:00
# Sort threads by creation date if available, otherwise by thread_id
def sort_key(thread):
if 'creation_date' in thread:
try:
return datetime.strptime(thread['creation_date'], "%Y-%m-%d %H:%M:%S")
except ValueError:
st.warning(f"Invalid date format for thread {thread['thread_id']}")
return thread['thread_id']
sorted_threads = sorted(threads, key=sort_key, reverse=True)
thread_options = []
for thread in sorted_threads:
if 'creation_date' in thread:
thread_options.append(f"{thread['thread_id']} - Created: {thread['creation_date']}")
else:
thread_options.append(thread['thread_id'])
if st.session_state.selected_thread is None and sorted_threads:
st.session_state.selected_thread = sorted_threads[0]['thread_id']
2024-10-06 01:04:15 +08:00
selected_thread = st.selectbox(
"🔍 Select Thread",
thread_options,
key="thread_select",
2024-10-08 03:13:11 +08:00
index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0)
2024-10-06 01:04:15 +08:00
)
if selected_thread:
2024-10-08 03:13:11 +08:00
st.session_state.selected_thread = selected_thread.split(' - ')[0] if ' - ' in selected_thread else selected_thread
else:
st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}")