mirror of https://github.com/kortix-ai/suna.git
196 lines
8.2 KiB
Python
196 lines
8.2 KiB
Python
import streamlit as st
|
|
import requests
|
|
from agentpress.ui.utils import API_BASE_URL
|
|
from datetime import datetime
|
|
|
|
def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools):
|
|
return {
|
|
"system_message": {"role": "system", "content": system_message},
|
|
"model_name": model_name,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"tools": selected_tools,
|
|
"additional_system_message": additional_system_message,
|
|
"tool_choice": "auto" # Add this line to ensure tool_choice is always set
|
|
}
|
|
|
|
def prepare_run_thread_agent_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools, autonomous_iterations_amount, continue_instructions):
|
|
return {
|
|
"system_message": {"role": "system", "content": system_message},
|
|
"model_name": model_name,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
"tools": selected_tools,
|
|
"additional_system_message": additional_system_message,
|
|
"autonomous_iterations_amount": autonomous_iterations_amount,
|
|
"continue_instructions": continue_instructions
|
|
}
|
|
|
|
def run_thread(thread_id, run_thread_data):
|
|
with st.spinner("Running thread..."):
|
|
try:
|
|
run_thread_response = requests.post(
|
|
f"{API_BASE_URL}/threads/{thread_id}/run/",
|
|
json=run_thread_data
|
|
)
|
|
run_thread_response.raise_for_status()
|
|
response_data = run_thread_response.json()
|
|
st.success(f"Thread run completed successfully! Status: {response_data.get('status', 'Unknown')}")
|
|
|
|
if 'id' in response_data:
|
|
st.session_state.latest_run_id = response_data['id']
|
|
|
|
st.subheader("Response Content")
|
|
display_response_content(response_data)
|
|
|
|
# Display the full response data
|
|
st.subheader("Full Response Data")
|
|
st.json(response_data)
|
|
|
|
return response_data # Return the response data
|
|
except requests.exceptions.RequestException as e:
|
|
st.error(f"Failed to run thread. Error: {str(e)}")
|
|
if hasattr(e, 'response') and e.response is not None:
|
|
st.text("Response content:")
|
|
st.text(e.response.text)
|
|
except Exception as e:
|
|
st.error(f"An unexpected error occurred: {str(e)}")
|
|
|
|
return None # Return None if there was an error
|
|
|
|
def run_thread_agent(thread_id, run_thread_agent_data):
|
|
with st.spinner("Running thread agent..."):
|
|
try:
|
|
run_thread_response = requests.post(
|
|
f"{API_BASE_URL}/threads/{thread_id}/run_agent/",
|
|
json=run_thread_agent_data
|
|
)
|
|
run_thread_response.raise_for_status()
|
|
response_data = run_thread_response.json()
|
|
st.success(f"Thread agent run completed successfully! Status: {response_data['status']}")
|
|
|
|
st.subheader("Agent Response")
|
|
display_agent_response_content(response_data)
|
|
|
|
# Display the full response data
|
|
st.subheader("Full Agent Response Data")
|
|
st.json(response_data)
|
|
|
|
return response_data
|
|
except requests.exceptions.RequestException as e:
|
|
st.error(f"Failed to run thread agent. Error: {str(e)}")
|
|
if hasattr(e, 'response') and e.response is not None:
|
|
st.text("Response content:")
|
|
st.text(e.response.text)
|
|
except Exception as e:
|
|
st.error(f"An unexpected error occurred: {str(e)}")
|
|
|
|
def display_response_content(response_data):
|
|
if isinstance(response_data, dict) and 'choices' in response_data:
|
|
message = response_data['choices'][0]['message']
|
|
st.write(f"**Role:** {message['role']}")
|
|
st.write(f"**Content:** {message['content']}")
|
|
|
|
if 'tool_calls' in message and message['tool_calls']:
|
|
st.write("**Tool Calls:**")
|
|
for tool_call in message['tool_calls']:
|
|
st.write(f"- Function: `{tool_call['function']['name']}`")
|
|
st.code(tool_call['function']['arguments'], language="json")
|
|
else:
|
|
st.json(response_data)
|
|
|
|
def display_agent_response_content(response_data):
|
|
st.write(f"**Status:** {response_data['status']}")
|
|
st.write(f"**Total Iterations:** {response_data['total_iterations']}")
|
|
st.write(f"**Completed Iterations:** {response_data.get('iterations_count', 'N/A')}")
|
|
|
|
for i, iteration in enumerate(response_data['iterations']):
|
|
with st.expander(f"Iteration {i+1}"):
|
|
display_response_content(iteration)
|
|
|
|
st.write("**Final Configuration:**")
|
|
st.json(response_data['final_config'])
|
|
|
|
def fetch_thread_runs(thread_id, limit):
|
|
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}")
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
st.error("Failed to retrieve runs.")
|
|
return []
|
|
|
|
def format_timestamp(timestamp):
|
|
if timestamp:
|
|
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
|
return 'N/A'
|
|
|
|
def display_runs(runs):
|
|
for run in runs:
|
|
with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False):
|
|
col1, col2 = st.columns(2)
|
|
with col1:
|
|
st.write(f"**Created At:** {format_timestamp(run['created_at'])}")
|
|
st.write(f"**Started At:** {format_timestamp(run['started_at'])}")
|
|
st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}")
|
|
st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}")
|
|
st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}")
|
|
with col2:
|
|
st.write(f"**Model:** {run['model']}")
|
|
st.write(f"**Temperature:** {run['temperature']}")
|
|
st.write(f"**Top P:** {run['top_p']}")
|
|
st.write(f"**Max Tokens:** {run['max_tokens']}")
|
|
st.write(f"**Tool Choice:** {run['tool_choice']}")
|
|
st.write(f"**Execute Tools Async:** {run['execute_tools_async']}")
|
|
st.write(f"**Autonomous Iterations:** {run['autonomous_iterations_amount']}")
|
|
|
|
st.write("**System Message:**")
|
|
st.json(run['system_message'])
|
|
|
|
if run['tools']:
|
|
st.write("**Tools:**")
|
|
st.json(run['tools'])
|
|
|
|
if run['usage']:
|
|
st.write("**Usage:**")
|
|
st.json(run['usage'])
|
|
|
|
if run['response_format']:
|
|
st.write("**Response Format:**")
|
|
st.json(run['response_format'])
|
|
|
|
if run['last_error']:
|
|
st.error("**Last Error:**")
|
|
st.code(run['last_error'])
|
|
|
|
if run['continue_instructions']:
|
|
st.write("**Continue Instructions:**")
|
|
st.text(run['continue_instructions'])
|
|
|
|
if run['status'] == "in_progress":
|
|
if st.button(f"Stop Run {run['id']}", key=f"stop_button_{run['id']}"):
|
|
stop_thread_run(run['thread_id'], run['id'])
|
|
st.rerun()
|
|
|
|
if st.button(f"Refresh Status for Run {run['id']}", key=f"refresh_button_{run['id']}"):
|
|
updated_run = get_thread_run_status(run['thread_id'], run['id'])
|
|
if updated_run:
|
|
run.update(updated_run)
|
|
st.rerun()
|
|
|
|
def stop_thread_run(thread_id, run_id):
|
|
response = requests.post(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/stop")
|
|
if response.status_code == 200:
|
|
st.success("Thread run stopped successfully.")
|
|
return response.json()
|
|
else:
|
|
st.error(f"Failed to stop thread run. Status code: {response.status_code}")
|
|
return None
|
|
|
|
def get_thread_run_status(thread_id, run_id):
|
|
response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs/{run_id}/status")
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
st.error(f"Failed to get thread run status. Status code: {response.status_code}")
|
|
return None
|