This commit is contained in:
marko-kraemer 2024-10-17 13:44:10 +02:00
parent cceba7ad16
commit 1eb750eacd
3 changed files with 44 additions and 35 deletions

View File

@ -1,52 +1,60 @@
from typing import Dict, Any
import asyncio
from agentpress.db import Database
from agentpress.thread_manager import ThreadAgent, ThreadManager
from agentpress.tool_registry import ToolRegistry
from agentpress.thread_manager import ThreadManager
from tools.files_tool import FilesTool
async def run_agent():
db = Database()
manager = ThreadManager(db)
tool_registry = ToolRegistry()
thread_id = await manager.create_thread()
await manager.add_message(thread_id, {"role": "user", "content": "Let's create a Python + HTML website. Start by outlining the project structure. Use the available tools to create and edit files as needed."})
await manager.add_message(thread_id, {"role": "user", "content": "Let's have a conversation about artificial intelligence and create a file summarizing our discussion."})
tools = tool_registry.get_all_tools()
system_message = {"role": "system", "content": "You are an AI expert engaging in a conversation about artificial intelligence. You can also create and manage files."}
async def initializer(agent):
pass
files_tool = FilesTool()
tool_schemas = files_tool.get_schemas()
async def pre_iteration(iteration: int, agent: ThreadAgent):
pass
def initializer():
print("Initializing thread run...")
manager.run_config['temperature'] = 0.8
async def after_iteration(iteration: int, result: Dict[str, Any], agent: ThreadAgent):
pass
def pre_iteration():
print(f"Preparing iteration {manager.current_iteration}...")
manager.run_config['max_tokens'] = 200 if manager.current_iteration > 3 else 150
async def finalizer(status: str, agent: ThreadAgent):
pass
def after_iteration():
print(f"Completed iteration {manager.current_iteration}. Status: {manager.run_config['status']}")
manager.run_config['continue_instructions'] = "Let's focus more on AI ethics in the next iteration and update our summary file."
system_message = {
"role": "system",
"content": "You are an expert AI Software Developer specializing in creating Python + HTML websites. Your task is to design, structure, and implement a website. Use the available tools to create and edit files for the project. Provide clear explanations and use the tools to implement code snippets directly into the appropriate files."
def finalizer():
print(f"Thread run finished with status: {manager.run_config['status']}")
print(f"Final configuration: {manager.run_config}")
settings = {
"thread_id": thread_id,
"system_message": system_message,
"model_name": "gpt-4",
"temperature": 0.7,
"max_tokens": 150,
"autonomous_iterations_amount": 3,
"continue_instructions": "Continue the conversation about AI, introducing new aspects or asking thought-provoking questions. Don't forget to update our summary file.",
"initializer": initializer,
"pre_iteration": pre_iteration,
"after_iteration": after_iteration,
"finalizer": finalizer,
"tools": list(tool_schemas.keys()),
"tool_choice": "auto"
}
response = await manager.run_thread(settings)
response = await manager.run_thread_agent(
thread_id=thread_id,
system_message=system_message,
model_name="gpt-4",
temperature=0.7,
max_tokens=4096,
autonomous_iterations_amount=3,
continue_instructions="Continue developing the Python + HTML website. Focus on one aspect at a time, such as backend setup, frontend design, or specific features. Use the available tools to create new files or edit existing ones. Implement code snippets directly into the appropriate files using the tools.",
initializer=initializer,
pre_iteration=pre_iteration,
after_iteration=after_iteration,
finalizer=finalizer,
tools=list(tools.keys())
)
return response
print(f"Thread run response: {response}")
messages = await manager.list_messages(thread_id)
print("\nFinal conversation:")
for msg in messages:
print(f"{msg['role'].capitalize()}: {msg['content']}")
if __name__ == "__main__":
import asyncio
asyncio.run(run_agent())

View File

@ -417,6 +417,7 @@ class ThreadManager:
"role": "assistant",
"content": response_message.get('content') or "",
}
tool_calls = response_message.get('tool_calls')
if tool_calls:
message["tool_calls"] = [
@ -430,7 +431,7 @@ class ThreadManager:
} for tool_call in tool_calls
]
return message
def get_available_functions(self) -> Dict[str, Callable]:
available_functions = {}
for tool_name, tool_info in self.tool_registry.get_all_tools().items():

BIN
main.db

Binary file not shown.