mirror of https://github.com/kortix-ai/suna.git
refactor wip
This commit is contained in:
parent
6e229b3830
commit
6a6b9d8e85
|
@ -124,16 +124,6 @@ class ThreadAgentResponse(BaseModel):
|
||||||
source: str # "thread", "default", "none", "missing"
|
source: str # "thread", "default", "none", "missing"
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
class AgentBuilderChatRequest(BaseModel):
|
|
||||||
message: str
|
|
||||||
conversation_history: List[Dict[str, str]] = []
|
|
||||||
partial_config: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class AgentBuilderChatResponse(BaseModel):
|
|
||||||
response: str
|
|
||||||
suggested_config: Optional[Dict[str, Any]] = None
|
|
||||||
next_step: Optional[str] = None
|
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
_db: DBConnection,
|
_db: DBConnection,
|
||||||
_instance_id: Optional[str] = None
|
_instance_id: Optional[str] = None
|
||||||
|
@ -299,71 +289,6 @@ async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: st
|
||||||
await verify_thread_access(client, thread_id, user_id)
|
await verify_thread_access(client, thread_id, user_id)
|
||||||
return agent_run_data
|
return agent_run_data
|
||||||
|
|
||||||
async def enhance_system_prompt(agent_name: str, description: str, user_system_prompt: str) -> str:
|
|
||||||
"""
|
|
||||||
Enhance a basic system prompt using GPT-4o to create a more comprehensive and effective system prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_name: Name of the agent
|
|
||||||
description: Description of the agent
|
|
||||||
user_system_prompt: User's basic system prompt/instructions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Enhanced system prompt generated by GPT-4o
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
system_message = """You are an expert at creating comprehensive system prompts for AI agents. Your task is to take basic agent information and transform it into a detailed, effective system prompt that will help the agent perform optimally.
|
|
||||||
|
|
||||||
Guidelines for creating system prompts:
|
|
||||||
1. Be specific about the agent's role, expertise, and capabilities
|
|
||||||
2. Include clear behavioral guidelines and interaction style
|
|
||||||
3. Specify the agent's knowledge domains and areas of expertise
|
|
||||||
4. Include guidance on how to handle different types of requests
|
|
||||||
5. Set appropriate boundaries and limitations
|
|
||||||
6. Make the prompt engaging and easy to understand
|
|
||||||
7. Ensure the prompt is comprehensive but not overly verbose
|
|
||||||
8. Include relevant context about tools and capabilities the agent might have
|
|
||||||
|
|
||||||
The enhanced prompt should be professional, clear, and actionable."""
|
|
||||||
|
|
||||||
user_message = f"""Please create an enhanced system prompt for an AI agent with the following details:
|
|
||||||
|
|
||||||
Agent Name: {agent_name}
|
|
||||||
Agent Description: {description}
|
|
||||||
User's Instructions: {user_system_prompt}
|
|
||||||
|
|
||||||
Transform this basic information into a comprehensive, effective system prompt that will help the agent perform at its best. The prompt should be detailed enough to guide the agent's behavior while remaining clear and actionable."""
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system_message},
|
|
||||||
{"role": "user", "content": user_message}
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"Enhancing system prompt for agent: {agent_name}")
|
|
||||||
response = await make_llm_api_call(
|
|
||||||
messages=messages,
|
|
||||||
model_name="openai/gpt-4o",
|
|
||||||
max_tokens=2000,
|
|
||||||
temperature=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
if response and response.get('choices') and response['choices'][0].get('message'):
|
|
||||||
enhanced_prompt = response['choices'][0]['message'].get('content', '').strip()
|
|
||||||
if enhanced_prompt:
|
|
||||||
logger.info(f"Successfully enhanced system prompt for agent: {agent_name}")
|
|
||||||
return enhanced_prompt
|
|
||||||
else:
|
|
||||||
logger.warning(f"GPT-4o returned empty enhanced prompt for agent: {agent_name}")
|
|
||||||
return user_system_prompt
|
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to get valid response from GPT-4o for agent: {agent_name}")
|
|
||||||
return user_system_prompt
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error enhancing system prompt for agent {agent_name}: {str(e)}")
|
|
||||||
# Return the original prompt if enhancement fails
|
|
||||||
return user_system_prompt
|
|
||||||
|
|
||||||
@router.post("/thread/{thread_id}/agent/start")
|
@router.post("/thread/{thread_id}/agent/start")
|
||||||
async def start_agent(
|
async def start_agent(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -1210,11 +1135,8 @@ async def initiate_agent_with_files(
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to initiate agent session: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to initiate agent session: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Custom agents
|
# Custom agents
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/agents", response_model=AgentsResponse)
|
@router.get("/agents", response_model=AgentsResponse)
|
||||||
async def get_agents(
|
async def get_agents(
|
||||||
user_id: str = Depends(get_current_user_id_from_jwt),
|
user_id: str = Depends(get_current_user_id_from_jwt),
|
||||||
|
|
|
@ -57,7 +57,7 @@ async def run_agent(
|
||||||
|
|
||||||
if not trace:
|
if not trace:
|
||||||
trace = langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id})
|
trace = langfuse.trace(name="run_agent", session_id=thread_id, metadata={"project_id": project_id})
|
||||||
thread_manager = ThreadManager(trace=trace, is_agent_builder=is_agent_builder, target_agent_id=target_agent_id, agent_config=agent_config)
|
thread_manager = ThreadManager(trace=trace, is_agent_builder=is_agent_builder or False, target_agent_id=target_agent_id, agent_config=agent_config)
|
||||||
|
|
||||||
client = await thread_manager.db.client
|
client = await thread_manager.db.client
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ server tool calls through dynamically generated individual function methods.
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema, ToolSchema, SchemaType
|
from agentpress.tool import Tool, ToolResult, openapi_schema, xml_schema, ToolSchema, SchemaType
|
||||||
from mcp_local.client import MCPManager
|
from mcp_service.client import MCPManager
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
import inspect
|
import inspect
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
|
|
|
@ -6,9 +6,10 @@ reaching the context window limitations of LLM models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
from litellm import token_counter, completion_cost
|
from litellm.utils import token_counter
|
||||||
|
from litellm.cost_calculator import completion_cost
|
||||||
from services.supabase import DBConnection
|
from services.supabase import DBConnection
|
||||||
from services.llm import make_llm_api_call
|
from services.llm import make_llm_api_call
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
|
@ -29,270 +30,291 @@ class ContextManager:
|
||||||
"""
|
"""
|
||||||
self.db = DBConnection()
|
self.db = DBConnection()
|
||||||
self.token_threshold = token_threshold
|
self.token_threshold = token_threshold
|
||||||
|
|
||||||
|
def is_tool_result_message(self, msg: Dict[str, Any]) -> bool:
|
||||||
|
"""Check if a message is a tool result message."""
|
||||||
|
if not ("content" in msg and msg['content']):
|
||||||
|
return False
|
||||||
|
content = msg['content']
|
||||||
|
if isinstance(content, str) and "ToolResult" in content:
|
||||||
|
return True
|
||||||
|
if isinstance(content, dict) and "tool_execution" in content:
|
||||||
|
return True
|
||||||
|
if isinstance(content, dict) and "interactive_elements" in content:
|
||||||
|
return True
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
if isinstance(parsed_content, dict) and "tool_execution" in parsed_content:
|
||||||
|
return True
|
||||||
|
if isinstance(parsed_content, dict) and "interactive_elements" in content:
|
||||||
|
return True
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
async def get_thread_token_count(self, thread_id: str) -> int:
|
def compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
|
||||||
"""Get the current token count for a thread using LiteLLM.
|
"""Compress the message content."""
|
||||||
|
if isinstance(msg_content, str):
|
||||||
Args:
|
if len(msg_content) > max_length:
|
||||||
thread_id: ID of the thread to analyze
|
return msg_content[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
|
||||||
|
|
||||||
Returns:
|
|
||||||
The total token count for relevant messages in the thread
|
|
||||||
"""
|
|
||||||
logger.debug(f"Getting token count for thread {thread_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get messages for the thread
|
|
||||||
messages = await self.get_messages_for_summarization(thread_id)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
logger.debug(f"No messages found for thread {thread_id}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Use litellm's token_counter for accurate model-specific counting
|
|
||||||
# This is much more accurate than the SQL-based estimation
|
|
||||||
token_count = token_counter(model="gpt-4", messages=messages)
|
|
||||||
|
|
||||||
logger.info(f"Thread {thread_id} has {token_count} tokens (calculated with litellm)")
|
|
||||||
return token_count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting token count: {str(e)}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def get_messages_for_summarization(self, thread_id: str) -> List[Dict[str, Any]]:
|
|
||||||
"""Get all LLM messages from the thread that need to be summarized.
|
|
||||||
|
|
||||||
This gets messages after the most recent summary or all messages if
|
|
||||||
no summary exists. Unlike get_llm_messages, this includes ALL messages
|
|
||||||
since the last summary, even if we're generating a new summary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
thread_id: ID of the thread to get messages from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message objects to summarize
|
|
||||||
"""
|
|
||||||
logger.debug(f"Getting messages for summarization for thread {thread_id}")
|
|
||||||
client = await self.db.client
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Find the most recent summary message
|
|
||||||
summary_result = await client.table('messages').select('created_at') \
|
|
||||||
.eq('thread_id', thread_id) \
|
|
||||||
.eq('type', 'summary') \
|
|
||||||
.eq('is_llm_message', True) \
|
|
||||||
.order('created_at', desc=True) \
|
|
||||||
.limit(1) \
|
|
||||||
.execute()
|
|
||||||
|
|
||||||
# Get messages after the most recent summary or all messages if no summary
|
|
||||||
if summary_result.data and len(summary_result.data) > 0:
|
|
||||||
last_summary_time = summary_result.data[0]['created_at']
|
|
||||||
logger.debug(f"Found last summary at {last_summary_time}")
|
|
||||||
|
|
||||||
# Get all messages after the summary, but NOT including the summary itself
|
|
||||||
messages_result = await client.table('messages').select('*') \
|
|
||||||
.eq('thread_id', thread_id) \
|
|
||||||
.eq('is_llm_message', True) \
|
|
||||||
.gt('created_at', last_summary_time) \
|
|
||||||
.order('created_at') \
|
|
||||||
.execute()
|
|
||||||
else:
|
else:
|
||||||
logger.debug("No previous summary found, getting all messages")
|
return msg_content
|
||||||
# Get all messages
|
elif isinstance(msg_content, dict):
|
||||||
messages_result = await client.table('messages').select('*') \
|
if len(json.dumps(msg_content)) > max_length:
|
||||||
.eq('thread_id', thread_id) \
|
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
|
||||||
.eq('is_llm_message', True) \
|
else:
|
||||||
.order('created_at') \
|
return msg_content
|
||||||
.execute()
|
|
||||||
|
def safe_truncate(self, msg_content: Union[str, dict], max_length: int = 100000) -> Union[str, dict]:
|
||||||
# Parse the message content if needed
|
"""Truncate the message content safely by removing the middle portion."""
|
||||||
messages = []
|
max_length = min(max_length, 100000)
|
||||||
for msg in messages_result.data:
|
if isinstance(msg_content, str):
|
||||||
# Skip existing summary messages - we don't want to summarize summaries
|
if len(msg_content) > max_length:
|
||||||
if msg.get('type') == 'summary':
|
# Calculate how much to keep from start and end
|
||||||
logger.debug(f"Skipping summary message from {msg.get('created_at')}")
|
keep_length = max_length - 150 # Reserve space for truncation message
|
||||||
continue
|
start_length = keep_length // 2
|
||||||
|
end_length = keep_length - start_length
|
||||||
# Parse content if it's a string
|
|
||||||
content = msg['content']
|
|
||||||
if isinstance(content, str):
|
|
||||||
try:
|
|
||||||
content = json.loads(content)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # Keep as string if not valid JSON
|
|
||||||
|
|
||||||
# Ensure we have the proper format for the LLM
|
start_part = msg_content[:start_length]
|
||||||
if 'role' not in content and 'type' in msg:
|
end_part = msg_content[-end_length:] if end_length > 0 else ""
|
||||||
# Convert message type to role if needed
|
|
||||||
role = msg['type']
|
|
||||||
if role == 'assistant' or role == 'user' or role == 'system' or role == 'tool':
|
|
||||||
content = {'role': role, 'content': content}
|
|
||||||
|
|
||||||
messages.append(content)
|
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
||||||
|
else:
|
||||||
logger.info(f"Got {len(messages)} messages to summarize for thread {thread_id}")
|
return msg_content
|
||||||
return messages
|
elif isinstance(msg_content, dict):
|
||||||
|
json_str = json.dumps(msg_content)
|
||||||
except Exception as e:
|
if len(json_str) > max_length:
|
||||||
logger.error(f"Error getting messages for summarization: {str(e)}", exc_info=True)
|
# Calculate how much to keep from start and end
|
||||||
return []
|
keep_length = max_length - 150 # Reserve space for truncation message
|
||||||
|
start_length = keep_length // 2
|
||||||
async def create_summary(
|
end_length = keep_length - start_length
|
||||||
self,
|
|
||||||
thread_id: str,
|
start_part = json_str[:start_length]
|
||||||
messages: List[Dict[str, Any]],
|
end_part = json_str[-end_length:] if end_length > 0 else ""
|
||||||
model: str = "gpt-4o-mini"
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
||||||
"""Generate a summary of conversation messages.
|
else:
|
||||||
|
return msg_content
|
||||||
|
|
||||||
|
def compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
|
||||||
|
"""Compress the tool result messages except the most recent one."""
|
||||||
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||||
|
max_tokens_value = max_tokens or (100 * 1000)
|
||||||
|
|
||||||
|
if uncompressed_total_token_count > max_tokens_value:
|
||||||
|
_i = 0 # Count the number of ToolResult messages
|
||||||
|
for msg in reversed(messages): # Start from the end and work backwards
|
||||||
|
if self.is_tool_result_message(msg): # Only compress ToolResult messages
|
||||||
|
_i += 1 # Count the number of ToolResult messages
|
||||||
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||||
|
if msg_token_count > token_threshold: # If the message is too long
|
||||||
|
if _i > 1: # If this is not the most recent ToolResult message
|
||||||
|
message_id = msg.get('message_id') # Get the message_id
|
||||||
|
if message_id:
|
||||||
|
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
|
||||||
|
else:
|
||||||
|
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||||
|
else:
|
||||||
|
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
|
||||||
|
"""Compress the user messages except the most recent one."""
|
||||||
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||||
|
max_tokens_value = max_tokens or (100 * 1000)
|
||||||
|
|
||||||
|
if uncompressed_total_token_count > max_tokens_value:
|
||||||
|
_i = 0 # Count the number of User messages
|
||||||
|
for msg in reversed(messages): # Start from the end and work backwards
|
||||||
|
if msg.get('role') == 'user': # Only compress User messages
|
||||||
|
_i += 1 # Count the number of User messages
|
||||||
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||||
|
if msg_token_count > token_threshold: # If the message is too long
|
||||||
|
if _i > 1: # If this is not the most recent User message
|
||||||
|
message_id = msg.get('message_id') # Get the message_id
|
||||||
|
if message_id:
|
||||||
|
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
|
||||||
|
else:
|
||||||
|
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||||
|
else:
|
||||||
|
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: int = 1000) -> List[Dict[str, Any]]:
|
||||||
|
"""Compress the assistant messages except the most recent one."""
|
||||||
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
||||||
|
max_tokens_value = max_tokens or (100 * 1000)
|
||||||
|
|
||||||
|
if uncompressed_total_token_count > max_tokens_value:
|
||||||
|
_i = 0 # Count the number of Assistant messages
|
||||||
|
for msg in reversed(messages): # Start from the end and work backwards
|
||||||
|
if msg.get('role') == 'assistant': # Only compress Assistant messages
|
||||||
|
_i += 1 # Count the number of Assistant messages
|
||||||
|
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
||||||
|
if msg_token_count > token_threshold: # If the message is too long
|
||||||
|
if _i > 1: # If this is not the most recent Assistant message
|
||||||
|
message_id = msg.get('message_id') # Get the message_id
|
||||||
|
if message_id:
|
||||||
|
msg["content"] = self.compress_message(msg["content"], message_id, token_threshold * 3)
|
||||||
|
else:
|
||||||
|
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
||||||
|
else:
|
||||||
|
msg["content"] = self.safe_truncate(msg["content"], int(max_tokens_value * 2))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def remove_meta_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Remove meta messages from the messages."""
|
||||||
|
result: List[Dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
msg_content = msg.get('content')
|
||||||
|
# Try to parse msg_content as JSON if it's a string
|
||||||
|
if isinstance(msg_content, str):
|
||||||
|
try:
|
||||||
|
msg_content = json.loads(msg_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
if isinstance(msg_content, dict):
|
||||||
|
# Create a copy to avoid modifying the original
|
||||||
|
msg_content_copy = msg_content.copy()
|
||||||
|
if "tool_execution" in msg_content_copy:
|
||||||
|
tool_execution = msg_content_copy["tool_execution"].copy()
|
||||||
|
if "arguments" in tool_execution:
|
||||||
|
del tool_execution["arguments"]
|
||||||
|
msg_content_copy["tool_execution"] = tool_execution
|
||||||
|
# Create a new message dict with the modified content
|
||||||
|
new_msg = msg.copy()
|
||||||
|
new_msg["content"] = json.dumps(msg_content_copy)
|
||||||
|
result.append(new_msg)
|
||||||
|
else:
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: int = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
"""Compress the messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: ID of the thread to summarize
|
messages: List of messages to compress
|
||||||
messages: Messages to summarize
|
llm_model: Model name for token counting
|
||||||
model: LLM model to use for summarization
|
max_tokens: Maximum allowed tokens
|
||||||
|
token_threshold: Token threshold for individual message compression (must be a power of 2)
|
||||||
Returns:
|
max_iterations: Maximum number of compression iterations
|
||||||
Summary message object or None if summarization failed
|
"""
|
||||||
|
# Set model-specific token limits
|
||||||
|
if 'sonnet' in llm_model.lower():
|
||||||
|
max_tokens = 200 * 1000 - 64000 - 28000
|
||||||
|
elif 'gpt' in llm_model.lower():
|
||||||
|
max_tokens = 128 * 1000 - 28000
|
||||||
|
elif 'gemini' in llm_model.lower():
|
||||||
|
max_tokens = 1000 * 1000 - 300000
|
||||||
|
elif 'deepseek' in llm_model.lower():
|
||||||
|
max_tokens = 128 * 1000 - 28000
|
||||||
|
else:
|
||||||
|
max_tokens = 41 * 1000 - 10000
|
||||||
|
|
||||||
|
result = messages
|
||||||
|
result = self.remove_meta_messages(result)
|
||||||
|
|
||||||
|
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
|
||||||
|
|
||||||
|
result = self.compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
|
||||||
|
result = self.compress_user_messages(result, llm_model, max_tokens, token_threshold)
|
||||||
|
result = self.compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
|
||||||
|
|
||||||
|
compressed_token_count = token_counter(model=llm_model, messages=result)
|
||||||
|
|
||||||
|
logger.info(f"compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
||||||
|
|
||||||
|
if max_iterations <= 0:
|
||||||
|
logger.warning(f"compress_messages: Max iterations reached, omitting messages")
|
||||||
|
result = self.compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
|
||||||
|
return result
|
||||||
|
|
||||||
|
if compressed_token_count > max_tokens:
|
||||||
|
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
|
||||||
|
result = self.compress_messages(messages, llm_model, max_tokens, token_threshold // 2, max_iterations - 1)
|
||||||
|
|
||||||
|
return self.middle_out_messages(result)
|
||||||
|
|
||||||
|
def compress_messages_by_omitting_messages(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
llm_model: str,
|
||||||
|
max_tokens: Optional[int] = 41000,
|
||||||
|
removal_batch_size: int = 10,
|
||||||
|
min_messages_to_keep: int = 10
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Compress the messages by omitting messages from the middle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to compress
|
||||||
|
llm_model: Model name for token counting
|
||||||
|
max_tokens: Maximum allowed tokens
|
||||||
|
removal_batch_size: Number of messages to remove per iteration
|
||||||
|
min_messages_to_keep: Minimum number of messages to preserve
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
logger.warning("No messages to summarize")
|
return messages
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(f"Creating summary for thread {thread_id} with {len(messages)} messages")
|
|
||||||
|
|
||||||
# Create system message with summarization instructions
|
|
||||||
system_message = {
|
|
||||||
"role": "system",
|
|
||||||
"content": f"""You are a specialized summarization assistant. Your task is to create a concise but comprehensive summary of the conversation history.
|
|
||||||
|
|
||||||
The summary should:
|
|
||||||
1. Preserve all key information including decisions, conclusions, and important context
|
|
||||||
2. Include any tools that were used and their results
|
|
||||||
3. Maintain chronological order of events
|
|
||||||
4. Be presented as a narrated list of key points with section headers
|
|
||||||
5. Include only factual information from the conversation (no new information)
|
|
||||||
6. Be concise but detailed enough that the conversation can continue with this summary as context
|
|
||||||
|
|
||||||
VERY IMPORTANT: This summary will replace older parts of the conversation in the LLM's context window, so ensure it contains ALL key information and LATEST STATE OF THE CONVERSATION - SO WE WILL KNOW HOW TO PICK UP WHERE WE LEFT OFF.
|
|
||||||
|
|
||||||
|
|
||||||
THE CONVERSATION HISTORY TO SUMMARIZE IS AS FOLLOWS:
|
|
||||||
===============================================================
|
|
||||||
==================== CONVERSATION HISTORY ====================
|
|
||||||
{messages}
|
|
||||||
==================== END OF CONVERSATION HISTORY ====================
|
|
||||||
===============================================================
|
|
||||||
"""
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Call LLM to generate summary
|
|
||||||
response = await make_llm_api_call(
|
|
||||||
model_name=model,
|
|
||||||
messages=[system_message, {"role": "user", "content": "PLEASE PROVIDE THE SUMMARY NOW."}],
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=SUMMARY_TARGET_TOKENS,
|
|
||||||
stream=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if response and hasattr(response, 'choices') and response.choices:
|
result = messages
|
||||||
summary_content = response.choices[0].message.content
|
result = self.remove_meta_messages(result)
|
||||||
|
|
||||||
# Track token usage
|
|
||||||
try:
|
|
||||||
token_count = token_counter(model=model, messages=[{"role": "user", "content": summary_content}])
|
|
||||||
cost = completion_cost(model=model, prompt="", completion=summary_content)
|
|
||||||
logger.info(f"Summary generated with {token_count} tokens at cost ${cost:.6f}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error calculating token usage: {str(e)}")
|
|
||||||
|
|
||||||
# Format the summary message with clear beginning and end markers
|
|
||||||
formatted_summary = f"""
|
|
||||||
======== CONVERSATION HISTORY SUMMARY ========
|
|
||||||
|
|
||||||
{summary_content}
|
# Early exit if no compression needed
|
||||||
|
initial_token_count = token_counter(model=llm_model, messages=result)
|
||||||
|
max_allowed_tokens = max_tokens or (100 * 1000)
|
||||||
|
|
||||||
|
if initial_token_count <= max_allowed_tokens:
|
||||||
|
return result
|
||||||
|
|
||||||
======== END OF SUMMARY ========
|
# Separate system message (assumed to be first) from conversation messages
|
||||||
|
system_message = messages[0] if messages and messages[0].get('role') == 'system' else None
|
||||||
|
conversation_messages = result[1:] if system_message else result
|
||||||
|
|
||||||
|
safety_limit = 500
|
||||||
|
current_token_count = initial_token_count
|
||||||
|
|
||||||
|
while current_token_count > max_allowed_tokens and safety_limit > 0:
|
||||||
|
safety_limit -= 1
|
||||||
|
|
||||||
|
if len(conversation_messages) <= min_messages_to_keep:
|
||||||
|
logger.warning(f"Cannot compress further: only {len(conversation_messages)} messages remain (min: {min_messages_to_keep})")
|
||||||
|
break
|
||||||
|
|
||||||
The above is a summary of the conversation history. The conversation continues below.
|
# Calculate removal strategy based on current message count
|
||||||
"""
|
if len(conversation_messages) > (removal_batch_size * 2):
|
||||||
|
# Remove from middle, keeping recent and early context
|
||||||
# Format the summary message
|
middle_start = len(conversation_messages) // 2 - (removal_batch_size // 2)
|
||||||
summary_message = {
|
middle_end = middle_start + removal_batch_size
|
||||||
"role": "user",
|
conversation_messages = conversation_messages[:middle_start] + conversation_messages[middle_end:]
|
||||||
"content": formatted_summary
|
|
||||||
}
|
|
||||||
|
|
||||||
return summary_message
|
|
||||||
else:
|
else:
|
||||||
logger.error("Failed to generate summary: Invalid response")
|
# Remove from earlier messages, preserving recent context
|
||||||
return None
|
messages_to_remove = min(removal_batch_size, len(conversation_messages) // 2)
|
||||||
|
if messages_to_remove > 0:
|
||||||
except Exception as e:
|
conversation_messages = conversation_messages[messages_to_remove:]
|
||||||
logger.error(f"Error creating summary: {str(e)}", exc_info=True)
|
else:
|
||||||
return None
|
# Can't remove any more messages
|
||||||
|
break
|
||||||
|
|
||||||
|
# Recalculate token count
|
||||||
|
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
|
||||||
|
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
|
||||||
|
|
||||||
|
# Prepare final result
|
||||||
|
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
|
||||||
|
final_token_count = token_counter(model=llm_model, messages=final_messages)
|
||||||
|
|
||||||
async def check_and_summarize_if_needed(
|
logger.info(f"compress_messages_by_omitting_messages: {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")
|
||||||
self,
|
|
||||||
thread_id: str,
|
return final_messages
|
||||||
add_message_callback,
|
|
||||||
model: str = "gpt-4o-mini",
|
def middle_out_messages(self, messages: List[Dict[str, Any]], max_messages: int = 320) -> List[Dict[str, Any]]:
|
||||||
force: bool = False
|
"""Remove messages from the middle of the list, keeping max_messages total."""
|
||||||
) -> bool:
|
if len(messages) <= max_messages:
|
||||||
"""Check if thread needs summarization and summarize if so.
|
return messages
|
||||||
|
|
||||||
Args:
|
# Keep half from the beginning and half from the end
|
||||||
thread_id: ID of the thread to check
|
keep_start = max_messages // 2
|
||||||
add_message_callback: Callback to add the summary message to the thread
|
keep_end = max_messages - keep_start
|
||||||
model: LLM model to use for summarization
|
|
||||||
force: Whether to force summarization regardless of token count
|
return messages[:keep_start] + messages[-keep_end:]
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if summarization was performed, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get token count using LiteLLM (accurate model-specific counting)
|
|
||||||
token_count = await self.get_thread_token_count(thread_id)
|
|
||||||
|
|
||||||
# If token count is below threshold and not forcing, no summarization needed
|
|
||||||
if token_count < self.token_threshold and not force:
|
|
||||||
logger.debug(f"Thread {thread_id} has {token_count} tokens, below threshold {self.token_threshold}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Log reason for summarization
|
|
||||||
if force:
|
|
||||||
logger.info(f"Forced summarization of thread {thread_id} with {token_count} tokens")
|
|
||||||
else:
|
|
||||||
logger.info(f"Thread {thread_id} exceeds token threshold ({token_count} >= {self.token_threshold}), summarizing...")
|
|
||||||
|
|
||||||
# Get messages to summarize
|
|
||||||
messages = await self.get_messages_for_summarization(thread_id)
|
|
||||||
|
|
||||||
# If there are too few messages, don't summarize
|
|
||||||
if len(messages) < 3:
|
|
||||||
logger.info(f"Thread {thread_id} has too few messages ({len(messages)}) to summarize")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create summary
|
|
||||||
summary = await self.create_summary(thread_id, messages, model)
|
|
||||||
|
|
||||||
if summary:
|
|
||||||
# Add summary message to thread
|
|
||||||
await add_message_callback(
|
|
||||||
thread_id=thread_id,
|
|
||||||
type="summary",
|
|
||||||
content=summary,
|
|
||||||
is_llm_message=True,
|
|
||||||
metadata={"token_count": token_count}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Successfully added summary to thread {thread_id}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to create summary for thread {thread_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in check_and_summarize_if_needed: {str(e)}", exc_info=True)
|
|
||||||
return False
|
|
|
@ -1602,51 +1602,6 @@ class ResponseProcessor:
|
||||||
)
|
)
|
||||||
return message_obj # Return the full message object
|
return message_obj # Return the full message object
|
||||||
|
|
||||||
# Check if this is an MCP tool (function_name starts with "call_mcp_tool")
|
|
||||||
function_name = tool_call.get("function_name", "")
|
|
||||||
|
|
||||||
# Check if this is an MCP tool - either the old call_mcp_tool or a dynamically registered MCP tool
|
|
||||||
is_mcp_tool = False
|
|
||||||
if function_name == "call_mcp_tool":
|
|
||||||
is_mcp_tool = True
|
|
||||||
else:
|
|
||||||
# Check if the result indicates it's an MCP tool by looking for MCP metadata
|
|
||||||
if hasattr(result, 'output') and isinstance(result.output, str):
|
|
||||||
# Check for MCP metadata pattern in the output
|
|
||||||
if "MCP Tool Result from" in result.output and "Tool Metadata:" in result.output:
|
|
||||||
is_mcp_tool = True
|
|
||||||
# Also check for MCP metadata in JSON format
|
|
||||||
elif "mcp_metadata" in result.output:
|
|
||||||
is_mcp_tool = True
|
|
||||||
|
|
||||||
if is_mcp_tool:
|
|
||||||
# Special handling for MCP tools - make content prominent and LLM-friendly
|
|
||||||
result_role = "user" if strategy == "user_message" else "assistant"
|
|
||||||
|
|
||||||
# Extract the actual content from the ToolResult
|
|
||||||
if hasattr(result, 'output'):
|
|
||||||
mcp_content = str(result.output)
|
|
||||||
else:
|
|
||||||
mcp_content = str(result)
|
|
||||||
|
|
||||||
# Create a simple, LLM-friendly message format that puts content first
|
|
||||||
simple_message = {
|
|
||||||
"role": result_role,
|
|
||||||
"content": mcp_content # Direct content, no complex nesting
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Adding MCP tool result with simplified format for LLM visibility")
|
|
||||||
self.trace.event(name="adding_mcp_tool_result_simplified", level="DEFAULT", status_message="Adding MCP tool result with simplified format for LLM visibility")
|
|
||||||
|
|
||||||
message_obj = await self.add_message(
|
|
||||||
thread_id=thread_id,
|
|
||||||
type="tool",
|
|
||||||
content=simple_message,
|
|
||||||
is_llm_message=True,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
return message_obj
|
|
||||||
|
|
||||||
# For XML and other non-native tools, use the new structured format
|
# For XML and other non-native tools, use the new structured format
|
||||||
# Determine message role based on strategy
|
# Determine message role based on strategy
|
||||||
result_role = "user" if strategy == "user_message" else "assistant"
|
result_role = "user" if strategy == "user_message" else "assistant"
|
||||||
|
@ -1781,28 +1736,6 @@ class ResponseProcessor:
|
||||||
|
|
||||||
return structured_result_v1
|
return structured_result_v1
|
||||||
|
|
||||||
def _format_xml_tool_result(self, tool_call: Dict[str, Any], result: ToolResult) -> str:
|
|
||||||
"""Format a tool result wrapped in a <tool_result> tag.
|
|
||||||
|
|
||||||
DEPRECATED: This method is kept for backwards compatibility.
|
|
||||||
New implementations should use _create_structured_tool_result instead.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_call: The tool call that was executed
|
|
||||||
result: The result of the tool execution
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
String containing the formatted result wrapped in <tool_result> tag
|
|
||||||
"""
|
|
||||||
# Always use xml_tag_name if it exists
|
|
||||||
if "xml_tag_name" in tool_call:
|
|
||||||
xml_tag_name = tool_call["xml_tag_name"]
|
|
||||||
return f"<tool_result> <{xml_tag_name}> {str(result)} </{xml_tag_name}> </tool_result>"
|
|
||||||
|
|
||||||
# Non-XML tool, just return the function result
|
|
||||||
function_name = tool_call["function_name"]
|
|
||||||
return f"Result for {function_name}: {str(result)}"
|
|
||||||
|
|
||||||
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int, assistant_message_id: Optional[str] = None, parsing_details: Optional[Dict[str, Any]] = None) -> ToolExecutionContext:
|
def _create_tool_context(self, tool_call: Dict[str, Any], tool_index: int, assistant_message_id: Optional[str] = None, parsing_details: Optional[Dict[str, Any]] = None) -> ToolExecutionContext:
|
||||||
"""Create a tool execution context with display name and parsing details populated."""
|
"""Create a tool execution context with display name and parsing details populated."""
|
||||||
context = ToolExecutionContext(
|
context = ToolExecutionContext(
|
||||||
|
|
|
@ -25,7 +25,7 @@ from utils.logger import logger
|
||||||
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
|
from langfuse.client import StatefulGenerationClient, StatefulTraceClient
|
||||||
from services.langfuse import langfuse
|
from services.langfuse import langfuse
|
||||||
import datetime
|
import datetime
|
||||||
from litellm import token_counter
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
# Type alias for tool choice
|
# Type alias for tool choice
|
||||||
ToolChoice = Literal["auto", "required", "none"]
|
ToolChoice = Literal["auto", "required", "none"]
|
||||||
|
@ -65,279 +65,6 @@ class ThreadManager:
|
||||||
)
|
)
|
||||||
self.context_manager = ContextManager()
|
self.context_manager = ContextManager()
|
||||||
|
|
||||||
def _is_tool_result_message(self, msg: Dict[str, Any]) -> bool:
|
|
||||||
if not ("content" in msg and msg['content']):
|
|
||||||
return False
|
|
||||||
content = msg['content']
|
|
||||||
if isinstance(content, str) and "ToolResult" in content: return True
|
|
||||||
if isinstance(content, dict) and "tool_execution" in content: return True
|
|
||||||
if isinstance(content, dict) and "interactive_elements" in content: return True
|
|
||||||
if isinstance(content, str):
|
|
||||||
try:
|
|
||||||
parsed_content = json.loads(content)
|
|
||||||
if isinstance(parsed_content, dict) and "tool_execution" in parsed_content: return True
|
|
||||||
if isinstance(parsed_content, dict) and "interactive_elements" in content: return True
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _compress_message(self, msg_content: Union[str, dict], message_id: Optional[str] = None, max_length: int = 3000) -> Union[str, dict]:
|
|
||||||
"""Compress the message content."""
|
|
||||||
# print("max_length", max_length)
|
|
||||||
if isinstance(msg_content, str):
|
|
||||||
if len(msg_content) > max_length:
|
|
||||||
return msg_content[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
|
|
||||||
else:
|
|
||||||
return msg_content
|
|
||||||
elif isinstance(msg_content, dict):
|
|
||||||
if len(json.dumps(msg_content)) > max_length:
|
|
||||||
return json.dumps(msg_content)[:max_length] + "... (truncated)" + f"\n\nmessage_id \"{message_id}\"\nUse expand-message tool to see contents"
|
|
||||||
else:
|
|
||||||
return msg_content
|
|
||||||
|
|
||||||
def _safe_truncate(self, msg_content: Union[str, dict], max_length: int = 100000) -> Union[str, dict]:
|
|
||||||
"""Truncate the message content safely by removing the middle portion."""
|
|
||||||
max_length = min(max_length, 100000)
|
|
||||||
if isinstance(msg_content, str):
|
|
||||||
if len(msg_content) > max_length:
|
|
||||||
# Calculate how much to keep from start and end
|
|
||||||
keep_length = max_length - 150 # Reserve space for truncation message
|
|
||||||
start_length = keep_length // 2
|
|
||||||
end_length = keep_length - start_length
|
|
||||||
|
|
||||||
start_part = msg_content[:start_length]
|
|
||||||
end_part = msg_content[-end_length:] if end_length > 0 else ""
|
|
||||||
|
|
||||||
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
|
||||||
else:
|
|
||||||
return msg_content
|
|
||||||
elif isinstance(msg_content, dict):
|
|
||||||
json_str = json.dumps(msg_content)
|
|
||||||
if len(json_str) > max_length:
|
|
||||||
# Calculate how much to keep from start and end
|
|
||||||
keep_length = max_length - 150 # Reserve space for truncation message
|
|
||||||
start_length = keep_length // 2
|
|
||||||
end_length = keep_length - start_length
|
|
||||||
|
|
||||||
start_part = json_str[:start_length]
|
|
||||||
end_part = json_str[-end_length:] if end_length > 0 else ""
|
|
||||||
|
|
||||||
return start_part + f"\n\n... (middle truncated) ...\n\n" + end_part + f"\n\nThis message is too long, repeat relevant information in your response to remember it"
|
|
||||||
else:
|
|
||||||
return msg_content
|
|
||||||
|
|
||||||
def _compress_tool_result_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
|
||||||
"""Compress the tool result messages except the most recent one."""
|
|
||||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
|
||||||
|
|
||||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
|
||||||
_i = 0 # Count the number of ToolResult messages
|
|
||||||
for msg in reversed(messages): # Start from the end and work backwards
|
|
||||||
if self._is_tool_result_message(msg): # Only compress ToolResult messages
|
|
||||||
_i += 1 # Count the number of ToolResult messages
|
|
||||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
|
||||||
if msg_token_count > token_threshold: # If the message is too long
|
|
||||||
if _i > 1: # If this is not the most recent ToolResult message
|
|
||||||
message_id = msg.get('message_id') # Get the message_id
|
|
||||||
if message_id:
|
|
||||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
|
||||||
else:
|
|
||||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
|
||||||
else:
|
|
||||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _compress_user_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
|
||||||
"""Compress the user messages except the most recent one."""
|
|
||||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
|
||||||
|
|
||||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
|
||||||
_i = 0 # Count the number of User messages
|
|
||||||
for msg in reversed(messages): # Start from the end and work backwards
|
|
||||||
if msg.get('role') == 'user': # Only compress User messages
|
|
||||||
_i += 1 # Count the number of User messages
|
|
||||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
|
||||||
if msg_token_count > token_threshold: # If the message is too long
|
|
||||||
if _i > 1: # If this is not the most recent User message
|
|
||||||
message_id = msg.get('message_id') # Get the message_id
|
|
||||||
if message_id:
|
|
||||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
|
||||||
else:
|
|
||||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
|
||||||
else:
|
|
||||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def _compress_assistant_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int], token_threshold: Optional[int] = 1000) -> List[Dict[str, Any]]:
|
|
||||||
"""Compress the assistant messages except the most recent one."""
|
|
||||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=messages)
|
|
||||||
if uncompressed_total_token_count > (max_tokens or (100 * 1000)):
|
|
||||||
_i = 0 # Count the number of Assistant messages
|
|
||||||
for msg in reversed(messages): # Start from the end and work backwards
|
|
||||||
if msg.get('role') == 'assistant': # Only compress Assistant messages
|
|
||||||
_i += 1 # Count the number of Assistant messages
|
|
||||||
msg_token_count = token_counter(messages=[msg]) # Count the number of tokens in the message
|
|
||||||
if msg_token_count > token_threshold: # If the message is too long
|
|
||||||
if _i > 1: # If this is not the most recent Assistant message
|
|
||||||
message_id = msg.get('message_id') # Get the message_id
|
|
||||||
if message_id:
|
|
||||||
msg["content"] = self._compress_message(msg["content"], message_id, token_threshold * 3)
|
|
||||||
else:
|
|
||||||
logger.warning(f"UNEXPECTED: Message has no message_id {str(msg)[:100]}")
|
|
||||||
else:
|
|
||||||
msg["content"] = self._safe_truncate(msg["content"], int(max_tokens * 2))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_meta_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
||||||
"""Remove meta messages from the messages."""
|
|
||||||
result: List[Dict[str, Any]] = []
|
|
||||||
for msg in messages:
|
|
||||||
msg_content = msg.get('content')
|
|
||||||
# Try to parse msg_content as JSON if it's a string
|
|
||||||
if isinstance(msg_content, str):
|
|
||||||
try: msg_content = json.loads(msg_content)
|
|
||||||
except json.JSONDecodeError: pass
|
|
||||||
if isinstance(msg_content, dict):
|
|
||||||
# Create a copy to avoid modifying the original
|
|
||||||
msg_content_copy = msg_content.copy()
|
|
||||||
if "tool_execution" in msg_content_copy:
|
|
||||||
tool_execution = msg_content_copy["tool_execution"].copy()
|
|
||||||
if "arguments" in tool_execution:
|
|
||||||
del tool_execution["arguments"]
|
|
||||||
msg_content_copy["tool_execution"] = tool_execution
|
|
||||||
# Create a new message dict with the modified content
|
|
||||||
new_msg = msg.copy()
|
|
||||||
new_msg["content"] = json.dumps(msg_content_copy)
|
|
||||||
result.append(new_msg)
|
|
||||||
else:
|
|
||||||
result.append(msg)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _compress_messages(self, messages: List[Dict[str, Any]], llm_model: str, max_tokens: Optional[int] = 41000, token_threshold: Optional[int] = 4096, max_iterations: int = 5) -> List[Dict[str, Any]]:
|
|
||||||
"""Compress the messages.
|
|
||||||
token_threshold: must be a power of 2
|
|
||||||
"""
|
|
||||||
|
|
||||||
if 'sonnet' in llm_model.lower():
|
|
||||||
max_tokens = 200 * 1000 - 64000 - 28000
|
|
||||||
elif 'gpt' in llm_model.lower():
|
|
||||||
max_tokens = 128 * 1000 - 28000
|
|
||||||
elif 'gemini' in llm_model.lower():
|
|
||||||
max_tokens = 1000 * 1000 - 300000
|
|
||||||
elif 'deepseek' in llm_model.lower():
|
|
||||||
max_tokens = 128 * 1000 - 28000
|
|
||||||
else:
|
|
||||||
max_tokens = 41 * 1000 - 10000
|
|
||||||
|
|
||||||
result = messages
|
|
||||||
result = self._remove_meta_messages(result)
|
|
||||||
|
|
||||||
uncompressed_total_token_count = token_counter(model=llm_model, messages=result)
|
|
||||||
|
|
||||||
result = self._compress_tool_result_messages(result, llm_model, max_tokens, token_threshold)
|
|
||||||
result = self._compress_user_messages(result, llm_model, max_tokens, token_threshold)
|
|
||||||
result = self._compress_assistant_messages(result, llm_model, max_tokens, token_threshold)
|
|
||||||
|
|
||||||
compressed_token_count = token_counter(model=llm_model, messages=result)
|
|
||||||
|
|
||||||
logger.info(f"_compress_messages: {uncompressed_total_token_count} -> {compressed_token_count}") # Log the token compression for debugging later
|
|
||||||
|
|
||||||
if max_iterations <= 0:
|
|
||||||
logger.warning(f"_compress_messages: Max iterations reached, omitting messages")
|
|
||||||
result = self._compress_messages_by_omitting_messages(messages, llm_model, max_tokens)
|
|
||||||
return result
|
|
||||||
|
|
||||||
if (compressed_token_count > max_tokens):
|
|
||||||
logger.warning(f"Further token compression is needed: {compressed_token_count} > {max_tokens}")
|
|
||||||
result = self._compress_messages(messages, llm_model, max_tokens, int(token_threshold / 2), max_iterations - 1)
|
|
||||||
|
|
||||||
return self._middle_out_messages(result)
|
|
||||||
|
|
||||||
def _compress_messages_by_omitting_messages(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
llm_model: str,
|
|
||||||
max_tokens: Optional[int] = 41000,
|
|
||||||
removal_batch_size: int = 10,
|
|
||||||
min_messages_to_keep: int = 10
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""Compress the messages by omitting messages from the middle.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages to compress
|
|
||||||
llm_model: Model name for token counting
|
|
||||||
max_tokens: Maximum allowed tokens
|
|
||||||
removal_batch_size: Number of messages to remove per iteration
|
|
||||||
min_messages_to_keep: Minimum number of messages to preserve
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
result = messages
|
|
||||||
result = self._remove_meta_messages(result)
|
|
||||||
|
|
||||||
# Early exit if no compression needed
|
|
||||||
initial_token_count = token_counter(model=llm_model, messages=result)
|
|
||||||
max_allowed_tokens = max_tokens or (100 * 1000)
|
|
||||||
|
|
||||||
if initial_token_count <= max_allowed_tokens:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Separate system message (assumed to be first) from conversation messages
|
|
||||||
system_message = messages[0] if messages and messages[0].get('role') == 'system' else None
|
|
||||||
conversation_messages = result[1:] if system_message else result
|
|
||||||
|
|
||||||
safety_limit = 500
|
|
||||||
current_token_count = initial_token_count
|
|
||||||
|
|
||||||
while current_token_count > max_allowed_tokens and safety_limit > 0:
|
|
||||||
safety_limit -= 1
|
|
||||||
|
|
||||||
if len(conversation_messages) <= min_messages_to_keep:
|
|
||||||
logger.warning(f"Cannot compress further: only {len(conversation_messages)} messages remain (min: {min_messages_to_keep})")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Calculate removal strategy based on current message count
|
|
||||||
if len(conversation_messages) > (removal_batch_size * 2):
|
|
||||||
# Remove from middle, keeping recent and early context
|
|
||||||
middle_start = len(conversation_messages) // 2 - (removal_batch_size // 2)
|
|
||||||
middle_end = middle_start + removal_batch_size
|
|
||||||
conversation_messages = conversation_messages[:middle_start] + conversation_messages[middle_end:]
|
|
||||||
else:
|
|
||||||
# Remove from earlier messages, preserving recent context
|
|
||||||
messages_to_remove = min(removal_batch_size, len(conversation_messages) // 2)
|
|
||||||
if messages_to_remove > 0:
|
|
||||||
conversation_messages = conversation_messages[messages_to_remove:]
|
|
||||||
else:
|
|
||||||
# Can't remove any more messages
|
|
||||||
break
|
|
||||||
|
|
||||||
# Recalculate token count
|
|
||||||
messages_to_count = ([system_message] + conversation_messages) if system_message else conversation_messages
|
|
||||||
current_token_count = token_counter(model=llm_model, messages=messages_to_count)
|
|
||||||
|
|
||||||
# Prepare final result
|
|
||||||
final_messages = ([system_message] + conversation_messages) if system_message else conversation_messages
|
|
||||||
final_token_count = token_counter(model=llm_model, messages=final_messages)
|
|
||||||
|
|
||||||
logger.info(f"_compress_messages_by_omitting_messages: {initial_token_count} -> {final_token_count} tokens ({len(messages)} -> {len(final_messages)} messages)")
|
|
||||||
|
|
||||||
return final_messages
|
|
||||||
|
|
||||||
def _middle_out_messages(self, messages: List[Dict[str, Any]], max_messages: int = 320) -> List[Dict[str, Any]]:
|
|
||||||
"""Remove messages from the middle of the list, keeping max_messages total."""
|
|
||||||
if len(messages) <= max_messages:
|
|
||||||
return messages
|
|
||||||
|
|
||||||
# Keep half from the beginning and half from the end
|
|
||||||
keep_start = max_messages // 2
|
|
||||||
keep_end = max_messages - keep_start
|
|
||||||
|
|
||||||
return messages[:keep_start] + messages[-keep_end:]
|
|
||||||
|
|
||||||
|
|
||||||
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
def add_tool(self, tool_class: Type[Tool], function_names: Optional[List[str]] = None, **kwargs):
|
||||||
"""Add a tool to the ThreadManager."""
|
"""Add a tool to the ThreadManager."""
|
||||||
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
self.tool_registry.register_tool(tool_class, function_names, **kwargs)
|
||||||
|
@ -637,7 +364,7 @@ Here are the XML tools available with examples:
|
||||||
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
openapi_tool_schemas = self.tool_registry.get_openapi_schemas()
|
||||||
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
logger.debug(f"Retrieved {len(openapi_tool_schemas) if openapi_tool_schemas else 0} OpenAPI tool schemas")
|
||||||
|
|
||||||
prepared_messages = self._compress_messages(prepared_messages, llm_model)
|
prepared_messages = self.context_manager.compress_messages(prepared_messages, llm_model)
|
||||||
|
|
||||||
# 5. Make LLM API call
|
# 5. Make LLM API call
|
||||||
logger.debug("Making LLM API call")
|
logger.debug("Making LLM API call")
|
||||||
|
|
|
@ -22,7 +22,7 @@ from sandbox import api as sandbox_api
|
||||||
from services import billing as billing_api
|
from services import billing as billing_api
|
||||||
from flags import api as feature_flags_api
|
from flags import api as feature_flags_api
|
||||||
from services import transcription as transcription_api
|
from services import transcription as transcription_api
|
||||||
from services.mcp_custom import discover_custom_tools
|
from mcp_service.mcp_custom import discover_custom_tools
|
||||||
import sys
|
import sys
|
||||||
from services import email_api
|
from services import email_api
|
||||||
from triggers import api as triggers_api
|
from triggers import api as triggers_api
|
||||||
|
@ -151,8 +151,8 @@ app.include_router(billing_api.router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(feature_flags_api.router, prefix="/api")
|
app.include_router(feature_flags_api.router, prefix="/api")
|
||||||
|
|
||||||
from mcp_local import api as mcp_api
|
from mcp_service import api as mcp_api
|
||||||
from mcp_local import secure_api as secure_mcp_api
|
from mcp_service import secure_api as secure_mcp_api
|
||||||
|
|
||||||
app.include_router(mcp_api.router, prefix="/api")
|
app.include_router(mcp_api.router, prefix="/api")
|
||||||
app.include_router(secure_mcp_api.router, prefix="/api/secure-mcp")
|
app.include_router(secure_mcp_api.router, prefix="/api/secure-mcp")
|
||||||
|
|
|
@ -1,299 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
import subprocess
|
|
||||||
from typing import Dict, Any
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from fastapi import HTTPException # type: ignore
|
|
||||||
from utils.logger import logger
|
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp.client.sse import sse_client # type: ignore
|
|
||||||
from mcp.client.streamable_http import streamablehttp_client # type: ignore
|
|
||||||
|
|
||||||
windows_executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
# def run_mcp_stdio_sync(command, args, env_vars, timeout=30):
|
|
||||||
# try:
|
|
||||||
# env = os.environ.copy()
|
|
||||||
# env.update(env_vars)
|
|
||||||
|
|
||||||
# full_command = [command] + args
|
|
||||||
|
|
||||||
# process = subprocess.Popen(
|
|
||||||
# full_command,
|
|
||||||
# stdin=subprocess.PIPE,
|
|
||||||
# stdout=subprocess.PIPE,
|
|
||||||
# stderr=subprocess.PIPE,
|
|
||||||
# env=env,
|
|
||||||
# text=True,
|
|
||||||
# bufsize=0,
|
|
||||||
# creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == "win32" else 0
|
|
||||||
# )
|
|
||||||
|
|
||||||
# init_request = {
|
|
||||||
# "jsonrpc": "2.0",
|
|
||||||
# "id": 1,
|
|
||||||
# "method": "initialize",
|
|
||||||
# "params": {
|
|
||||||
# "protocolVersion": "2024-11-05",
|
|
||||||
# "capabilities": {},
|
|
||||||
# "clientInfo": {"name": "mcp-client", "version": "1.0.0"}
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
|
|
||||||
# process.stdin.write(json.dumps(init_request) + "\n")
|
|
||||||
# process.stdin.flush()
|
|
||||||
|
|
||||||
# init_response_line = process.stdout.readline().strip()
|
|
||||||
# if not init_response_line:
|
|
||||||
# raise Exception("No response from MCP server during initialization")
|
|
||||||
|
|
||||||
# init_response = json.loads(init_response_line)
|
|
||||||
|
|
||||||
# init_notification = {
|
|
||||||
# "jsonrpc": "2.0",
|
|
||||||
# "method": "notifications/initialized"
|
|
||||||
# }
|
|
||||||
# process.stdin.write(json.dumps(init_notification) + "\n")
|
|
||||||
# process.stdin.flush()
|
|
||||||
|
|
||||||
# tools_request = {
|
|
||||||
# "jsonrpc": "2.0",
|
|
||||||
# "id": 2,
|
|
||||||
# "method": "tools/list",
|
|
||||||
# "params": {}
|
|
||||||
# }
|
|
||||||
|
|
||||||
# process.stdin.write(json.dumps(tools_request) + "\n")
|
|
||||||
# process.stdin.flush()
|
|
||||||
|
|
||||||
# tools_response_line = process.stdout.readline().strip()
|
|
||||||
# if not tools_response_line:
|
|
||||||
# raise Exception("No response from MCP server for tools list")
|
|
||||||
|
|
||||||
# tools_response = json.loads(tools_response_line)
|
|
||||||
|
|
||||||
# tools_info = []
|
|
||||||
# if "result" in tools_response and "tools" in tools_response["result"]:
|
|
||||||
# for tool in tools_response["result"]["tools"]:
|
|
||||||
# tool_info = {
|
|
||||||
# "name": tool["name"],
|
|
||||||
# "description": tool.get("description", ""),
|
|
||||||
# "input_schema": tool.get("inputSchema", {})
|
|
||||||
# }
|
|
||||||
# tools_info.append(tool_info)
|
|
||||||
|
|
||||||
# return {
|
|
||||||
# "status": "connected",
|
|
||||||
# "transport": "stdio",
|
|
||||||
# "tools": tools_info
|
|
||||||
# }
|
|
||||||
|
|
||||||
# except subprocess.TimeoutExpired:
|
|
||||||
# return {
|
|
||||||
# "status": "error",
|
|
||||||
# "error": f"Process timeout after {timeout} seconds",
|
|
||||||
# "tools": []
|
|
||||||
# }
|
|
||||||
# except json.JSONDecodeError as e:
|
|
||||||
# return {
|
|
||||||
# "status": "error",
|
|
||||||
# "error": f"Invalid JSON response: {str(e)}",
|
|
||||||
# "tools": []
|
|
||||||
# }
|
|
||||||
# except Exception as e:
|
|
||||||
# return {
|
|
||||||
# "status": "error",
|
|
||||||
# "error": str(e),
|
|
||||||
# "tools": []
|
|
||||||
# }
|
|
||||||
# finally:
|
|
||||||
# try:
|
|
||||||
# if 'process' in locals():
|
|
||||||
# process.terminate()
|
|
||||||
# process.wait(timeout=5)
|
|
||||||
# except:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
|
|
||||||
# async def connect_stdio_server_windows(server_name, server_config, all_tools, timeout):
|
|
||||||
# """Windows-compatible stdio connection using subprocess"""
|
|
||||||
|
|
||||||
# logger.info(f"Connecting to {server_name} using Windows subprocess method")
|
|
||||||
|
|
||||||
# command = server_config["command"]
|
|
||||||
# args = server_config.get("args", [])
|
|
||||||
# env_vars = server_config.get("env", {})
|
|
||||||
|
|
||||||
# loop = asyncio.get_event_loop()
|
|
||||||
# result = await loop.run_in_executor(
|
|
||||||
# windows_executor,
|
|
||||||
# run_mcp_stdio_sync,
|
|
||||||
# command,
|
|
||||||
# args,
|
|
||||||
# env_vars,
|
|
||||||
# timeout
|
|
||||||
# )
|
|
||||||
|
|
||||||
# all_tools[server_name] = result
|
|
||||||
|
|
||||||
# if result["status"] == "connected":
|
|
||||||
# logger.info(f" {server_name}: Connected via Windows subprocess ({len(result['tools'])} tools)")
|
|
||||||
# else:
|
|
||||||
# logger.error(f" {server_name}: Error - {result['error']}")
|
|
||||||
|
|
||||||
|
|
||||||
# async def list_mcp_tools_mixed_windows(config, timeout=15):
|
|
||||||
# all_tools = {}
|
|
||||||
|
|
||||||
# if "mcpServers" not in config:
|
|
||||||
# return all_tools
|
|
||||||
|
|
||||||
# mcp_servers = config["mcpServers"]
|
|
||||||
|
|
||||||
# for server_name, server_config in mcp_servers.items():
|
|
||||||
# logger.info(f"Connecting to MCP server: {server_name}")
|
|
||||||
# if server_config.get("disabled", False):
|
|
||||||
# all_tools[server_name] = {"status": "disabled", "tools": []}
|
|
||||||
# logger.info(f" {server_name}: Disabled")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# await connect_stdio_server_windows(server_name, server_config, all_tools, timeout)
|
|
||||||
|
|
||||||
# except asyncio.TimeoutError:
|
|
||||||
# all_tools[server_name] = {
|
|
||||||
# "status": "error",
|
|
||||||
# "error": f"Connection timeout after {timeout} seconds",
|
|
||||||
# "tools": []
|
|
||||||
# }
|
|
||||||
# logger.error(f" {server_name}: Timeout after {timeout} seconds")
|
|
||||||
# except Exception as e:
|
|
||||||
# error_msg = str(e)
|
|
||||||
# all_tools[server_name] = {
|
|
||||||
# "status": "error",
|
|
||||||
# "error": error_msg,
|
|
||||||
# "tools": []
|
|
||||||
# }
|
|
||||||
# logger.error(f" {server_name}: Error - {error_msg}")
|
|
||||||
# import traceback
|
|
||||||
# logger.debug(f"Full traceback for {server_name}: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
# return all_tools
|
|
||||||
|
|
||||||
|
|
||||||
async def discover_custom_tools(request_type: str, config: Dict[str, Any]):
|
|
||||||
logger.info(f"Received custom MCP discovery request: type={request_type}")
|
|
||||||
logger.debug(f"Request config: {config}")
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
server_name = None
|
|
||||||
|
|
||||||
# if request_type == 'json':
|
|
||||||
# try:
|
|
||||||
# all_tools = await list_mcp_tools_mixed_windows(config, timeout=30)
|
|
||||||
# if "mcpServers" in config and config["mcpServers"]:
|
|
||||||
# server_name = list(config["mcpServers"].keys())[0]
|
|
||||||
|
|
||||||
# if server_name in all_tools:
|
|
||||||
# server_info = all_tools[server_name]
|
|
||||||
# if server_info["status"] == "connected":
|
|
||||||
# tools = server_info["tools"]
|
|
||||||
# logger.info(f"Found {len(tools)} tools for server {server_name}")
|
|
||||||
# else:
|
|
||||||
# error_msg = server_info.get("error", "Unknown error")
|
|
||||||
# logger.error(f"Server {server_name} failed: {error_msg}")
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=400,
|
|
||||||
# detail=f"Failed to connect to MCP server '{server_name}': {error_msg}"
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# logger.error(f"Server {server_name} not found in results")
|
|
||||||
# raise HTTPException(status_code=400, detail=f"Server '{server_name}' not found in results")
|
|
||||||
# else:
|
|
||||||
# logger.error("No MCP servers configured")
|
|
||||||
# raise HTTPException(status_code=400, detail="No MCP servers configured")
|
|
||||||
|
|
||||||
# except HTTPException:
|
|
||||||
# raise
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error connecting to stdio MCP server: {e}")
|
|
||||||
# import traceback
|
|
||||||
# logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
||||||
# raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}")
|
|
||||||
|
|
||||||
# if request_type == 'http':
|
|
||||||
# if 'url' not in config:
|
|
||||||
# raise HTTPException(status_code=400, detail="HTTP configuration must include 'url' field")
|
|
||||||
# url = config['url']
|
|
||||||
# await connect_streamable_http_server(url)
|
|
||||||
# tools = await connect_streamable_http_server(url)
|
|
||||||
|
|
||||||
# elif request_type == 'sse':
|
|
||||||
# if 'url' not in config:
|
|
||||||
# raise HTTPException(status_code=400, detail="SSE configuration must include 'url' field")
|
|
||||||
|
|
||||||
# url = config['url']
|
|
||||||
# headers = config.get('headers', {})
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# async with asyncio.timeout(15):
|
|
||||||
# try:
|
|
||||||
# async with sse_client(url, headers=headers) as (read, write):
|
|
||||||
# async with ClientSession(read, write) as session:
|
|
||||||
# await session.initialize()
|
|
||||||
# tools_result = await session.list_tools()
|
|
||||||
# tools_info = []
|
|
||||||
# for tool in tools_result.tools:
|
|
||||||
# tool_info = {
|
|
||||||
# "name": tool.name,
|
|
||||||
# "description": tool.description,
|
|
||||||
# "input_schema": tool.inputSchema
|
|
||||||
# }
|
|
||||||
# tools_info.append(tool_info)
|
|
||||||
|
|
||||||
# for tool_info in tools_info:
|
|
||||||
# tools.append({
|
|
||||||
# "name": tool_info["name"],
|
|
||||||
# "description": tool_info["description"],
|
|
||||||
# "inputSchema": tool_info["input_schema"]
|
|
||||||
# })
|
|
||||||
# except TypeError as e:
|
|
||||||
# if "unexpected keyword argument" in str(e):
|
|
||||||
# async with sse_client(url) as (read, write):
|
|
||||||
# async with ClientSession(read, write) as session:
|
|
||||||
# await session.initialize()
|
|
||||||
# tools_result = await session.list_tools()
|
|
||||||
# tools_info = []
|
|
||||||
# for tool in tools_result.tools:
|
|
||||||
# tool_info = {
|
|
||||||
# "name": tool.name,
|
|
||||||
# "description": tool.description,
|
|
||||||
# "input_schema": tool.inputSchema
|
|
||||||
# }
|
|
||||||
# tools_info.append(tool_info)
|
|
||||||
|
|
||||||
# for tool_info in tools_info:
|
|
||||||
# tools.append({
|
|
||||||
# "name": tool_info["name"],
|
|
||||||
# "description": tool_info["description"],
|
|
||||||
# "inputSchema": tool_info["input_schema"]
|
|
||||||
# })
|
|
||||||
# else:
|
|
||||||
# raise
|
|
||||||
# except asyncio.TimeoutError:
|
|
||||||
# raise HTTPException(status_code=408, detail="Connection timeout - server took too long to respond")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error connecting to SSE MCP server: {e}")
|
|
||||||
# raise HTTPException(status_code=400, detail=f"Failed to connect to MCP server: {str(e)}")
|
|
||||||
# else:
|
|
||||||
# raise HTTPException(status_code=400, detail="Invalid server type. Must be 'json' or 'sse'")
|
|
||||||
|
|
||||||
# response_data = {"tools": tools, "count": len(tools)}
|
|
||||||
|
|
||||||
# if server_name:
|
|
||||||
# response_data["serverName"] = server_name
|
|
||||||
|
|
||||||
# logger.info(f"Returning {len(tools)} tools for server {server_name}")
|
|
||||||
# return response_data
|
|
|
@ -136,10 +136,10 @@ export function FooterSection() {
|
||||||
className="block w-full h-48 md:h-64 relative mt-24 z-0 cursor-pointer"
|
className="block w-full h-48 md:h-64 relative mt-24 z-0 cursor-pointer"
|
||||||
>
|
>
|
||||||
<div className="absolute inset-0 bg-gradient-to-t from-transparent to-background z-10 from-40%" />
|
<div className="absolute inset-0 bg-gradient-to-t from-transparent to-background z-10 from-40%" />
|
||||||
<div className="absolute inset-0 mx-6">
|
<div className="absolute inset-0 ">
|
||||||
<FlickeringGrid
|
<FlickeringGrid
|
||||||
text={tablet ? 'Agents Agents Agents' : 'Agents Agents Agents'}
|
text={tablet ? 'Agents' : 'Agents Agents Agents'}
|
||||||
fontSize={tablet ? 70 : 90}
|
fontSize={tablet ? 60 : 90}
|
||||||
className="h-full w-full"
|
className="h-full w-full"
|
||||||
squareSize={2}
|
squareSize={2}
|
||||||
gridGap={tablet ? 2 : 3}
|
gridGap={tablet ? 2 : 3}
|
||||||
|
|
Loading…
Reference in New Issue