mirror of https://github.com/kortix-ai/suna.git
58 lines
2.3 KiB
Python
58 lines
2.3 KiB
Python
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
from sqlalchemy.future import select
|
||
|
from core.db import Agent
|
||
|
from datetime import datetime
|
||
|
from typing import List, Optional
|
||
|
import json
|
||
|
|
||
|
class AgentManager:
|
||
|
def __init__(self, db):
|
||
|
self.db = db
|
||
|
|
||
|
async def create_agent(self, model: str, name: str, system_prompt: str, selected_tools: List[str], temperature: float = 0.5) -> int:
|
||
|
async with self.db.get_async_session() as session:
|
||
|
new_agent = Agent(
|
||
|
model=model,
|
||
|
name=name,
|
||
|
system_prompt=system_prompt,
|
||
|
selected_tools=selected_tools, # Store as a list directly
|
||
|
temperature=temperature,
|
||
|
created_at=datetime.now().isoformat()
|
||
|
)
|
||
|
session.add(new_agent)
|
||
|
await session.commit()
|
||
|
await session.refresh(new_agent)
|
||
|
return new_agent.id
|
||
|
|
||
|
async def get_agent(self, agent_id: int) -> Optional[Agent]:
|
||
|
async with self.db.get_async_session() as session:
|
||
|
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
||
|
agent = result.scalar_one_or_none()
|
||
|
return agent
|
||
|
|
||
|
async def update_agent(self, agent_id: int, **kwargs) -> bool:
|
||
|
async with self.db.get_async_session() as session:
|
||
|
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
||
|
agent = result.scalar_one_or_none()
|
||
|
if agent:
|
||
|
for key, value in kwargs.items():
|
||
|
setattr(agent, key, value)
|
||
|
await session.commit()
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
async def delete_agent(self, agent_id: int) -> bool:
|
||
|
async with self.db.get_async_session() as session:
|
||
|
result = await session.execute(select(Agent).filter(Agent.id == agent_id))
|
||
|
agent = result.scalar_one_or_none()
|
||
|
if agent:
|
||
|
await session.delete(agent)
|
||
|
await session.commit()
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
async def list_agents(self) -> List[Agent]:
|
||
|
async with self.db.get_async_session() as session:
|
||
|
result = await session.execute(select(Agent))
|
||
|
agents = result.scalars().all()
|
||
|
return agents
|