diff --git a/core/agent_manager.py b/core/agent_manager.py deleted file mode 100644 index 1766a7d0..00000000 --- a/core/agent_manager.py +++ /dev/null @@ -1,58 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from core.db import Agent -from datetime import datetime -from typing import List, Optional -import json - -class AgentManager: - def __init__(self, db): - self.db = db - - async def create_agent(self, model: str, name: str, system_prompt: str, selected_tools: List[str], temperature: float = 0.5) -> int: - async with self.db.get_async_session() as session: - new_agent = Agent( - model=model, - name=name, - system_prompt=system_prompt, - selected_tools=selected_tools, # Store as a list directly - temperature=temperature, - created_at=datetime.now().isoformat() - ) - session.add(new_agent) - await session.commit() - await session.refresh(new_agent) - return new_agent.id - - async def get_agent(self, agent_id: int) -> Optional[Agent]: - async with self.db.get_async_session() as session: - result = await session.execute(select(Agent).filter(Agent.id == agent_id)) - agent = result.scalar_one_or_none() - return agent - - async def update_agent(self, agent_id: int, **kwargs) -> bool: - async with self.db.get_async_session() as session: - result = await session.execute(select(Agent).filter(Agent.id == agent_id)) - agent = result.scalar_one_or_none() - if agent: - for key, value in kwargs.items(): - setattr(agent, key, value) - await session.commit() - return True - return False - - async def delete_agent(self, agent_id: int) -> bool: - async with self.db.get_async_session() as session: - result = await session.execute(select(Agent).filter(Agent.id == agent_id)) - agent = result.scalar_one_or_none() - if agent: - await session.delete(agent) - await session.commit() - return True - return False - - async def list_agents(self) -> List[Agent]: - async with self.db.get_async_session() as session: - result = await session.execute(select(Agent)) - agents = result.scalars().all() - return agents \ No newline at end of file diff --git a/core/api.py b/core/api.py index 5f2cab21..dd499845 100644 --- a/core/api.py +++ b/core/api.py @@ -1,109 +1,107 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel +from fastapi import FastAPI, HTTPException, Query, Path +from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional import asyncio +from fastapi.responses import StreamingResponse from core.db import Database from core.thread_manager import ThreadManager -from core.agent_manager import AgentManager -from core.tools.tool_registry import ToolRegistry +from core.tool_registry import ToolRegistry from core.config import Settings -app = FastAPI() +app = FastAPI( + title="Thread Manager API", + description="API for managing and running threads with LLM integration", + version="1.0.0", +) + db = Database() manager = ThreadManager(db) -tool_registry = ToolRegistry() # Initialize here -agent_manager = AgentManager(db) +tool_registry = ToolRegistry() -# Pydantic models for request and response bodies class Message(BaseModel): - role: str - content: str + role: str = Field(..., description="The role of the message sender (e.g., 'user', 'assistant')") + content: str = Field(..., description="The content of the message") class RunThreadRequest(BaseModel): - system_message: Dict[str, Any] - model_name: str - temperature: float = 0.5 - max_tokens: Optional[int] = 500 - tools: Optional[List[str]] = None - tool_choice: str = "required" - additional_instructions: Optional[str] = None - stream: bool = False + system_message: Dict[str, Any] = Field(..., description="The system message to be used for the thread run") + model_name: str = Field(..., description="The name of the LLM model to be used") + temperature: float = Field(0.5, description="The sampling temperature for the LLM") + max_tokens: Optional[int] = Field(None, description="The maximum number of tokens to generate") + tools: Optional[List[str]] = Field(None, description="The list of tools to be used in the thread run") + tool_choice: str = Field("auto", description="Controls which tool is called by the model") + additional_system_message: Optional[str] = Field(None, description="Additional system message to be appended") + stream: bool = Field(False, description="Whether to stream the response") + top_p: Optional[float] = Field(None, description="The nucleus sampling value") + response_format: Optional[Dict[str, Any]] = Field(None, description="Specifies the format that the model must output") -class Agent(BaseModel): - name: str - model: str - system_prompt: str - selected_tools: List[str] - temperature: float = 0.5 - -@app.post("/threads/") +@app.post("/threads/", response_model=Dict[str, str], summary="Create a new thread") async def create_thread(): + """ + Create a new thread and return its ID. + """ thread_id = await manager.create_thread() return {"thread_id": thread_id} -@app.get("/threads/") +@app.get("/threads/", response_model=List[Dict[str, str]], summary="Get all threads") async def get_threads(): + """ + Retrieve a list of all thread IDs. + """ threads = await manager.get_threads() return [{"thread_id": thread.thread_id} for thread in threads] -@app.post("/threads/{thread_id}/messages/") -async def add_message(thread_id: int, message: Message): +@app.post("/threads/{thread_id}/messages/", response_model=Dict[str, str], summary="Add a message to a thread") +async def add_message(thread_id: str, message: Message): + """ + Add a new message to the specified thread. + """ await manager.add_message(thread_id, message.dict()) return {"status": "success"} -@app.get("/threads/{thread_id}/messages/") -async def list_messages(thread_id: int): +@app.get("/threads/{thread_id}/messages/", response_model=List[Dict[str, Any]], summary="List messages in a thread") +async def list_messages(thread_id: str): + """ + Retrieve all messages from the specified thread. + """ messages = await manager.list_messages(thread_id) return messages -@app.post("/threads/{thread_id}/run/") -async def run_thread(thread_id: int, request: Dict[str, Any]): - if 'agent_id' in request: - # Agent-based run - response_gen = manager.run_thread( - thread_id=thread_id, - agent_id=request['agent_id'], - additional_instructions=request.get('additional_instructions'), - stream=request.get('stream', False) - ) +@app.post("/threads/{thread_id}/run/", response_model=Dict[str, Any], summary="Run a thread") +async def run_thread(thread_id: str, request: RunThreadRequest): + """ + Run the specified thread with the given parameters. + """ + response = await manager.run_thread( + thread_id=thread_id, + system_message=request.system_message, + model_name=request.model_name, + temperature=request.temperature, + max_tokens=request.max_tokens, + tools=request.tools, + additional_system_message=request.additional_system_message, + top_p=request.top_p, + tool_choice=request.tool_choice, + response_format=request.response_format) + + return {"status": "success", "response": response} + +@app.get("/threads/{thread_id}/run/status/", response_model=Dict[str, Any], summary="Get thread run status") +async def get_thread_run_status(thread_id: str): + """ + Retrieve the status of the latest run for the specified thread. + """ + latest_thread_run = await manager.get_latest_thread_run(thread_id) + if latest_thread_run: + return latest_thread_run else: - # Manual configuration run - response_gen = manager.run_thread( - thread_id=thread_id, - system_message=request['system_message'], - model_name=request['model_name'], - temperature=request['temperature'], - max_tokens=request['max_tokens'], - tools=request['tools'], - additional_instructions=request.get('additional_instructions'), - stream=request.get('stream', False) - ) + return {"status": "No runs found for this thread."} - if request.get('stream', False): - raise HTTPException(status_code=501, detail="Streaming is not supported via this endpoint.") - else: - response = [] - async for chunk in response_gen: - response.append(chunk) - return {"response": response} - -@app.get("/threads/{thread_id}/run/status/") -async def get_thread_run_status(thread_id: int): - try: - latest_thread_run = await manager.get_latest_thread_run(thread_id) - if latest_thread_run: - return { - "status": latest_thread_run.status, - "error_message": latest_thread_run.error_message - } - else: - return {"status": "No runs found for this thread."} - except AttributeError: - return {"status": "Error", "message": "Unable to retrieve thread run status."} - -@app.get("/tools/") +@app.get("/tools/", response_model=Dict[str, Dict[str, Any]], summary="Get available tools") async def get_tools(): + """ + Retrieve a list of all available tools and their schemas. + """ tools = tool_registry.get_all_tools() if not tools: print("No tools found in the registry") # Debug print @@ -116,52 +114,45 @@ async def get_tools(): for name, tool_info in tools.items() } -@app.post("/agents/") -async def create_agent(agent: Agent): - agent_id = await agent_manager.create_agent(**agent.dict()) - return {"agent_id": agent_id} +@app.get("/threads/{thread_id}/runs/{run_id}", response_model=Dict[str, Any], summary="Retrieve a run") +async def get_run( + thread_id: str = Path(..., description="The ID of the thread that was run"), + run_id: str = Path(..., description="The ID of the run to retrieve") +): + """ + Retrieve the run object matching the specified ID. + """ + run = await manager.get_run(thread_id, run_id) + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + return run -@app.get("/agents/") -async def list_agents(): - agents = await agent_manager.list_agents() - return [ - { - "id": agent.id, - "name": agent.name, - "model": agent.model, - "system_prompt": agent.system_prompt, - "selected_tools": agent.selected_tools, - "temperature": agent.temperature, - "created_at": agent.created_at - } - for agent in agents - ] +@app.post("/threads/{thread_id}/runs/{run_id}/cancel", response_model=Dict[str, Any], summary="Cancel a run") +async def cancel_run( + thread_id: str = Path(..., description="The ID of the thread to which this run belongs"), + run_id: str = Path(..., description="The ID of the run to cancel") +): + """ + Cancels a run that is in_progress. + """ + run = await manager.cancel_run(thread_id, run_id) + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + return run -@app.get("/agents/{agent_id}") -async def get_agent(agent_id: int): - agent = await agent_manager.get_agent(agent_id) - if agent: - return { - "id": agent.id, - "name": agent.name, - "model": agent.model, - "system_prompt": agent.system_prompt, - "selected_tools": agent.selected_tools, - "temperature": agent.temperature, - "created_at": agent.created_at - } - raise HTTPException(status_code=404, detail="Agent not found") +@app.get("/threads/{thread_id}/runs", response_model=List[Dict[str, Any]], summary="List runs") +async def list_runs( + thread_id: str = Path(..., description="The ID of the thread the runs belong to"), + limit: int = Query(20, ge=1, le=100, description="A limit on the number of objects to be returned") +): + """ + Returns a list of runs belonging to a thread. + """ + runs = await manager.list_runs(thread_id, limit) + return runs -@app.put("/agents/{agent_id}") -async def update_agent(agent_id: int, agent: Agent): - success = await agent_manager.update_agent(agent_id, **agent.dict()) - if success: - return {"status": "success"} - raise HTTPException(status_code=404, detail="Agent not found") +# Add more endpoints as needed for production use -@app.delete("/agents/{agent_id}") -async def delete_agent(agent_id: int): - success = await agent_manager.delete_agent(agent_id) - if success: - return {"status": "success"} - raise HTTPException(status_code=404, detail="Agent not found") \ No newline at end of file +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/core/db.py b/core/db.py index c3d1bc72..ceeacaa8 100644 --- a/core/db.py +++ b/core/db.py @@ -1,17 +1,19 @@ -from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON +from sqlalchemy import Column, Integer, String, Text, ForeignKey, Float, JSON, Boolean from sqlalchemy.orm import relationship, declarative_base from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from core.config import settings # Changed from Settings to settings import os from contextlib import asynccontextmanager +import uuid +from datetime import datetime Base = declarative_base() class Thread(Base): __tablename__ = 'threads' - thread_id = Column(Integer, primary_key=True) + thread_id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) messages = Column(Text) creation_date = Column(String) last_updated_date = Column(String) @@ -22,25 +24,31 @@ class Thread(Base): class ThreadRun(Base): __tablename__ = 'thread_runs' - run_id = Column(Integer, primary_key=True) - thread_id = Column(Integer, ForeignKey('threads.thread_id')) - messages = Column(Text) - creation_date = Column(String) - status = Column(String) - error_message = Column(Text, nullable=True) + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + thread_id = Column(String(36), ForeignKey('threads.thread_id')) + created_at = Column(Integer) + status = Column(String) + last_error = Column(Text, nullable=True) + started_at = Column(Integer, nullable=True) + cancelled_at = Column(Integer, nullable=True) + failed_at = Column(Integer, nullable=True) + completed_at = Column(Integer, nullable=True) + model = Column(String) + system_message = Column(Text) + tools = Column(JSON, nullable=True) + usage = Column(JSON, nullable=True) + temperature = Column(Float, nullable=True) + top_p = Column(Float, nullable=True) + max_tokens = Column(Integer, nullable=True) + tool_choice = Column(String, nullable=True) + execute_tools_async = Column(Boolean) + response_format = Column(JSON, nullable=True) thread = relationship("Thread", back_populates="thread_runs") -class Agent(Base): - __tablename__ = 'agents' - - id = Column(Integer, primary_key=True) - model = Column(String, nullable=False) - name = Column(String, nullable=False) - system_prompt = Column(Text, nullable=False) - selected_tools = Column(JSON) # Changed from ARRAY to JSON - temperature = Column(Float, default=0.5) - created_at = Column(String, nullable=False) + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.created_at = int(datetime.utcnow().timestamp()) # class MemoryModule(Base): # __tablename__ = 'memory_modules' diff --git a/core/llm.py b/core/llm.py index e8d70521..ac5b3fa2 100644 --- a/core/llm.py +++ b/core/llm.py @@ -1,4 +1,4 @@ -from typing import Union, AsyncGenerator +from typing import Union, Dict, Any import litellm import os import json @@ -26,24 +26,13 @@ os.environ['GROQ_API_KEY'] = GROQ_API_KEY logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False) -> AsyncGenerator[Union[dict, str], None]: +async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0, max_tokens=None, tools=None, tool_choice="auto", use_tool_parser=False, api_key=None, api_base=None, agentops_session=None, stream=False, top_p=None, response_format=None) -> Union[Dict[str, Any], str]: litellm.set_verbose = True async def attempt_api_call(api_call_func, max_attempts=3): for attempt in range(max_attempts): try: - response = await api_call_func() - if stream: - async for chunk in response: - yield chunk - else: - response_content = response.choices[0].message['content'] if json_mode else response - if json_mode: - if not json.loads(response_content): - logger.info(f"Invalid JSON received, retrying attempt {attempt + 1}") - continue - yield response - return + return await api_call_func() except litellm.exceptions.RateLimitError as e: logger.warning(f"Rate limit exceeded. Waiting for 30 seconds before retrying...") await asyncio.sleep(30) @@ -60,7 +49,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0 "model": model_name, "messages": messages, "temperature": temperature, - "response_format": {"type": "json_object"} if json_mode else None, + "response_format": response_format or ({"type": "json_object"} if json_mode else None), + "top_p": top_p, + "stream": stream, } # Add api_key and api_base if provided @@ -88,12 +79,9 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0 api_call_params["tool_choice"] = tool_choice if "claude" in model_name.lower() or "anthropic" in model_name.lower(): - # if messages[0]["role"] != "user": - # api_call_params["messages"] = [{"role": "user", "content": "."}] + messages api_call_params["extra_headers"] = { "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" } - # # "anthropic-beta": "prompt-caching-2024-07-31" # Log the API request logger.info(f"Sending API request: {json.dumps(api_call_params, indent=2)}") @@ -101,20 +89,49 @@ async def make_llm_api_call(messages, model_name, json_mode=False, temperature=0 if agentops_session: response = await agentops_session.patch(litellm.acompletion)(**api_call_params) else: - if stream: - response = await litellm.acompletion(**api_call_params, stream=True) - else: - response = await litellm.acompletion(**api_call_params) + response = await litellm.acompletion(**api_call_params) # Log the API response logger.info(f"Received API response: {response}") return response - async for result in attempt_api_call(api_call): - yield result + return await attempt_api_call(api_call) # Sample Usage if __name__ == "__main__": + import asyncio + async def test_llm_api_call(stream=True): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Complex essay on economics"} + ] + model_name = "gpt-4o" - pass \ No newline at end of file + response = await make_llm_api_call(messages, model_name, stream=stream) + + if stream: + print("Streaming response:") + async for chunk in response: + if isinstance(chunk, dict) and 'choices' in chunk: + content = chunk['choices'][0]['delta'].get('content', '') + print(content, end='', flush=True) + else: + # For non-dict responses (like ModelResponse objects) + content = chunk.choices[0].delta.content + if content: + print(content, end='', flush=True) + print("\nStream completed.") + else: + print("Non-streaming response:") + if isinstance(response, dict) and 'choices' in response: + print(response['choices'][0]['message']['content']) + else: + # For non-dict responses (like ModelResponse objects) + print(response.choices[0].message.content) + + # Example usage: + # asyncio.run(test_llm_api_call(stream=True)) # For streaming + # asyncio.run(test_llm_api_call(stream=False)) # For non-streaming + + asyncio.run(test_llm_api_call()) \ No newline at end of file diff --git a/core/thread_manager.py b/core/thread_manager.py index 6948a4ed..1ac4fa8d 100644 --- a/core/thread_manager.py +++ b/core/thread_manager.py @@ -5,19 +5,18 @@ from typing import List, Dict, Any, Optional, Callable, AsyncGenerator, Union from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from core.db import Database, Thread, ThreadRun -from core.tools.tool import Tool, ToolResult +from core.tool import Tool, ToolResult from core.llm import make_llm_api_call # from core.working_memory_manager import WorkingMemory from datetime import datetime -from core.tools.tool_registry import ToolRegistry +from core.tool_registry import ToolRegistry import re -from core.agent_manager import AgentManager +import uuid class ThreadManager: def __init__(self, db: Database): self.db = db self.tool_registry = ToolRegistry() - self.agent_manager = AgentManager(db) async def create_thread(self) -> int: async with self.db.get_async_session() as session: @@ -197,84 +196,176 @@ class ThreadManager: async def run_thread( self, - thread_id: int, - agent_id: Optional[int] = None, - system_message: Optional[Dict[str, Any]] = None, - model_name: Optional[str] = None, - temperature: Optional[float] = None, + thread_id: str, + system_message: Dict[str, Any], + model_name: str, + temperature: float = 0.5, max_tokens: Optional[int] = None, - tools: Optional[List[Dict[str, Any]]] = None, - additional_instructions: Optional[str] = None, + tools: Optional[List[str]] = None, + additional_system_message: Optional[str] = None, hide_tool_msgs: bool = False, execute_tools_async: bool = True, use_tool_parser: bool = False, - stream: bool = False - ) -> AsyncGenerator[Union[Dict[str, Any], str], None]: - if agent_id is not None: - agent = await self.agent_manager.get_agent(agent_id) - if not agent: - raise ValueError(f"Agent with id {agent_id} not found") - system_message = {"role": "system", "content": agent.system_prompt} - model_name = agent.model - temperature = agent.temperature - tools = [self.tool_registry.get_tool(tool).schema()[0] for tool in agent.selected_tools] if agent.selected_tools else None - elif system_message is None or model_name is None: - raise ValueError("Either agent_id or system_message and model_name must be provided") + top_p: Optional[float] = None, + tool_choice: str = "auto", + response_format: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + run_id = str(uuid.uuid4()) + + # Fetch full tool objects based on the provided tool names + full_tools = None + if tools: + full_tools = [self.tool_registry.get_tool(tool_name)['schema'] for tool_name in tools if self.tool_registry.get_tool(tool_name)] + + thread_run = ThreadRun( + id=run_id, + thread_id=thread_id, + status="queued", + model=model_name, + system_message=json.dumps(system_message), + tools=json.dumps(full_tools) if full_tools else None, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + tool_choice=tool_choice, + execute_tools_async=execute_tools_async, + response_format=json.dumps(response_format) if response_format else None + ) - if await self.should_stop(thread_id): - yield {"status": "stopped", "message": "Session cancelled"} - return + async with self.db.get_async_session() as session: + session.add(thread_run) + await session.commit() - if use_tool_parser: - hide_tool_msgs = True - - await self.cleanup_incomplete_tool_calls(thread_id) - - # Prepare messages - messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs) - prepared_messages = [system_message] + messages - - if additional_instructions: - additional_instruction_message = { - "role": "user", - "content": additional_instructions - } - prepared_messages.append(additional_instruction_message) - try: - response_stream = make_llm_api_call( + thread_run.status = "in_progress" + thread_run.started_at = int(datetime.utcnow().timestamp()) + await self.update_thread_run(thread_run) + + if await self.should_stop(thread_id): + thread_run.status = "cancelled" + thread_run.cancelled_at = int(datetime.utcnow().timestamp()) + await self.update_thread_run(thread_run) + return {"status": "stopped", "message": "Session cancelled"} + + if use_tool_parser: + hide_tool_msgs = True + + await self.cleanup_incomplete_tool_calls(thread_id) + + # Prepare messages + messages = await self.list_messages(thread_id, hide_tool_msgs=hide_tool_msgs) + prepared_messages = [system_message] + messages + + if additional_system_message: + additional_instruction_message = { + "role": "user", + "content": additional_system_message + } + prepared_messages.append(additional_instruction_message) + + response = await make_llm_api_call( prepared_messages, model_name, temperature=temperature, max_tokens=max_tokens, - tools=tools, - tool_choice="auto", - stream=stream + tools=full_tools, + tool_choice=tool_choice, + stream=False, + top_p=top_p, + response_format=response_format ) - async for partial_response in response_stream: - if stream: - yield partial_response - else: - response = partial_response + usage = response.usage if hasattr(response, 'usage') else None + usage_dict = self.serialize_usage(usage) if usage else None + thread_run.usage = usage_dict - if not stream: - if tools is None or use_tool_parser: - await self.handle_response_without_tools(thread_id, response, use_tool_parser) - else: - await self.handle_response_with_tools(thread_id, response, execute_tools_async) + # Add the assistant's message to the thread + assistant_message = { + "role": "assistant", + "content": response.choices[0].message['content'] + } + if 'tool_calls' in response.choices[0].message: + assistant_message['tool_calls'] = response.choices[0].message['tool_calls'] + + await self.add_message(thread_id, assistant_message) - if await self.should_stop(thread_id): - yield {"status": "stopped", "message": "Session cancelled"} - else: - await self.save_thread_run(thread_id) - yield response + if tools is None or use_tool_parser: + await self.handle_response_without_tools(thread_id, response, use_tool_parser) + else: + await self.handle_response_with_tools(thread_id, response, execute_tools_async) + + thread_run.status = "completed" + thread_run.completed_at = int(datetime.utcnow().timestamp()) + await self.update_thread_run(thread_run) + + return { + "id": thread_run.id, + "choices": [self.serialize_choice(choice) for choice in response.choices], + "usage": usage_dict, + "model": model_name, + "object": "chat.completion", + "created": int(datetime.utcnow().timestamp()) + } except Exception as e: error_message = f"Error in API call: {str(e)}\n\nFull error: {repr(e)}" logging.error(error_message) - await self.update_thread_run_with_error(thread_id, error_message) - yield {"status": "error", "message": error_message} + thread_run.status = "failed" + thread_run.failed_at = int(datetime.utcnow().timestamp()) + thread_run.last_error = error_message + await self.update_thread_run(thread_run) + return {"status": "error", "message": error_message} + + def serialize_usage(self, usage): + return { + "completion_tokens": usage.completion_tokens, + "prompt_tokens": usage.prompt_tokens, + "total_tokens": usage.total_tokens, + "completion_tokens_details": self.serialize_completion_tokens_details(usage.completion_tokens_details), + "prompt_tokens_details": self.serialize_prompt_tokens_details(usage.prompt_tokens_details) + } + + def serialize_completion_tokens_details(self, details): + return { + "audio_tokens": details.audio_tokens, + "reasoning_tokens": details.reasoning_tokens + } + + def serialize_prompt_tokens_details(self, details): + return { + "audio_tokens": details.audio_tokens, + "cached_tokens": details.cached_tokens + } + + def serialize_choice(self, choice): + return { + "finish_reason": choice.finish_reason, + "index": choice.index, + "message": self.serialize_message(choice.message) + } + + def serialize_message(self, message): + return { + "content": message.content, + "role": message.role, + "tool_calls": [self.serialize_tool_call(tc) for tc in message.tool_calls] if message.tool_calls else None + } + + def serialize_tool_call(self, tool_call): + return { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + } + } + + async def update_thread_run(self, thread_run: ThreadRun): + async with self.db.get_async_session() as session: + session.add(thread_run) + await session.commit() + await session.refresh(thread_run) async def handle_response_without_tools(self, thread_id: int, response: Any, use_tool_parser: bool): response_content = response.choices[0].message['content'] @@ -282,8 +373,8 @@ class ThreadManager: if use_tool_parser: await self.handle_tool_parser_response(thread_id, response_content) else: - logging.info("Adding assistant message to thread.") - await self.add_message(thread_id, {"role": "assistant", "content": response_content}) + # The message has already been added in the run_thread method, so we don't need to add it again here + pass async def handle_tool_parser_response(self, thread_id: int, response_content: str): tool_call_match = re.search(r'\{[\s\S]*"function_calls"[\s\S]*\}', response_content) @@ -325,9 +416,8 @@ class ThreadManager: response_message = response.choices[0].message tool_calls = response_message.get('tool_calls', []) - assistant_message = self.create_assistant_message_with_tools(response_message) - await self.add_message(thread_id, assistant_message) - + # The assistant message has already been added in the run_thread method + available_functions = self.get_available_functions() if await self.should_stop(thread_id): @@ -348,9 +438,7 @@ class ThreadManager: except AttributeError as e: logging.error(f"AttributeError: {e}") - content = response_message.get('content', '') - if content: - await self.add_message(thread_id, {"role": "assistant", "content": content}) + # No need to add the message here as it's already been added in the run_thread method def create_assistant_message_with_tools(self, response_message: Any) -> Dict[str, Any]: message = { @@ -452,7 +540,7 @@ class ThreadManager: result = await session.execute(stmt) return result.scalar_one_or_none() is not None - async def save_thread_run(self, thread_id: int): + async def save_thread_run(self, thread_id: str): async with self.db.get_async_session() as session: thread = await session.get(Thread, thread_id) if not thread: @@ -461,14 +549,26 @@ class ThreadManager: messages = json.loads(thread.messages) creation_date = datetime.now().isoformat() - new_thread_run = ThreadRun( - thread_id=thread_id, - messages=json.dumps(messages), - creation_date=creation_date, - status='completed' - ) - session.add(new_thread_run) - await session.commit() + # Get the latest ThreadRun for this thread + stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1) + result = await session.execute(stmt) + latest_thread_run = result.scalar_one_or_none() + + if latest_thread_run: + # Update the existing ThreadRun + latest_thread_run.messages = json.dumps(messages) + latest_thread_run.last_updated_date = creation_date + await session.commit() + else: + # Create a new ThreadRun if none exists + new_thread_run = ThreadRun( + thread_id=thread_id, + messages=json.dumps(messages), + creation_date=creation_date, + status='completed' + ) + session.add(new_thread_run) + await session.commit() async def get_thread(self, thread_id: int) -> Optional[Thread]: async with self.db.get_async_session() as session: @@ -489,11 +589,92 @@ class ThreadManager: result = await session.execute(select(Thread).order_by(Thread.thread_id.desc())) return result.scalars().all() - async def get_latest_thread_run(self, thread_id: int): + async def get_latest_thread_run(self, thread_id: str): async with self.db.get_async_session() as session: - stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.run_id.desc()).limit(1) + stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(1) result = await session.execute(stmt) - return result.scalar_one_or_none() + latest_run = result.scalar_one_or_none() + if latest_run: + return { + "id": latest_run.id, + "status": latest_run.status, + "error_message": latest_run.last_error, + "created_at": latest_run.created_at, + "started_at": latest_run.started_at, + "completed_at": latest_run.completed_at, + "cancelled_at": latest_run.cancelled_at, + "failed_at": latest_run.failed_at, + "model": latest_run.model, + "usage": latest_run.usage + } + return None + + async def get_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]: + async with self.db.get_async_session() as session: + run = await session.get(ThreadRun, run_id) + if run and run.thread_id == thread_id: + return { + "id": run.id, + "thread_id": run.thread_id, + "status": run.status, + "created_at": run.created_at, + "started_at": run.started_at, + "completed_at": run.completed_at, + "cancelled_at": run.cancelled_at, + "failed_at": run.failed_at, + "model": run.model, + "system_message": json.loads(run.system_message) if run.system_message else None, + "tools": json.loads(run.tools) if run.tools else None, + "usage": run.usage, + "temperature": run.temperature, + "top_p": run.top_p, + "max_tokens": run.max_tokens, + "tool_choice": run.tool_choice, + "execute_tools_async": run.execute_tools_async, + "response_format": json.loads(run.response_format) if run.response_format else None, + "last_error": run.last_error + } + return None + + async def cancel_run(self, thread_id: str, run_id: str) -> Optional[Dict[str, Any]]: + async with self.db.get_async_session() as session: + run = await session.get(ThreadRun, run_id) + if run and run.thread_id == thread_id and run.status == "in_progress": + run.status = "cancelled" + run.cancelled_at = int(datetime.utcnow().timestamp()) + await session.commit() + return await self.get_run(thread_id, run_id) + return None + + async def list_runs(self, thread_id: str, limit: int) -> List[Dict[str, Any]]: + async with self.db.get_async_session() as session: + stmt = select(ThreadRun).where(ThreadRun.thread_id == thread_id).order_by(ThreadRun.created_at.desc()).limit(limit) + result = await session.execute(stmt) + runs = result.scalars().all() + return [ + { + "id": run.id, + "thread_id": run.thread_id, + "status": run.status, + "created_at": run.created_at, + "started_at": run.started_at, + "completed_at": run.completed_at, + "cancelled_at": run.cancelled_at, + "failed_at": run.failed_at, + "model": run.model, + "system_message": json.loads(run.system_message) if run.system_message else None, + "tools": json.loads(run.tools) if run.tools else None, + "usage": run.usage, + "temperature": run.temperature, + "top_p": run.top_p, + "max_tokens": run.max_tokens, + "tool_choice": run.tool_choice, + "execute_tools_async": run.execute_tools_async, + "response_format": json.loads(run.response_format) if run.response_format else None, + "last_error": run.last_error + } + for run in runs + ] if __name__ == "__main__": import asyncio @@ -519,6 +700,4 @@ if __name__ == "__main__": print(f"Response: {response}") - asyncio.run(main()) - asyncio.run(main()) \ No newline at end of file diff --git a/core/tools/tool.py b/core/tool.py similarity index 100% rename from core/tools/tool.py rename to core/tool.py diff --git a/core/tools/tool_registry.py b/core/tool_registry.py similarity index 91% rename from core/tools/tool_registry.py rename to core/tool_registry.py index a9cf8fdf..27c88727 100644 --- a/core/tools/tool_registry.py +++ b/core/tool_registry.py @@ -1,5 +1,5 @@ -from typing import Dict, Type, Any -from core.tools.tool import Tool +from typing import Dict, Type, Any, Optional +from core.tool import Tool from core.config import settings import importlib.util import os @@ -39,7 +39,7 @@ class ToolRegistry: print(f"Registered tools: {list(self.tools.keys())}") # Debug print - def get_tool(self, tool_name: str) -> Dict[str, Any]: + def get_tool(self, tool_name: str) -> Optional[Dict[str, Any]]: return self.tools.get(tool_name) def get_all_tools(self) -> Dict[str, Dict[str, Any]]: diff --git a/core/ui/agent_management.py b/core/ui/agent_management.py deleted file mode 100644 index 6ecccf06..00000000 --- a/core/ui/agent_management.py +++ /dev/null @@ -1,97 +0,0 @@ -import streamlit as st -import requests -from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE - -def display_agent_management(): - st.header("Agent Management") - - col1, col2 = st.columns([1, 1]) - - with col1: - display_create_agent_form() - - with col2: - display_existing_agents() - -def display_create_agent_form(): - st.subheader("Create New Agent") - with st.form("create_agent_form"): - new_agent_name = st.text_input("Agent Name") - new_agent_model = st.selectbox("Model", AI_MODELS) - new_agent_system_prompt = st.text_area("System Prompt", value=STANDARD_SYSTEM_MESSAGE) - new_agent_temperature = st.slider("Temperature", 0.0, 1.0, 0.5) - tool_options = list(st.session_state.tools.keys()) - new_agent_tools = st.multiselect("Tools", tool_options) - - submitted = st.form_submit_button("Create Agent") - if submitted: - create_agent(new_agent_name, new_agent_model, new_agent_system_prompt, new_agent_temperature, new_agent_tools) - -def create_agent(name, model, system_prompt, temperature, tools): - response = requests.post(f"{API_BASE_URL}/agents/", json={ - "name": name, - "model": model, - "system_prompt": system_prompt, - "temperature": temperature, - "selected_tools": tools - }) - if response.status_code == 200: - st.success("Agent created successfully!") - st.session_state.fetch_agents() - else: - st.error("Failed to create agent.") - -def display_existing_agents(): - st.subheader("Existing Agents") - for agent in st.session_state.agents: - with st.expander(f"Agent: {agent['name']}"): - display_agent_details(agent) - -def display_agent_details(agent): - st.write(f"Model: {agent['model']}") - st.write(f"Temperature: {agent['temperature']}") - st.write(f"Tools: {', '.join(agent['selected_tools'])}") - - if st.button(f"Edit Agent {agent['id']}"): - st.session_state.editing_agent = agent['id'] - - if st.button(f"Delete Agent {agent['id']}"): - delete_agent(agent['id']) - - if st.session_state.get('editing_agent') == agent['id']: - edit_agent_form(agent) - -def edit_agent_form(agent): - with st.form(f"edit_agent_form_{agent['id']}"): - updated_name = st.text_input("Agent Name", value=agent['name']) - updated_model = st.selectbox("Model", AI_MODELS, index=AI_MODELS.index(agent['model'])) - updated_system_prompt = st.text_area("System Prompt", value=agent['system_prompt']) - updated_temperature = st.slider("Temperature", 0.0, 1.0, value=agent['temperature']) - tool_options = list(st.session_state.tools.keys()) - updated_tools = st.multiselect("Tools", options=tool_options, default=agent['selected_tools']) - - if st.form_submit_button("Update Agent"): - update_agent(agent['id'], updated_name, updated_model, updated_system_prompt, updated_temperature, updated_tools) - st.session_state.editing_agent = None - -def update_agent(agent_id, name, model, system_prompt, temperature, tools): - response = requests.put(f"{API_BASE_URL}/agents/{agent_id}", json={ - "name": name, - "model": model, - "system_prompt": system_prompt, - "temperature": temperature, - "selected_tools": tools - }) - if response.status_code == 200: - st.success("Agent updated successfully!") - st.session_state.fetch_agents() - else: - st.error("Failed to update agent.") - -def delete_agent(agent_id): - response = requests.delete(f"{API_BASE_URL}/agents/{agent_id}") - if response.status_code == 200: - st.success("Agent deleted successfully!") - st.session_state.fetch_agents() - else: - st.error("Failed to delete agent.") \ No newline at end of file diff --git a/core/ui/main.py b/core/ui/main.py index b71a8b65..47dc253f 100644 --- a/core/ui/main.py +++ b/core/ui/main.py @@ -1,34 +1,45 @@ import streamlit as st from core.ui.thread_management import display_thread_management -from core.ui.message_display import display_messages -from core.ui.thread_runner import display_thread_runner -from core.ui.agent_management import display_agent_management +from core.ui.message_display import display_messages_and_runner +from core.ui.thread_runner import fetch_thread_runs, display_runs from core.ui.tool_display import display_tools -from core.ui.utils import initialize_session_state, fetch_data +from core.ui.utils import initialize_session_state, fetch_data, API_BASE_URL def main(): initialize_session_state() fetch_data() + st.set_page_config(page_title="AI Assistant Management System", layout="wide") + st.sidebar.title("Navigation") - mode = st.sidebar.radio("Select Mode", ["Agent Management", "Thread Management", "Tools"]) + mode = st.sidebar.radio("Select Mode", ["Thread Management", "Tools"]) st.title("AI Assistant Management System") - if mode == "Agent Management": - display_agent_management() - elif mode == "Tools": + if mode == "Tools": display_tools() else: # Thread Management display_thread_management_content() def display_thread_management_content(): - st.header("Thread Management") - display_thread_management() + col1, col2 = st.columns([1, 3]) + + with col1: + display_thread_management() + if st.session_state.selected_thread: + display_thread_runner(st.session_state.selected_thread) + + with col2: + if st.session_state.selected_thread: + display_messages_and_runner(st.session_state.selected_thread) - if st.session_state.selected_thread: - display_messages(st.session_state.selected_thread) - display_thread_runner(st.session_state.selected_thread) +def display_thread_runner(thread_id): + st.subheader("Thread Runs") + + limit = st.number_input("Number of runs to retrieve", min_value=1, max_value=100, value=20) + if st.button("Fetch Runs"): + runs = fetch_thread_runs(thread_id, limit) + display_runs(runs) if __name__ == "__main__": main() \ No newline at end of file diff --git a/core/ui/message_display.py b/core/ui/message_display.py index 72a961a0..33081555 100644 --- a/core/ui/message_display.py +++ b/core/ui/message_display.py @@ -1,16 +1,21 @@ import streamlit as st import requests -import json -from core.ui.utils import API_BASE_URL +from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE +from core.ui.thread_runner import prepare_run_thread_data, run_thread, display_response_content -def display_messages(thread_id): - st.subheader(f"๐Ÿงต Thread ID: {thread_id}") - # st.write("### ๐Ÿ“ Messages") +def display_messages_and_runner(thread_id): + st.subheader(f"Messages for Thread: {thread_id}") messages = fetch_messages(thread_id) - display_message_json(messages) display_message_list(messages) - display_add_message_form(thread_id) + + col1, col2 = st.columns(2) + + with col1: + display_add_message_form(thread_id) + + with col2: + display_run_thread_form(thread_id) def fetch_messages(thread_id): messages_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/messages/") @@ -20,28 +25,37 @@ def fetch_messages(thread_id): st.error("Failed to fetch messages.") return [] -def display_message_json(messages): - if st.button("Show/Hide JSON"): - st.session_state.show_json = not st.session_state.get('show_json', False) - - if st.session_state.get('show_json', False): - json_str = json.dumps(messages, indent=2) - st.code(json_str, language="json") - def display_message_list(messages): for msg in messages: with st.chat_message(msg['role']): st.write(msg['content']) def display_add_message_form(thread_id): - st.write("### โž• Add a New Message") + st.write("### Add a New Message") with st.form(key="add_message_form"): - role = st.selectbox("๐Ÿ”น Role", ["user", "assistant"], key="add_role") - content = st.text_area("๐Ÿ“ Content", key="add_content") - submitted = st.form_submit_button("โž• Add Message") + role = st.selectbox("Role", ["user", "assistant"], key="add_role") + content = st.text_area("Content", key="add_content") + submitted = st.form_submit_button("Add Message") if submitted: add_message(thread_id, role, content) +def display_run_thread_form(thread_id): + st.write("### Run Thread") + with st.form(key="run_thread_form"): + model_name = st.selectbox("Model", AI_MODELS, key="model_name") + temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature") + max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens") + system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100) + additional_system_message = st.text_area("Additional System Message", key="additional_system_message", height=100) + + tool_options = list(st.session_state.tools.keys()) + selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools") + + submitted = st.form_submit_button("Run Thread") + if submitted: + run_thread_data = prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools) + run_thread(thread_id, run_thread_data) + def add_message(thread_id, role, content): message_data = {"role": role, "content": content} add_msg_response = requests.post( @@ -49,6 +63,7 @@ def add_message(thread_id, role, content): json=message_data ) if add_msg_response.status_code == 200: + st.success("Message added successfully.") st.rerun() else: st.error("Failed to add message.") \ No newline at end of file diff --git a/core/ui/thread_management.py b/core/ui/thread_management.py index 30ce0948..b9e34aff 100644 --- a/core/ui/thread_management.py +++ b/core/ui/thread_management.py @@ -1,22 +1,22 @@ import streamlit as st import requests from core.ui.utils import API_BASE_URL +from datetime import datetime def display_thread_management(): - col1, col2 = st.columns([1, 2]) + st.subheader("Thread Management") - with col1: - if st.button("โž• Create New Thread"): - create_new_thread() + if st.button("โž• Create New Thread", key="create_thread_button"): + create_new_thread() - with col2: - display_thread_selector() + display_thread_selector() def create_new_thread(): response = requests.post(f"{API_BASE_URL}/threads/") if response.status_code == 200: thread_id = response.json()['thread_id'] st.session_state.selected_thread = thread_id + st.success(f"New thread created with ID: {thread_id}") st.rerun() else: st.error("Failed to create a new thread.") @@ -25,17 +25,36 @@ def display_thread_selector(): threads_response = requests.get(f"{API_BASE_URL}/threads/") if threads_response.status_code == 200: threads = threads_response.json() - thread_options = [str(thread['thread_id']) for thread in threads] + + # Sort threads by creation date if available, otherwise by thread_id + def sort_key(thread): + if 'creation_date' in thread: + try: + return datetime.strptime(thread['creation_date'], "%Y-%m-%d %H:%M:%S") + except ValueError: + st.warning(f"Invalid date format for thread {thread['thread_id']}") + return thread['thread_id'] + + sorted_threads = sorted(threads, key=sort_key, reverse=True) + + thread_options = [] + for thread in sorted_threads: + if 'creation_date' in thread: + thread_options.append(f"{thread['thread_id']} - Created: {thread['creation_date']}") + else: + thread_options.append(thread['thread_id']) - if st.session_state.selected_thread is None and threads: - st.session_state.selected_thread = str(threads[0]['thread_id']) + if st.session_state.selected_thread is None and sorted_threads: + st.session_state.selected_thread = sorted_threads[0]['thread_id'] selected_thread = st.selectbox( "๐Ÿ” Select Thread", thread_options, key="thread_select", - index=thread_options.index(str(st.session_state.selected_thread)) if st.session_state.selected_thread else 0 + index=next((i for i, t in enumerate(sorted_threads) if t['thread_id'] == st.session_state.selected_thread), 0) ) if selected_thread: - st.session_state.selected_thread = int(selected_thread) \ No newline at end of file + st.session_state.selected_thread = selected_thread.split(' - ')[0] if ' - ' in selected_thread else selected_thread + else: + st.error(f"Failed to fetch threads. Status code: {threads_response.status_code}") \ No newline at end of file diff --git a/core/ui/thread_runner.py b/core/ui/thread_runner.py index 72078225..a83c44df 100644 --- a/core/ui/thread_runner.py +++ b/core/ui/thread_runner.py @@ -1,66 +1,16 @@ import streamlit as st import requests -import json -from core.ui.utils import API_BASE_URL, AI_MODELS, STANDARD_SYSTEM_MESSAGE +from core.ui.utils import API_BASE_URL +from datetime import datetime -def display_thread_runner(thread_id): - st.write("## โš™๏ธ Run Thread") - - manual_config = display_manual_setup_tab() - - # Common settings - additional_instructions = st.text_area("Additional Instructions", key="additional_instructions", height=100) - stream = st.checkbox("๐Ÿ“ก Stream Responses", key="stream_responses") - - # Prepare the run thread data - run_thread_data = prepare_run_thread_data(manual_config, additional_instructions, stream) - - # Display the preview of the request payload in an expander - with st.expander("๐Ÿ“ค Preview Request Payload", expanded=False): - st.json(run_thread_data) - - # Center the run button and make it more prominent - col1, col2, col3 = st.columns([1, 2, 1]) - with col2: - if st.button("โ–ถ๏ธ Run Thread", key="run_thread_button", use_container_width=True): - run_thread(thread_id, run_thread_data) - - display_thread_run_status(thread_id) - -def display_manual_setup_tab(): - model_name = st.selectbox("Model", AI_MODELS, key="model_name") - - col1, col2 = st.columns(2) - with col1: - temperature = st.slider("Temperature", 0.0, 1.0, 0.5, key="temperature") - with col2: - max_tokens = st.number_input("Max Tokens", min_value=1, max_value=10000, value=500, key="max_tokens") - - system_message = st.text_area("System Message", value=STANDARD_SYSTEM_MESSAGE, key="system_message", height=100) - - tool_options = list(st.session_state.tools.keys()) - selected_tools = st.multiselect("Select Tools", options=tool_options, key="selected_tools") - - manual_config = { +def prepare_run_thread_data(model_name, temperature, max_tokens, system_message, additional_system_message, selected_tools): + return { + "system_message": {"role": "system", "content": system_message}, "model_name": model_name, "temperature": temperature, "max_tokens": max_tokens, - "system_message": system_message, - "selected_tools": selected_tools - } - - return manual_config - -def prepare_run_thread_data(manual_config, additional_instructions, stream): - tools = [st.session_state.tools[tool]['schema'] for tool in manual_config['selected_tools'] if tool in st.session_state.tools] - return { - "system_message": {"role": "system", "content": manual_config['system_message']}, - "model_name": manual_config['model_name'], - "temperature": manual_config['temperature'], - "max_tokens": manual_config['max_tokens'], - "tools": tools, - "additional_instructions": additional_instructions, - "stream": stream + "tools": selected_tools, + "additional_system_message": additional_system_message } def run_thread(thread_id, run_thread_data): @@ -73,49 +23,77 @@ def run_thread(thread_id, run_thread_data): response_data = run_thread_response.json() st.success("Thread run completed successfully!") - # Display the return payload in an expander - with st.expander("๐Ÿ“ฅ Return Payload", expanded=False): - st.json(response_data) + if 'id' in response_data: + st.session_state.latest_run_id = response_data['id'] - # Display the actual response content - st.write("### ๐Ÿ“ฌ Response Content") + st.subheader("Response Content") display_response_content(response_data) - st.rerun() else: - st.error("Failed to run thread.") - with st.expander("โŒ Error Response", expanded=True): - st.json(run_thread_response.json()) + st.error(f"Failed to run thread. Status code: {run_thread_response.status_code}") + st.text("Response content:") + st.text(run_thread_response.text) def display_response_content(response_data): - if isinstance(response_data, dict) and 'response' in response_data: - for item in response_data['response']: - if isinstance(item, dict): - if 'content' in item: - st.markdown(item['content']) - elif 'tool_calls' in item: - st.write("**Tool Calls:**") - for tool_call in item['tool_calls']: - st.write(f"- Function: `{tool_call['function']['name']}`") - st.code(tool_call['function']['arguments'], language="json") - elif isinstance(item, str): - st.markdown(item) + if isinstance(response_data, dict) and 'choices' in response_data: + message = response_data['choices'][0]['message'] + st.write(f"**Role:** {message['role']}") + st.write(f"**Content:** {message['content']}") + + if 'tool_calls' in message: + st.write("**Tool Calls:**") + for tool_call in message['tool_calls']: + st.write(f"- Function: `{tool_call['function']['name']}`") + st.code(tool_call['function']['arguments'], language="json") else: st.json(response_data) -def display_thread_run_status(thread_id): - status_response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/run/status/") - if status_response.status_code == 200: - status_data = status_response.json() - st.write("### โš™๏ธ Thread Run Status") - status = status_data.get('status') - if status == 'completed': - st.success(f"**Status:** {status}") - elif status == 'error': - st.error(f"**Status:** {status}") - with st.expander("Error Details", expanded=True): - st.code(status_data.get('error_message'), language="") - else: - st.info(f"**Status:** {status}") +def fetch_thread_runs(thread_id, limit): + response = requests.get(f"{API_BASE_URL}/threads/{thread_id}/runs?limit={limit}") + if response.status_code == 200: + return response.json() else: - st.warning("Could not retrieve thread run status.") \ No newline at end of file + st.error("Failed to retrieve runs.") + return [] + +def format_timestamp(timestamp): + if timestamp: + return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') + return 'N/A' + +def display_runs(runs): + for run in runs: + with st.expander(f"Run {run['id']} - Status: {run['status']}", expanded=False): + col1, col2 = st.columns(2) + with col1: + st.write(f"**Created At:** {format_timestamp(run['created_at'])}") + st.write(f"**Started At:** {format_timestamp(run['started_at'])}") + st.write(f"**Completed At:** {format_timestamp(run['completed_at'])}") + st.write(f"**Cancelled At:** {format_timestamp(run['cancelled_at'])}") + st.write(f"**Failed At:** {format_timestamp(run['failed_at'])}") + with col2: + st.write(f"**Model:** {run['model']}") + st.write(f"**Temperature:** {run['temperature']}") + st.write(f"**Top P:** {run['top_p']}") + st.write(f"**Max Tokens:** {run['max_tokens']}") + st.write(f"**Tool Choice:** {run['tool_choice']}") + st.write(f"**Execute Tools Async:** {run['execute_tools_async']}") + + st.write("**System Message:**") + st.json(run['system_message']) + + if run['tools']: + st.write("**Tools:**") + st.json(run['tools']) + + if run['usage']: + st.write("**Usage:**") + st.json(run['usage']) + + if run['response_format']: + st.write("**Response Format:**") + st.json(run['response_format']) + + if run['last_error']: + st.error("**Last Error:**") + st.code(run['last_error']) \ No newline at end of file diff --git a/core/ui/utils.py b/core/ui/utils.py index 70b2b298..aa8676b9 100644 --- a/core/ui/utils.py +++ b/core/ui/utils.py @@ -4,13 +4,6 @@ from core.constants import AI_MODELS, STANDARD_SYSTEM_MESSAGE API_BASE_URL = "http://localhost:8000" -def fetch_agents(): - response = requests.get(f"{API_BASE_URL}/agents/") - if response.status_code == 200: - st.session_state.agents = response.json() - else: - st.error("Failed to fetch agents.") - def fetch_tools(): response = requests.get(f"{API_BASE_URL}/tools/") if response.status_code == 200: @@ -21,15 +14,10 @@ def fetch_tools(): def initialize_session_state(): if 'selected_thread' not in st.session_state: st.session_state.selected_thread = None - if 'agents' not in st.session_state: - st.session_state.agents = [] if 'tools' not in st.session_state: st.session_state.tools = [] - if 'fetch_agents' not in st.session_state: - st.session_state.fetch_agents = fetch_agents if 'fetch_tools' not in st.session_state: st.session_state.fetch_tools = fetch_tools def fetch_data(): - fetch_agents() fetch_tools() \ No newline at end of file diff --git a/main.db b/main.db index c2dc9adf..711ac612 100644 Binary files a/main.db and b/main.db differ diff --git a/main.py b/main.py new file mode 100644 index 00000000..e69de29b diff --git a/poetry.lock b/poetry.lock index f98f7443..044c0088 100644 --- a/poetry.lock +++ b/poetry.lock @@ -417,6 +417,26 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "fastapi" +version = "0.115.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.115.0-py3-none-any.whl", hash = "sha256:17ea427674467486e997206a5ab25760f6b09e069f099b96f5b55a32fb6f1631"}, + {file = "fastapi-0.115.0.tar.gz", hash = "sha256:f93b4ca3529a8ebc6fc3fcf710e5efa8de3df9b41570958abf1d97d843138004"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.37.2,<0.39.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "filelock" version = "3.16.1" @@ -2284,6 +2304,53 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sse-starlette" +version = "2.1.3" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sse_starlette-2.1.3-py3-none-any.whl", hash = "sha256:8ec846438b4665b9e8c560fcdea6bc8081a3abf7942faa95e5a744999d219772"}, + {file = "sse_starlette-2.1.3.tar.gz", hash = "sha256:9cd27eb35319e1414e3d2558ee7414487f9529ce3b3cf9b21434fd110e017169"}, +] + +[package.dependencies] +anyio = "*" +starlette = "*" +uvicorn = "*" + +[package.extras] +examples = ["fastapi"] + +[[package]] +name = "sseclient-py" +version = "1.7.2" +description = "SSE client for Python" +optional = false +python-versions = "*" +files = [ + {file = "sseclient-py-1.7.2.tar.gz", hash = "sha256:ba3197d314766eccb72a1dda80b5fa14a0fbba07d796a287654c07edde88fe0f"}, + {file = "sseclient_py-1.7.2-py2.py3-none-any.whl", hash = "sha256:a758653b13b78df42cdb696740635a26cb72ad433b75efb68dbbb163d099b6a9"}, +] + +[[package]] +name = "starlette" +version = "0.38.6" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.8" +files = [ + {file = "starlette-0.38.6-py3-none-any.whl", hash = "sha256:4517a1409e2e73ee4951214ba012052b9e16f60e90d73cfb06192c19203bbb05"}, + {file = "starlette-0.38.6.tar.gz", hash = "sha256:863a1588f5574e70a821dadefb41e4881ea451a47a3cd1b4df359d4ffefe5ead"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] + [[package]] name = "streamlit" version = "1.39.0" @@ -2615,6 +2682,24 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.31.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.8" +files = [ + {file = "uvicorn-0.31.0-py3-none-any.whl", hash = "sha256:cac7be4dd4d891c363cd942160a7b02e69150dcbc7a36be04d5f4af4b17c8ced"}, + {file = "uvicorn-0.31.0.tar.gz", hash = "sha256:13bc21373d103859f68fe739608e2eb054a816dea79189bc3ca08ea89a275906"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "watchdog" version = "5.0.3" @@ -2784,4 +2869,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "817f3d510c67d6a59033e26f11a247490013e3e63c6184844d6b261f43352cfd" +content-hash = "11cd10192810fd94515b1f6ad5d8ce27b5d665d1d4e65fff9b05b4cd80c9d377" diff --git a/pyproject.toml b/pyproject.toml index d15a9b01..fb9cdb0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,9 @@ litellm = "^1.44.4" pytest = "^8.3.2" pytest-asyncio = "^0.24.0" agentops = "^0.3.10" +sseclient-py = "1.7.2" +fastapi = "^0.115.0" +sse-starlette = "^2.1.3" [tool.poetry.group.dev.dependencies] diff --git a/tools/files_tool.py b/tools/files_tool.py index 76968117..13be0cd3 100644 --- a/tools/files_tool.py +++ b/tools/files_tool.py @@ -1,7 +1,7 @@ import os import asyncio from typing import Dict, Any -from core.tools.tool import Tool, ToolResult +from core.tool import Tool, ToolResult from core.config import settings class FilesTool(Tool): diff --git a/tools/tool_example.py b/tools/tool_example.py index d3662a5b..9668bd09 100644 --- a/tools/tool_example.py +++ b/tools/tool_example.py @@ -1,5 +1,5 @@ from typing import Dict, Any -from core.tools.tool import Tool, ToolResult +from core.tool import Tool, ToolResult class ExampleTool(Tool): description = "An example tool for demonstration purposes." diff --git a/workspace/hello.txt b/workspace/hello.txt new file mode 100644 index 00000000..9be6036d --- /dev/null +++ b/workspace/hello.txt @@ -0,0 +1 @@ +Hello there! This is a sample file created just for you. \ No newline at end of file diff --git a/workspace/info.txt b/workspace/info.txt new file mode 100644 index 00000000..0451cf81 --- /dev/null +++ b/workspace/info.txt @@ -0,0 +1 @@ +This file contains random information about this assistant. I'm here to help you with a wide range of tasks, from answering questions to managing files. \ No newline at end of file diff --git a/workspace/random_file_1.txt b/workspace/random_file_1.txt deleted file mode 100644 index 91f36e18..00000000 --- a/workspace/random_file_1.txt +++ /dev/null @@ -1 +0,0 @@ -This is some random content for the file. \ No newline at end of file diff --git a/workspace/test.txt b/workspace/test.txt deleted file mode 100644 index 5ac36801..00000000 --- a/workspace/test.txt +++ /dev/null @@ -1 +0,0 @@ -random contents \ No newline at end of file diff --git a/workspace/tips.txt b/workspace/tips.txt new file mode 100644 index 00000000..8f666d10 --- /dev/null +++ b/workspace/tips.txt @@ -0,0 +1,4 @@ +Here are some random tips: +1. Stay curious and keep learning. +2. Organize your workspace for better productivity. +3. Take breaks to maintain focus and energy. \ No newline at end of file