mirror of https://github.com/kortix-ai/suna.git
feat(api): enhance thread creation and agent functionality
- Updated the `create_thread` endpoint to set a default name if none is provided. - Modified the `Agent` class to use a dictionary for `agentpress_tools` instead of a custom type. - Improved the `LocalKVStore` initialization to use a hidden filename. - Added a new weather retrieval tool to the MCP. - Updated agent and thread setup in the main function for better clarity and functionality. - Refactored message handling in the `ThreadsClient` to use query parameters for message posting.
This commit is contained in:
parent
307a9a80ae
commit
7050490d03
|
@ -2702,7 +2702,7 @@ async def get_thread(
|
|||
|
||||
@router.post("/threads", response_model=CreateThreadResponse)
|
||||
async def create_thread(
|
||||
name: Optional[str] = Form(...),
|
||||
name: Optional[str] = Form(None),
|
||||
user_id: str = Depends(get_current_user_id_from_jwt)
|
||||
):
|
||||
"""
|
||||
|
@ -2710,6 +2710,8 @@ async def create_thread(
|
|||
|
||||
[WARNING] Keep in sync with initiate endpoint.
|
||||
"""
|
||||
if not name:
|
||||
name = "New Project"
|
||||
logger.info(f"Creating new thread with name: {name}")
|
||||
client = await db.client
|
||||
account_id = user_id # In Basejump, personal account_id is the same as user_id
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
"""
|
||||
Kortix SDK for Suna AI Agent Platform
|
||||
|
||||
A Python SDK for creating and managing AI agents with tool execution capabilities.
|
||||
A Python SDK for creating and managing AI agents with thread execution capabilities.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from .kortix.config import global_config as config
|
||||
from kortix._kortix import Kortix
|
||||
from kortix.tools import AgentPressTools, KortixMCP
|
||||
|
||||
__all__ = ["config"]
|
||||
__all__ = ["Kortix", "AgentPressTools", "KortixMCP"]
|
||||
|
|
|
@ -3,7 +3,6 @@ from thread import Thread, AgentRun
|
|||
from tools import AgentPressTools, KortixMCP, KortixTools
|
||||
from api.agents import (
|
||||
AgentCreateRequest,
|
||||
AgentPress_Tools,
|
||||
AgentPress_ToolConfig,
|
||||
AgentsClient,
|
||||
CustomMCP,
|
||||
|
@ -35,7 +34,7 @@ class KortixAgent:
|
|||
async def create(
|
||||
self, name: str, system_prompt: str, model: str, tools: list[KortixTools] = []
|
||||
) -> Agent:
|
||||
agentpress_tools = AgentPress_Tools()
|
||||
agentpress_tools = {}
|
||||
custom_mcps: list[CustomMCP] = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, AgentPressTools):
|
||||
|
|
|
@ -25,18 +25,13 @@ class AgentPress_ToolConfig:
|
|||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentPress_Tools(Dict[AgentPressTools, AgentPress_ToolConfig]):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCreateRequest:
|
||||
name: str
|
||||
system_prompt: str
|
||||
description: Optional[str] = None
|
||||
custom_mcps: Optional[List[CustomMCP]] = None
|
||||
agentpress_tools: Optional[AgentPress_Tools] = None
|
||||
agentpress_tools: Optional[Dict[AgentPressTools, AgentPress_ToolConfig]] = None
|
||||
is_default: bool = False
|
||||
avatar: Optional[str] = None
|
||||
avatar_color: Optional[str] = None
|
||||
|
@ -48,7 +43,7 @@ class AgentUpdateRequest:
|
|||
description: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
custom_mcps: Optional[List[CustomMCP]] = None
|
||||
agentpress_tools: Optional[AgentPress_Tools] = None
|
||||
agentpress_tools: Optional[Dict[AgentPressTools, AgentPress_ToolConfig]] = None
|
||||
is_default: Optional[bool] = None
|
||||
avatar: Optional[str] = None
|
||||
avatar_color: Optional[str] = None
|
||||
|
@ -75,7 +70,7 @@ class AgentVersionResponse:
|
|||
version_name: str
|
||||
system_prompt: str
|
||||
custom_mcps: List[CustomMCP]
|
||||
agentpress_tools: AgentPress_Tools
|
||||
agentpress_tools: Dict[AgentPressTools, AgentPress_ToolConfig]
|
||||
is_active: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
@ -89,7 +84,7 @@ class AgentResponse:
|
|||
name: str
|
||||
system_prompt: str
|
||||
custom_mcps: List[CustomMCP]
|
||||
agentpress_tools: AgentPress_Tools
|
||||
agentpress_tools: Dict[AgentPressTools, AgentPress_ToolConfig]
|
||||
is_default: bool
|
||||
created_at: str
|
||||
description: Optional[str] = None
|
||||
|
|
|
@ -92,6 +92,10 @@ class Message:
|
|||
is_llm_message: bool
|
||||
content: Any # Can be string, dict, or ContentObject
|
||||
created_at: str
|
||||
updated_at: str
|
||||
agent_id: str
|
||||
agent_version_id: str
|
||||
metadata: Any
|
||||
|
||||
@property
|
||||
def message_type(self) -> MessageType:
|
||||
|
@ -429,7 +433,7 @@ class ThreadsClient:
|
|||
# This endpoint expects form data, not JSON
|
||||
response = await self.client.post(
|
||||
f"/threads/{thread_id}/messages/add",
|
||||
data={"message": message},
|
||||
params={"message": message},
|
||||
headers={k: v for k, v in self.headers.items() if k != "Content-Type"},
|
||||
)
|
||||
data = self._handle_response(response)
|
||||
|
@ -462,11 +466,10 @@ class ThreadsClient:
|
|||
Returns:
|
||||
CreateThreadResponse containing the new thread ID and project ID
|
||||
"""
|
||||
# This endpoint expects form data, not JSON
|
||||
form_data = {"name": name} if name is not None else {"name": "New Thread"}
|
||||
data = None if name is None else {"name": name}
|
||||
response = await self.client.post(
|
||||
"/threads",
|
||||
data=form_data,
|
||||
data=data,
|
||||
headers={k: v for k, v in self.headers.items() if k != "Content-Type"},
|
||||
)
|
||||
data = self._handle_response(response)
|
||||
|
|
|
@ -5,10 +5,13 @@ from typing import Any, Optional
|
|||
|
||||
|
||||
from _kortix import Kortix
|
||||
from tools import AgentPressTools, KortixMCP
|
||||
from fastmcp import FastMCP
|
||||
|
||||
|
||||
# Local key-value store for storing agent and thread IDs
|
||||
class LocalKVStore:
|
||||
def __init__(self, filename: str = "kvstore.json"):
|
||||
def __init__(self, filename: str = ".kvstore.json"):
|
||||
self.filename = filename
|
||||
self._data = {}
|
||||
self._load()
|
||||
|
@ -47,20 +50,41 @@ class LocalKVStore:
|
|||
kv = LocalKVStore()
|
||||
|
||||
|
||||
mcp = FastMCP(name="Kortix")
|
||||
|
||||
|
||||
@mcp.tool
|
||||
async def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is rainy."
|
||||
|
||||
|
||||
async def main():
|
||||
kortixMCP = KortixMCP(mcp, "http://localhost:4000/mcp/")
|
||||
await kortixMCP.initialize()
|
||||
|
||||
# Start the MCP server in the background
|
||||
asyncio.create_task(
|
||||
mcp.run_http_async(
|
||||
show_banner=False, log_level="error", host="0.0.0.0", port=4000
|
||||
)
|
||||
)
|
||||
|
||||
kortix = Kortix("af3d6952-2109-4ab3-bbfe-4a2e2326c740", "http://localhost:8000/api")
|
||||
|
||||
# Setup the agent
|
||||
agent_id = kv.get("agent_id")
|
||||
if not agent_id:
|
||||
agent = await kortix.Agent.create(
|
||||
name="Test Agent",
|
||||
system_prompt="You are a test agent. You only respond with 'Hello, world!'",
|
||||
name="Generic Agent",
|
||||
system_prompt="You are a generic agent. You can use the tools provided to you to answer questions.",
|
||||
model="gpt-4o-mini",
|
||||
tools=[AgentPressTools.WEB_SEARCH_TOOL, kortixMCP],
|
||||
)
|
||||
kv.set("agent_id", agent._agent_id)
|
||||
else:
|
||||
agent = await kortix.Agent.get(agent_id)
|
||||
|
||||
# Setup the thread
|
||||
thread_id = kv.get("thread_id")
|
||||
if not thread_id:
|
||||
thread = await kortix.Thread.create()
|
||||
|
@ -68,7 +92,8 @@ async def main():
|
|||
else:
|
||||
thread = await kortix.Thread.get(thread_id)
|
||||
|
||||
agent_run = await agent.run("What is the weather in Tokyo?", thread)
|
||||
# Run the agent
|
||||
agent_run = await agent.run("What is the weather in Singapore?", thread)
|
||||
|
||||
stream = await agent_run.get_stream()
|
||||
|
||||
|
@ -77,4 +102,7 @@ async def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except:
|
||||
exit(0)
|
||||
|
|
|
@ -4,16 +4,20 @@ from fastmcp import FastMCP
|
|||
|
||||
|
||||
class KortixMCP:
|
||||
async def create(self, endpoint: str, mcp: FastMCP):
|
||||
def __init__(self, mcp: FastMCP, endpoint: str):
|
||||
self._fastmcp = mcp
|
||||
self.url = endpoint
|
||||
self.name = mcp.name
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
self.name = self._fastmcp.name
|
||||
self.type = "http"
|
||||
self.enabled_tools: list[str] = []
|
||||
tools = await mcp.get_tools()
|
||||
tools = await self._fastmcp.get_tools()
|
||||
for tool in tools.values():
|
||||
if tool.enabled:
|
||||
self.enabled_tools.append(tool.name)
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue