2024-10-06 01:04:15 +08:00
import json
import logging
import asyncio
2024-10-17 04:08:46 +08:00
from typing import List , Dict , Any , Optional , Callable
from sqlalchemy import select
from agentpress . db import Database , Thread , ThreadRun
2024-10-10 22:21:39 +08:00
from agentpress . llm import make_llm_api_call
2024-10-17 04:08:46 +08:00
from datetime import datetime , UTC
from agentpress . tool import ToolResult
2024-10-10 22:21:39 +08:00
from agentpress . tool_registry import ToolRegistry
2024-10-08 03:13:11 +08:00
import uuid
2024-10-17 04:08:46 +08:00
from tools . files_tool import FilesTool
2024-10-10 22:21:39 +08:00
2024-10-06 01:04:15 +08:00
class ThreadManager :
def __init__ ( self , db : Database ) :
self . db = db
self . tool_registry = ToolRegistry ( )
2024-10-17 04:08:46 +08:00
self . run_config : Dict [ str , Any ] = { }
self . current_iteration : int = 0
2024-10-06 01:04:15 +08:00
async def create_thread ( self ) - > int :
async with self . db . get_async_session ( ) as session :
new_thread = Thread (
2024-10-10 22:21:39 +08:00
messages = json . dumps ( [ ] )
2024-10-06 01:04:15 +08:00
)
session . add ( new_thread )
await session . commit ( )
await session . refresh ( new_thread ) # Ensure thread_id is populated
return new_thread . thread_id
async def add_message ( self , thread_id : int , message_data : Dict [ str , Any ] , images : Optional [ List [ Dict [ str , Any ] ] ] = None ) :
logging . info ( f " Adding message to thread { thread_id } with images: { images } " )
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if not thread :
raise ValueError ( f " Thread with id { thread_id } not found " )
try :
messages = json . loads ( thread . messages )
# If we're adding a user message, perform checks
if message_data [ ' role ' ] == ' user ' :
# Find the last assistant message with tool calls
last_assistant_index = next ( ( i for i in reversed ( range ( len ( messages ) ) ) if messages [ i ] [ ' role ' ] == ' assistant ' and ' tool_calls ' in messages [ i ] ) , None )
if last_assistant_index is not None :
tool_call_count = len ( messages [ last_assistant_index ] [ ' tool_calls ' ] )
tool_response_count = sum ( 1 for msg in messages [ last_assistant_index + 1 : ] if msg [ ' role ' ] == ' tool ' )
if tool_call_count != tool_response_count :
await self . cleanup_incomplete_tool_calls ( thread_id )
# Convert ToolResult objects to strings
for key , value in message_data . items ( ) :
if isinstance ( value , ToolResult ) :
message_data [ key ] = str ( value )
# Process images if present
if images :
if isinstance ( message_data [ ' content ' ] , str ) :
message_data [ ' content ' ] = [ { " type " : " text " , " text " : message_data [ ' content ' ] } ]
elif not isinstance ( message_data [ ' content ' ] , list ) :
message_data [ ' content ' ] = [ ]
for image in images :
image_content = {
" type " : " image_url " ,
" image_url " : {
" url " : f " data: { image [ ' content_type ' ] } ;base64, { image [ ' base64 ' ] } " ,
" detail " : " high "
}
}
message_data [ ' content ' ] . append ( image_content )
messages . append ( message_data )
thread . messages = json . dumps ( messages )
await session . commit ( )
logging . info ( f " Message added to thread { thread_id } : { message_data } " )
except Exception as e :
await session . rollback ( )
logging . error ( f " Failed to add message to thread { thread_id } : { e } " )
raise e
async def get_message ( self , thread_id : int , message_index : int ) - > Optional [ Dict [ str , Any ] ] :
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if not thread :
return None
messages = json . loads ( thread . messages )
if message_index < len ( messages ) :
return messages [ message_index ]
return None
async def modify_message ( self , thread_id : int , message_index : int , new_message_data : Dict [ str , Any ] ) :
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if not thread :
raise ValueError ( f " Thread with id { thread_id } not found " )
try :
messages = json . loads ( thread . messages )
if message_index < len ( messages ) :
messages [ message_index ] = new_message_data
thread . messages = json . dumps ( messages )
await session . commit ( )
else :
raise ValueError ( f " Message index { message_index } is out of range " )
except Exception as e :
await session . rollback ( )
raise e
async def remove_message ( self , thread_id : int , message_index : int ) :
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if not thread :
raise ValueError ( f " Thread with id { thread_id } not found " )
try :
messages = json . loads ( thread . messages )
if message_index < len ( messages ) :
del messages [ message_index ]
thread . messages = json . dumps ( messages )
await session . commit ( )
except Exception as e :
await session . rollback ( )
raise e
async def list_messages ( self , thread_id : int , hide_tool_msgs : bool = False , only_latest_assistant : bool = False , regular_list : bool = True ) - > List [ Dict [ str , Any ] ] :
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if not thread :
return [ ]
messages = json . loads ( thread . messages )
if only_latest_assistant :
for msg in reversed ( messages ) :
if msg . get ( ' role ' ) == ' assistant ' :
return [ msg ]
return [ ]
filtered_messages = messages # Initialize filtered_messages with all messages
if hide_tool_msgs :
filtered_messages = [
{ k : v for k , v in msg . items ( ) if k != ' tool_calls ' }
for msg in filtered_messages
if msg . get ( ' role ' ) != ' tool '
]
if regular_list :
filtered_messages = [
msg for msg in filtered_messages
if msg . get ( ' role ' ) in [ ' system ' , ' assistant ' , ' tool ' , ' user ' ]
]
return filtered_messages
async def cleanup_incomplete_tool_calls ( self , thread_id : int ) :
messages = await self . list_messages ( thread_id )
last_assistant_message = next ( ( m for m in reversed ( messages ) if m [ ' role ' ] == ' assistant ' and ' tool_calls ' in m ) , None )
if last_assistant_message :
tool_calls = last_assistant_message . get ( ' tool_calls ' , [ ] )
tool_responses = [ m for m in messages [ messages . index ( last_assistant_message ) + 1 : ] if m [ ' role ' ] == ' tool ' ]
if len ( tool_calls ) != len ( tool_responses ) :
# Create failed ToolResults for incomplete tool calls
failed_tool_results = [ ]
for tool_call in tool_calls [ len ( tool_responses ) : ] :
failed_tool_result = {
" role " : " tool " ,
" tool_call_id " : tool_call [ ' id ' ] ,
" name " : tool_call [ ' function ' ] [ ' name ' ] ,
" content " : " ToolResult(success=False, output= ' Execution interrupted. Session was stopped. ' ) "
}
failed_tool_results . append ( failed_tool_result )
# Insert failed tool results after the last assistant message
assistant_index = messages . index ( last_assistant_message )
messages [ assistant_index + 1 : assistant_index + 1 ] = failed_tool_results
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if thread :
thread . messages = json . dumps ( messages )
await session . commit ( )
return True
return False
2024-10-17 04:08:46 +08:00
async def run_thread ( self , settings : Dict [ str , Any ] ) - > Dict [ str , Any ] :
2024-10-06 01:04:15 +08:00
try :
2024-10-17 04:08:46 +08:00
thread_run = ThreadRun (
id = str ( uuid . uuid4 ( ) ) ,
thread_id = settings [ ' thread_id ' ] ,
status = " queued " ,
model = settings [ ' model_name ' ] ,
temperature = settings . get ( ' temperature ' , 0.7 ) ,
max_tokens = settings . get ( ' max_tokens ' ) ,
top_p = settings . get ( ' top_p ' ) ,
tool_choice = settings . get ( ' tool_choice ' , ' auto ' ) ,
execute_tools_async = settings . get ( ' execute_tools_async ' , True ) ,
system_message = json . dumps ( settings [ ' system_message ' ] ) ,
tools = json . dumps ( settings . get ( ' tools ' ) ) ,
response_format = json . dumps ( settings . get ( ' response_format ' ) ) ,
autonomous_iterations_amount = settings . get ( ' autonomous_iterations_amount ' , 1 ) ,
continue_instructions = settings . get ( ' continue_instructions ' )
)
async with self . db . get_async_session ( ) as session :
session . add ( thread_run )
await session . commit ( )
2024-10-08 03:13:11 +08:00
thread_run . status = " in_progress "
2024-10-17 04:08:46 +08:00
thread_run . started_at = int ( datetime . now ( UTC ) . timestamp ( ) )
2024-10-08 03:13:11 +08:00
await self . update_thread_run ( thread_run )
2024-10-17 04:08:46 +08:00
self . run_config = { k : v for k , v in thread_run . __dict__ . items ( ) if not k . startswith ( ' _ ' ) }
self . run_config [ ' iterations ' ] = [ ]
if settings . get ( ' initializer ' ) :
settings [ ' initializer ' ] ( )
# Update thread_run with changes from run_config
for key , value in self . run_config . items ( ) :
setattr ( thread_run , key , value )
2024-10-08 03:13:11 +08:00
await self . update_thread_run ( thread_run )
2024-10-10 22:21:39 +08:00
full_tools = None
2024-10-17 04:08:46 +08:00
if settings . get ( ' tools ' ) :
full_tools = [ self . tool_registry . get_tool ( tool_name ) [ ' schema ' ] for tool_name in settings [ ' tools ' ] if self . tool_registry . get_tool ( tool_name ) ]
2024-10-08 03:13:11 +08:00
2024-10-17 04:08:46 +08:00
self . current_iteration = 0
for iteration in range ( settings . get ( ' autonomous_iterations_amount ' , 1 ) ) :
self . current_iteration = iteration + 1
if await self . should_stop ( settings [ ' thread_id ' ] , thread_run . id ) :
thread_run . status = " stopped "
thread_run . cancelled_at = int ( datetime . now ( UTC ) . timestamp ( ) )
await self . update_thread_run ( thread_run )
return { " status " : " stopped " , " message " : " Thread run cancelled " }
if settings . get ( ' pre_iteration ' ) :
settings [ ' pre_iteration ' ] ( )
# Update thread_run with changes from run_config
for key , value in self . run_config . items ( ) :
setattr ( thread_run , key , value )
await self . update_thread_run ( thread_run )
if iteration > 0 and settings . get ( ' continue_instructions ' ) :
await self . add_message ( settings [ ' thread_id ' ] , { " role " : " user " , " content " : settings [ ' continue_instructions ' ] } )
messages = await self . list_messages ( settings [ ' thread_id ' ] , hide_tool_msgs = settings . get ( ' hide_tool_msgs ' , False ) )
prepared_messages = [ settings [ ' system_message ' ] ] + messages
if settings . get ( ' additional_message ' ) :
prepared_messages . append ( settings [ ' additional_message ' ] )
response = await make_llm_api_call (
prepared_messages ,
settings [ ' model_name ' ] ,
temperature = thread_run . temperature ,
max_tokens = thread_run . max_tokens ,
tools = full_tools ,
tool_choice = thread_run . tool_choice ,
stream = False ,
top_p = thread_run . top_p ,
response_format = json . loads ( thread_run . response_format ) if thread_run . response_format else None
)
2024-10-10 22:21:39 +08:00
2024-10-17 04:08:46 +08:00
usage = response . usage if hasattr ( response , ' usage ' ) else None
usage_dict = self . serialize_usage ( usage ) if usage else None
thread_run . usage = usage_dict
2024-10-06 01:04:15 +08:00
2024-10-17 04:08:46 +08:00
assistant_message = {
" role " : " assistant " ,
" content " : response . choices [ 0 ] . message [ ' content ' ]
}
if ' tool_calls ' in response . choices [ 0 ] . message :
assistant_message [ ' tool_calls ' ] = response . choices [ 0 ] . message [ ' tool_calls ' ]
await self . add_message ( settings [ ' thread_id ' ] , assistant_message )
2024-10-06 01:04:15 +08:00
2024-10-17 04:08:46 +08:00
if settings . get ( ' tools ' ) is None or settings . get ( ' use_tool_parser ' , False ) :
await self . handle_response_without_tools ( settings [ ' thread_id ' ] , response , settings . get ( ' use_tool_parser ' , False ) )
else :
await self . handle_response_with_tools ( settings [ ' thread_id ' ] , response , settings . get ( ' execute_tools_async ' , True ) )
2024-10-06 01:04:15 +08:00
2024-10-17 04:08:46 +08:00
self . run_config [ ' iterations ' ] . append ( {
" iteration " : self . current_iteration ,
" response " : self . serialize_choice ( response . choices [ 0 ] ) ,
" usage " : usage_dict
} )
2024-10-10 22:21:39 +08:00
2024-10-17 04:08:46 +08:00
if settings . get ( ' after_iteration ' ) :
settings [ ' after_iteration ' ] ( )
# Update thread_run with changes from run_config
for key , value in self . run_config . items ( ) :
setattr ( thread_run , key , value )
thread_run . iterations = json . dumps ( self . run_config [ ' iterations ' ] )
2024-10-10 22:21:39 +08:00
await self . update_thread_run ( thread_run )
2024-10-08 03:13:11 +08:00
thread_run . status = " completed "
2024-10-17 04:08:46 +08:00
thread_run . completed_at = int ( datetime . now ( UTC ) . timestamp ( ) )
2024-10-08 03:13:11 +08:00
await self . update_thread_run ( thread_run )
2024-10-17 04:08:46 +08:00
self . run_config . update ( { k : v for k , v in thread_run . __dict__ . items ( ) if not k . startswith ( ' _ ' ) } )
if settings . get ( ' finalizer ' ) :
settings [ ' finalizer ' ] ( )
# Update thread_run with final changes from run_config
for key , value in self . run_config . items ( ) :
setattr ( thread_run , key , value )
await self . update_thread_run ( thread_run )
2024-10-08 03:13:11 +08:00
return {
" id " : thread_run . id ,
2024-10-10 22:21:39 +08:00
" status " : thread_run . status ,
2024-10-17 04:08:46 +08:00
" iterations " : self . run_config [ ' iterations ' ] ,
" total_iterations " : len ( self . run_config [ ' iterations ' ] ) ,
" usage " : thread_run . usage ,
" model " : settings [ ' model_name ' ] ,
2024-10-08 03:13:11 +08:00
" object " : " chat.completion " ,
2024-10-17 04:08:46 +08:00
" created " : int ( datetime . now ( UTC ) . timestamp ( ) )
2024-10-08 03:13:11 +08:00
}
2024-10-06 01:04:15 +08:00
except Exception as e :
2024-10-08 03:13:11 +08:00
thread_run . status = " failed "
2024-10-17 04:08:46 +08:00
thread_run . failed_at = int ( datetime . now ( UTC ) . timestamp ( ) )
2024-10-10 22:21:39 +08:00
thread_run . last_error = str ( e )
2024-10-08 03:13:11 +08:00
await self . update_thread_run ( thread_run )
2024-10-17 04:08:46 +08:00
self . run_config . update ( { k : v for k , v in thread_run . __dict__ . items ( ) if not k . startswith ( ' _ ' ) } )
if settings . get ( ' finalizer ' ) :
settings [ ' finalizer ' ] ( )
2024-10-10 22:21:39 +08:00
raise
2024-10-08 03:13:11 +08:00
async def update_thread_run ( self , thread_run : ThreadRun ) :
async with self . db . get_async_session ( ) as session :
session . add ( thread_run )
await session . commit ( )
await session . refresh ( thread_run )
2024-10-06 01:04:15 +08:00
async def handle_response_without_tools ( self , thread_id : int , response : Any , use_tool_parser : bool ) :
response_content = response . choices [ 0 ] . message [ ' content ' ]
if use_tool_parser :
await self . handle_tool_parser_response ( thread_id , response_content )
else :
2024-10-08 03:13:11 +08:00
# The message has already been added in the run_thread method, so we don't need to add it again here
pass
2024-10-06 01:04:15 +08:00
async def handle_tool_parser_response ( self , thread_id : int , response_content : str ) :
tool_call_match = re . search ( r ' \ { [ \ s \ S]* " function_calls " [ \ s \ S]* \ } ' , response_content )
if tool_call_match :
try :
tool_call_json = json . loads ( tool_call_match . group ( ) )
tool_calls = tool_call_json . get ( ' function_calls ' , [ ] )
assistant_message = {
" role " : " assistant " ,
" content " : response_content ,
" tool_calls " : [
{
" id " : f " call_ { i } " ,
" type " : " function " ,
" function " : {
" name " : call [ ' name ' ] ,
" arguments " : json . dumps ( call [ ' arguments ' ] )
}
} for i , call in enumerate ( tool_calls )
]
}
await self . add_message ( thread_id , assistant_message )
available_functions = self . get_available_functions ( )
tool_results = await self . execute_tools ( assistant_message [ ' tool_calls ' ] , available_functions , thread_id , execute_tools_async = True )
await self . process_tool_results ( thread_id , tool_results )
except json . JSONDecodeError :
logging . error ( " Failed to parse tool call JSON from response " )
await self . add_message ( thread_id , { " role " : " assistant " , " content " : response_content } )
else :
await self . add_message ( thread_id , { " role " : " assistant " , " content " : response_content } )
async def handle_response_with_tools ( self , thread_id : int , response : Any , execute_tools_async : bool ) :
try :
response_message = response . choices [ 0 ] . message
tool_calls = response_message . get ( ' tool_calls ' , [ ] )
2024-10-08 03:13:11 +08:00
# The assistant message has already been added in the run_thread method
2024-10-06 01:04:15 +08:00
available_functions = self . get_available_functions ( )
2024-10-10 22:21:39 +08:00
if await self . should_stop ( thread_id , thread_id ) :
2024-10-06 01:04:15 +08:00
return { " status " : " stopped " , " message " : " Session cancelled " }
if tool_calls :
if execute_tools_async :
tool_results = await self . execute_tools_async ( tool_calls , available_functions , thread_id )
else :
tool_results = await self . execute_tools_sync ( tool_calls , available_functions , thread_id )
# Add tool results to messages
for result in tool_results :
await self . add_message ( thread_id , result )
2024-10-10 22:21:39 +08:00
if await self . should_stop ( thread_id , thread_id ) :
2024-10-06 01:04:15 +08:00
return { " status " : " stopped " , " message " : " Session cancelled after tool execution " }
except AttributeError as e :
logging . error ( f " AttributeError: { e } " )
2024-10-08 03:13:11 +08:00
# No need to add the message here as it's already been added in the run_thread method
2024-10-06 01:04:15 +08:00
def create_assistant_message_with_tools ( self , response_message : Any ) - > Dict [ str , Any ] :
message = {
" role " : " assistant " ,
" content " : response_message . get ( ' content ' ) or " " ,
}
tool_calls = response_message . get ( ' tool_calls ' )
if tool_calls :
message [ " tool_calls " ] = [
{
" id " : tool_call . id ,
" type " : " function " ,
" function " : {
" name " : tool_call . function . name ,
" arguments " : tool_call . function . arguments
}
} for tool_call in tool_calls
]
return message
def get_available_functions ( self ) - > Dict [ str , Callable ] :
available_functions = { }
for tool_name , tool_info in self . tool_registry . get_all_tools ( ) . items ( ) :
tool_instance = tool_info [ ' instance ' ]
for func_name , func in tool_instance . __class__ . __dict__ . items ( ) :
if callable ( func ) and not func_name . startswith ( " __ " ) :
available_functions [ func_name ] = getattr ( tool_instance , func_name )
return available_functions
async def execute_tools ( self , tool_calls : List [ Any ] , available_functions : Dict [ str , Callable ] , thread_id : int , execute_tools_async : bool ) - > List [ Dict [ str , Any ] ] :
if execute_tools_async :
return await self . execute_tools_async ( tool_calls , available_functions , thread_id )
else :
return await self . execute_tools_sync ( tool_calls , available_functions , thread_id )
async def execute_tools_async ( self , tool_calls , available_functions , thread_id ) :
async def execute_single_tool ( tool_call ) :
2024-10-10 22:21:39 +08:00
if await self . should_stop ( thread_id , thread_id ) :
2024-10-06 01:04:15 +08:00
return { " status " : " stopped " , " message " : " Session cancelled " }
function_name = tool_call . function . name
tool_call_id = tool_call . id
2024-10-17 04:08:46 +08:00
try :
function_args = json . loads ( tool_call . function . arguments )
except json . JSONDecodeError as e :
error_message = f " Error parsing arguments for { function_name } : { str ( e ) } "
logging . error ( error_message )
logging . error ( f " Problematic JSON: { tool_call . function . arguments } " )
return {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
" name " : function_name ,
" content " : str ( ToolResult ( success = False , output = error_message ) ) ,
}
2024-10-06 01:04:15 +08:00
function_to_call = available_functions . get ( function_name )
if function_to_call :
return await self . execute_tool ( function_to_call , function_args , function_name , tool_call_id )
else :
logging . warning ( f " Function { function_name } not found in available functions " )
return None
tool_results = await asyncio . gather ( * [ execute_single_tool ( tool_call ) for tool_call in tool_calls ] )
return [ result for result in tool_results if result ]
async def execute_tools_sync ( self , tool_calls , available_functions , thread_id ) :
tool_results = [ ]
for tool_call in tool_calls :
2024-10-10 22:21:39 +08:00
if await self . should_stop ( thread_id , thread_id ) :
2024-10-06 01:04:15 +08:00
return [ { " status " : " stopped " , " message " : " Session cancelled " } ]
function_name = tool_call . function . name
function_args = json . loads ( tool_call . function . arguments )
tool_call_id = tool_call . id
function_to_call = available_functions . get ( function_name )
if function_to_call :
result = await self . execute_tool ( function_to_call , function_args , function_name , tool_call_id )
if result :
tool_results . append ( result )
else :
logging . warning ( f " Function { function_name } not found in available functions " )
return tool_results
async def process_tool_results ( self , thread_id : int , tool_results : List [ Dict [ str , Any ] ] ) :
for result in tool_results :
await self . add_message ( thread_id , result [ ' tool_message ' ] )
async def execute_tool ( self , function_to_call , function_args , function_name , tool_call_id ) :
try :
function_response = await function_to_call ( * * function_args )
except Exception as e :
error_message = f " Error in { function_name } : { str ( e ) } "
function_response = ToolResult ( success = False , output = error_message )
return {
" role " : " tool " ,
" tool_call_id " : tool_call_id ,
" name " : function_name ,
" content " : str ( function_response ) ,
}
2024-10-10 22:21:39 +08:00
async def should_stop ( self , thread_id : str , run_id : str ) - > bool :
2024-10-06 01:04:15 +08:00
async with self . db . get_async_session ( ) as session :
2024-10-17 04:08:46 +08:00
run = await session . get ( ThreadRun , run_id )
2024-10-10 22:21:39 +08:00
if run and run . status in [ " stopped " , " cancelled " , " queued " ] :
return True
return False
async def stop_thread_run ( self , thread_id : str , run_id : str ) - > Dict [ str , Any ] :
async with self . db . get_async_session ( ) as session :
run = await session . get ( ThreadRun , run_id )
if run and run . thread_id == thread_id and run . status == " in_progress " :
run . status = " stopping "
await session . commit ( )
return self . serialize_thread_run ( run )
return None
2024-10-08 03:13:11 +08:00
async def save_thread_run ( self , thread_id : str ) :
2024-10-06 01:04:15 +08:00
async with self . db . get_async_session ( ) as session :
thread = await session . get ( Thread , thread_id )
if not thread :
raise ValueError ( f " Thread with id { thread_id } not found " )
messages = json . loads ( thread . messages )
creation_date = datetime . now ( ) . isoformat ( )
2024-10-08 03:13:11 +08:00
# Get the latest ThreadRun for this thread
stmt = select ( ThreadRun ) . where ( ThreadRun . thread_id == thread_id ) . order_by ( ThreadRun . created_at . desc ( ) ) . limit ( 1 )
result = await session . execute ( stmt )
latest_thread_run = result . scalar_one_or_none ( )
if latest_thread_run :
# Update the existing ThreadRun
latest_thread_run . messages = json . dumps ( messages )
latest_thread_run . last_updated_date = creation_date
await session . commit ( )
else :
# Create a new ThreadRun if none exists
new_thread_run = ThreadRun (
thread_id = thread_id ,
messages = json . dumps ( messages ) ,
creation_date = creation_date ,
status = ' completed '
)
session . add ( new_thread_run )
await session . commit ( )
2024-10-06 01:04:15 +08:00
async def get_thread ( self , thread_id : int ) - > Optional [ Thread ] :
async with self . db . get_async_session ( ) as session :
return await session . get ( Thread , thread_id )
async def update_thread_run_with_error ( self , thread_id : int , error_message : str ) :
async with self . db . get_async_session ( ) as session :
stmt = select ( ThreadRun ) . where ( ThreadRun . thread_id == thread_id ) . order_by ( ThreadRun . run_id . desc ( ) ) . limit ( 1 )
result = await session . execute ( stmt )
thread_run = result . scalar_one_or_none ( )
if thread_run :
thread_run . status = ' error '
thread_run . error_message = error_message # Store the full error message
await session . commit ( )
async def get_threads ( self ) - > List [ Thread ] :
async with self . db . get_async_session ( ) as session :
result = await session . execute ( select ( Thread ) . order_by ( Thread . thread_id . desc ( ) ) )
return result . scalars ( ) . all ( )
2024-10-08 03:13:11 +08:00
async def get_latest_thread_run ( self , thread_id : str ) :
2024-10-06 01:04:15 +08:00
async with self . db . get_async_session ( ) as session :
2024-10-08 03:13:11 +08:00
stmt = select ( ThreadRun ) . where ( ThreadRun . thread_id == thread_id ) . order_by ( ThreadRun . created_at . desc ( ) ) . limit ( 1 )
2024-10-06 01:04:15 +08:00
result = await session . execute ( stmt )
2024-10-08 03:13:11 +08:00
latest_run = result . scalar_one_or_none ( )
if latest_run :
return {
" id " : latest_run . id ,
" status " : latest_run . status ,
" error_message " : latest_run . last_error ,
" created_at " : latest_run . created_at ,
" started_at " : latest_run . started_at ,
" completed_at " : latest_run . completed_at ,
" cancelled_at " : latest_run . cancelled_at ,
" failed_at " : latest_run . failed_at ,
" model " : latest_run . model ,
" usage " : latest_run . usage
}
return None
async def get_run ( self , thread_id : str , run_id : str ) - > Optional [ Dict [ str , Any ] ] :
async with self . db . get_async_session ( ) as session :
run = await session . get ( ThreadRun , run_id )
if run and run . thread_id == thread_id :
return {
" id " : run . id ,
" thread_id " : run . thread_id ,
" status " : run . status ,
" created_at " : run . created_at ,
" started_at " : run . started_at ,
" completed_at " : run . completed_at ,
" cancelled_at " : run . cancelled_at ,
" failed_at " : run . failed_at ,
" model " : run . model ,
" system_message " : json . loads ( run . system_message ) if run . system_message else None ,
" tools " : json . loads ( run . tools ) if run . tools else None ,
" usage " : run . usage ,
" temperature " : run . temperature ,
" top_p " : run . top_p ,
" max_tokens " : run . max_tokens ,
" tool_choice " : run . tool_choice ,
" execute_tools_async " : run . execute_tools_async ,
" response_format " : json . loads ( run . response_format ) if run . response_format else None ,
" last_error " : run . last_error
}
return None
async def cancel_run ( self , thread_id : str , run_id : str ) - > Optional [ Dict [ str , Any ] ] :
async with self . db . get_async_session ( ) as session :
run = await session . get ( ThreadRun , run_id )
if run and run . thread_id == thread_id and run . status == " in_progress " :
run . status = " cancelled "
2024-10-17 04:08:46 +08:00
run . cancelled_at = int ( datetime . now ( UTC ) . timestamp ( ) )
2024-10-08 03:13:11 +08:00
await session . commit ( )
return await self . get_run ( thread_id , run_id )
return None
async def list_runs ( self , thread_id : str , limit : int ) - > List [ Dict [ str , Any ] ] :
async with self . db . get_async_session ( ) as session :
2024-10-10 22:21:39 +08:00
thread_runs_stmt = select ( ThreadRun ) . where ( ThreadRun . thread_id == thread_id ) . order_by ( ThreadRun . created_at . desc ( ) ) . limit ( limit )
thread_runs_result = await session . execute ( thread_runs_stmt )
thread_runs = thread_runs_result . scalars ( ) . all ( )
2024-10-17 04:08:46 +08:00
return [ self . serialize_thread_run ( run ) for run in thread_runs ]
2024-10-10 22:21:39 +08:00
async def create_thread_run ( self , thread_id : str , * * kwargs ) - > ThreadRun :
run_id = str ( uuid . uuid4 ( ) )
thread_run = ThreadRun (
id = run_id ,
thread_id = thread_id ,
status = " queued " ,
model = kwargs . get ( ' model_name ' ) ,
temperature = kwargs . get ( ' temperature ' ) ,
max_tokens = kwargs . get ( ' max_tokens ' ) ,
top_p = kwargs . get ( ' top_p ' ) ,
tool_choice = kwargs . get ( ' tool_choice ' , " auto " ) ,
execute_tools_async = kwargs . get ( ' execute_tools_async ' , True ) ,
system_message = json . dumps ( kwargs . get ( ' system_message ' ) ) ,
tools = json . dumps ( kwargs . get ( ' tools ' ) ) ,
2024-10-17 04:08:46 +08:00
response_format = json . dumps ( kwargs . get ( ' response_format ' ) ) ,
autonomous_iterations_amount = kwargs . get ( ' autonomous_iterations_amount ' ) ,
continue_instructions = kwargs . get ( ' continue_instructions ' )
2024-10-10 22:21:39 +08:00
)
async with self . db . get_async_session ( ) as session :
session . add ( thread_run )
await session . commit ( )
return thread_run
async def get_thread_run_count ( self , thread_id : str ) - > int :
async with self . db . get_async_session ( ) as session :
result = await session . execute ( select ( ThreadRun ) . filter_by ( thread_id = thread_id ) )
return len ( result . all ( ) )
async def get_thread_run_status ( self , thread_id : str , run_id : str ) - > Dict [ str , Any ] :
async with self . db . get_async_session ( ) as session :
run = await session . get ( ThreadRun , run_id )
if run and run . thread_id == thread_id :
return self . serialize_thread_run ( run )
return None
def serialize_thread_run ( self , run : ThreadRun ) - > Dict [ str , Any ] :
return {
" id " : run . id ,
" thread_id " : run . thread_id ,
" status " : run . status ,
" created_at " : run . created_at ,
" started_at " : run . started_at ,
" completed_at " : run . completed_at ,
" cancelled_at " : run . cancelled_at ,
" failed_at " : run . failed_at ,
" model " : run . model ,
" temperature " : run . temperature ,
" max_tokens " : run . max_tokens ,
" top_p " : run . top_p ,
" tool_choice " : run . tool_choice ,
" execute_tools_async " : run . execute_tools_async ,
" system_message " : json . loads ( run . system_message ) if run . system_message else None ,
" tools " : json . loads ( run . tools ) if run . tools else None ,
" usage " : run . usage ,
" response_format " : json . loads ( run . response_format ) if run . response_format else None ,
" last_error " : run . last_error ,
2024-10-17 04:08:46 +08:00
" autonomous_iterations_amount " : run . autonomous_iterations_amount ,
" continue_instructions " : run . continue_instructions ,
" iterations " : json . loads ( run . iterations ) if run . iterations else None
2024-10-10 22:21:39 +08:00
}
2024-10-17 04:08:46 +08:00
def serialize_usage ( self , usage ) :
2024-10-10 22:21:39 +08:00
return {
2024-10-17 04:08:46 +08:00
" completion_tokens " : usage . completion_tokens ,
" prompt_tokens " : usage . prompt_tokens ,
" total_tokens " : usage . total_tokens ,
" completion_tokens_details " : self . serialize_completion_tokens_details ( usage . completion_tokens_details ) ,
" prompt_tokens_details " : self . serialize_prompt_tokens_details ( usage . prompt_tokens_details )
}
def serialize_completion_tokens_details ( self , details ) :
return {
" audio_tokens " : details . audio_tokens ,
" reasoning_tokens " : details . reasoning_tokens
}
def serialize_prompt_tokens_details ( self , details ) :
return {
" audio_tokens " : details . audio_tokens ,
" cached_tokens " : details . cached_tokens
}
def serialize_choice ( self , choice ) :
return {
" finish_reason " : choice . finish_reason ,
" index " : choice . index ,
" message " : self . serialize_message ( choice . message )
}
def serialize_message ( self , message ) :
return {
" content " : message . content ,
" role " : message . role ,
" tool_calls " : [ self . serialize_tool_call ( tc ) for tc in message . tool_calls ] if message . tool_calls else None
}
def serialize_tool_call ( self , tool_call ) :
return {
" id " : tool_call . id ,
" type " : tool_call . type ,
" function " : {
" name " : tool_call . function . name ,
" arguments " : tool_call . function . arguments
}
2024-10-10 22:21:39 +08:00
}
2024-10-06 01:04:15 +08:00
if __name__ == " __main__ " :
import asyncio
2024-10-10 22:21:39 +08:00
from agentpress . db import Database
2024-10-17 04:08:46 +08:00
from tools . files_tool import FilesTool
2024-10-06 01:04:15 +08:00
async def main ( ) :
db = Database ( )
manager = ThreadManager ( db )
thread_id = await manager . create_thread ( )
2024-10-17 04:08:46 +08:00
await manager . add_message ( thread_id , { " role " : " user " , " content " : " Let ' s have a conversation about artificial intelligence and create a file summarizing our discussion. " } )
2024-10-06 01:04:15 +08:00
2024-10-17 04:08:46 +08:00
system_message = { " role " : " system " , " content " : " You are an AI expert engaging in a conversation about artificial intelligence. You can also create and manage files. " }
files_tool = FilesTool ( )
tool_schemas = files_tool . get_schemas ( )
def initializer ( ) :
print ( " Initializing thread run... " )
manager . run_config [ ' temperature ' ] = 0.8
def pre_iteration ( ) :
print ( f " Preparing iteration { manager . current_iteration } ... " )
manager . run_config [ ' max_tokens ' ] = 200 if manager . current_iteration > 3 else 150
def after_iteration ( ) :
print ( f " Completed iteration { manager . current_iteration } . Status: { manager . run_config [ ' status ' ] } " )
manager . run_config [ ' continue_instructions ' ] = " Let ' s focus more on AI ethics in the next iteration and update our summary file. "
def finalizer ( ) :
print ( f " Thread run finished with status: { manager . run_config [ ' status ' ] } " )
print ( f " Final configuration: { manager . run_config } " )
settings = {
" thread_id " : thread_id ,
" system_message " : system_message ,
" model_name " : " gpt-4o " ,
" temperature " : 0.7 ,
" max_tokens " : 150 ,
" autonomous_iterations_amount " : 3 ,
" continue_instructions " : " Continue the conversation about AI, introducing new aspects or asking thought-provoking questions. Don ' t forget to update our summary file. " ,
" initializer " : initializer ,
" pre_iteration " : pre_iteration ,
" after_iteration " : after_iteration ,
" finalizer " : finalizer ,
" tools " : list ( tool_schemas . keys ( ) ) ,
" tool_choice " : " auto "
}
2024-10-10 22:21:39 +08:00
2024-10-17 04:08:46 +08:00
response = await manager . run_thread ( settings )
2024-10-06 01:04:15 +08:00
2024-10-17 04:08:46 +08:00
print ( f " Thread run response: { response } " )
2024-10-10 22:21:39 +08:00
messages = await manager . list_messages ( thread_id )
print ( " \n Final conversation: " )
for msg in messages :
print ( f " { msg [ ' role ' ] . capitalize ( ) } : { msg [ ' content ' ] } " )
2024-10-06 01:04:15 +08:00
asyncio . run ( main ( ) )