mirror of https://github.com/kortix-ai/suna.git
Merge pull request #47 from kortix-ai/reasoning
Reasoning + Cost Calculation
This commit is contained in:
commit
be9e6d46cc
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
import json
|
||||
|
@ -7,6 +7,7 @@ from datetime import datetime, timezone
|
|||
import uuid
|
||||
from typing import Optional, List, Dict, Any
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
|
@ -15,7 +16,14 @@ from agent.run import run_agent
|
|||
from utils.auth_utils import get_current_user_id, get_user_id_from_stream_auth, verify_thread_access
|
||||
from utils.logger import logger
|
||||
from utils.billing import check_billing_status, get_account_id_from_thread
|
||||
from utils.db import update_agent_run_status
|
||||
# Removed duplicate import of update_agent_run_status from utils.db as it's defined locally
|
||||
|
||||
# Define request model for starting agent
|
||||
class AgentStartRequest(BaseModel):
|
||||
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||
enable_thinking: Optional[bool] = False
|
||||
reasoning_effort: Optional[str] = 'low'
|
||||
stream: Optional[bool] = False # Add stream option, default to False
|
||||
|
||||
# Initialize shared resources
|
||||
router = APIRouter()
|
||||
|
@ -236,11 +244,15 @@ async def _cleanup_agent_run(agent_run_id: str):
|
|||
# Non-fatal error, can continue
|
||||
|
||||
@router.post("/thread/{thread_id}/agent/start")
|
||||
async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id)):
|
||||
"""Start an agent for a specific thread in the background."""
|
||||
logger.info(f"Starting new agent for thread: {thread_id}")
|
||||
async def start_agent(
|
||||
thread_id: str,
|
||||
body: AgentStartRequest = Body(...), # Accept request body
|
||||
user_id: str = Depends(get_current_user_id)
|
||||
):
|
||||
"""Start an agent for a specific thread in the background with dynamic settings."""
|
||||
logger.info(f"Starting new agent for thread: {thread_id} with model: {body.model_name}, thinking: {body.enable_thinking}, effort: {body.reasoning_effort}, stream: {body.stream}")
|
||||
client = await db.client
|
||||
|
||||
|
||||
# Verify user has access to this thread
|
||||
await verify_thread_access(client, thread_id, user_id)
|
||||
|
||||
|
@ -291,11 +303,20 @@ async def start_agent(thread_id: str, user_id: str = Depends(get_current_user_id
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to register agent run in Redis, continuing without Redis tracking: {str(e)}")
|
||||
|
||||
# Run the agent in the background
|
||||
# Run the agent in the background, passing the dynamic settings
|
||||
task = asyncio.create_task(
|
||||
run_agent_background(agent_run_id, thread_id, instance_id, project_id)
|
||||
run_agent_background(
|
||||
agent_run_id=agent_run_id,
|
||||
thread_id=thread_id,
|
||||
instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
model_name=body.model_name,
|
||||
enable_thinking=body.enable_thinking,
|
||||
reasoning_effort=body.reasoning_effort,
|
||||
stream=body.stream # Pass stream parameter
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Set a callback to clean up when task is done
|
||||
task.add_done_callback(
|
||||
lambda _: asyncio.create_task(
|
||||
|
@ -420,11 +441,20 @@ async def stream_agent_run(
|
|||
}
|
||||
)
|
||||
|
||||
async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: str, project_id: str):
|
||||
async def run_agent_background(
|
||||
agent_run_id: str,
|
||||
thread_id: str,
|
||||
instance_id: str,
|
||||
project_id: str,
|
||||
model_name: str, # Add model_name parameter
|
||||
enable_thinking: Optional[bool], # Add enable_thinking parameter
|
||||
reasoning_effort: Optional[str], # Add reasoning_effort parameter
|
||||
stream: bool # Add stream parameter
|
||||
):
|
||||
"""Run the agent in the background and handle status updates."""
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id})")
|
||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model: {model_name}, stream: {stream}")
|
||||
client = await db.client
|
||||
|
||||
|
||||
# Tracking variables
|
||||
total_responses = 0
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
@ -538,11 +568,18 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
|||
logger.warning(f"No stop signal checker for agent run: {agent_run_id} - pubsub unavailable")
|
||||
|
||||
try:
|
||||
# Run the agent
|
||||
# Run the agent, passing the dynamic settings
|
||||
logger.debug(f"Initializing agent generator for thread: {thread_id} (instance: {instance_id})")
|
||||
agent_gen = run_agent(thread_id, stream=True,
|
||||
thread_manager=thread_manager, project_id=project_id)
|
||||
|
||||
agent_gen = run_agent(
|
||||
thread_id=thread_id,
|
||||
project_id=project_id,
|
||||
stream=stream, # Pass stream parameter from API request
|
||||
thread_manager=thread_manager,
|
||||
model_name=model_name, # Pass model_name
|
||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||
)
|
||||
|
||||
# Collect all responses to save to database
|
||||
all_responses = []
|
||||
|
||||
|
@ -654,4 +691,4 @@ async def run_agent_background(agent_run_id: str, thread_id: str, instance_id: s
|
|||
except Exception as e:
|
||||
logger.warning(f"Error deleting active run key: {str(e)}")
|
||||
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (instance: {instance_id})")
|
||||
|
|
|
@ -20,9 +20,19 @@ from utils.billing import check_billing_status, get_account_id_from_thread
|
|||
|
||||
load_dotenv()
|
||||
|
||||
async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread_manager: Optional[ThreadManager] = None, native_max_auto_continues: int = 25, max_iterations: int = 150):
|
||||
async def run_agent(
|
||||
thread_id: str,
|
||||
project_id: str,
|
||||
stream: bool, # Accept stream parameter from caller (api.py)
|
||||
thread_manager: Optional[ThreadManager] = None,
|
||||
native_max_auto_continues: int = 25,
|
||||
max_iterations: int = 150,
|
||||
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name parameter with default
|
||||
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
|
||||
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
|
||||
):
|
||||
"""Run the development agent with specified configuration."""
|
||||
|
||||
|
||||
if not thread_manager:
|
||||
thread_manager = ThreadManager()
|
||||
client = await thread_manager.db.client
|
||||
|
@ -59,8 +69,8 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
|||
'sandbox': {
|
||||
'id': sandbox_id,
|
||||
'pass': sandbox_pass,
|
||||
'vnc_preview': sandbox.get_preview_link(6080),
|
||||
'sandbox_url': sandbox.get_preview_link(8080)
|
||||
'vnc_preview': str(sandbox.get_preview_link(6080)), # Convert to string
|
||||
'sandbox_url': str(sandbox.get_preview_link(8080)) # Convert to string
|
||||
}
|
||||
}).eq('project_id', project_id).execute()
|
||||
|
||||
|
@ -161,23 +171,29 @@ async def run_agent(thread_id: str, project_id: str, stream: bool = True, thread
|
|||
print(f"Error parsing browser state: {e}")
|
||||
# print(latest_browser_state.data[0])
|
||||
|
||||
# Run Thread
|
||||
# Determine max tokens based on the passed model_name
|
||||
max_tokens = None
|
||||
if model_name == "anthropic/claude-3-7-sonnet-latest":
|
||||
max_tokens = 64000 # Example: Set max tokens for a specific model
|
||||
|
||||
# Run Thread, passing the dynamic settings
|
||||
response = await thread_manager.run_thread(
|
||||
thread_id=thread_id,
|
||||
system_prompt=system_message, # Pass the constructed message
|
||||
stream=stream,
|
||||
llm_model=os.getenv("MODEL_TO_USE", "anthropic/claude-3-7-sonnet-latest"),
|
||||
llm_temperature=0,
|
||||
llm_max_tokens=64000,
|
||||
stream=stream, # Pass the stream parameter received by run_agent
|
||||
llm_model=model_name, # Use the passed model_name
|
||||
llm_temperature=1, # Example temperature
|
||||
llm_max_tokens=max_tokens, # Use the determined value
|
||||
tool_choice="auto",
|
||||
max_xml_tool_calls=1,
|
||||
temporary_message=temporary_message,
|
||||
processor_config=processor_config, # Pass the config object
|
||||
native_max_auto_continues=native_max_auto_continues,
|
||||
# Explicitly set include_xml_examples to False here
|
||||
include_xml_examples=False,
|
||||
include_xml_examples=False, # Explicitly set include_xml_examples to False here
|
||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||
)
|
||||
|
||||
|
||||
if isinstance(response, dict) and "status" in response and response["status"] == "error":
|
||||
yield response
|
||||
break
|
||||
|
@ -279,7 +295,8 @@ async def test_agent():
|
|||
|
||||
if not user_message.strip():
|
||||
print("\n🔄 Running agent...\n")
|
||||
await process_agent_response(thread_id, project_id, thread_manager)
|
||||
# Pass stream=True explicitly when calling from test_agent
|
||||
await process_agent_response(thread_id, project_id, thread_manager, stream=True)
|
||||
continue
|
||||
|
||||
# Add the user message to the thread
|
||||
|
@ -294,19 +311,40 @@ async def test_agent():
|
|||
)
|
||||
|
||||
print("\n🔄 Running agent...\n")
|
||||
await process_agent_response(thread_id, project_id, thread_manager)
|
||||
|
||||
# Pass stream=True explicitly when calling from test_agent
|
||||
await process_agent_response(thread_id, project_id, thread_manager, stream=True)
|
||||
|
||||
print("\n👋 Test completed. Goodbye!")
|
||||
|
||||
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
|
||||
"""Process the streaming response from the agent."""
|
||||
async def process_agent_response(
|
||||
thread_id: str,
|
||||
project_id: str,
|
||||
thread_manager: ThreadManager,
|
||||
stream: bool = True, # Add stream parameter, default to True for testing
|
||||
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name with default
|
||||
enable_thinking: Optional[bool] = False, # Add enable_thinking
|
||||
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort
|
||||
):
|
||||
"""Process the streaming response from the agent, passing model/thinking parameters."""
|
||||
chunk_counter = 0
|
||||
current_response = ""
|
||||
tool_call_counter = 0 # Track number of tool calls
|
||||
|
||||
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
||||
|
||||
# Pass the received parameters to the run_agent call
|
||||
agent_generator = run_agent(
|
||||
thread_id=thread_id,
|
||||
project_id=project_id,
|
||||
stream=stream, # Pass the stream parameter here
|
||||
thread_manager=thread_manager,
|
||||
native_max_auto_continues=25,
|
||||
model_name=model_name,
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort
|
||||
)
|
||||
|
||||
async for chunk in agent_generator:
|
||||
chunk_counter += 1
|
||||
|
||||
|
||||
if chunk.get('type') == 'content' and 'content' in chunk:
|
||||
current_response += chunk.get('content', '')
|
||||
# Print the response as it comes in
|
||||
|
@ -382,4 +420,4 @@ if __name__ == "__main__":
|
|||
load_dotenv() # Ensure environment variables are loaded
|
||||
|
||||
# Run the test function
|
||||
asyncio.run(test_agent())
|
||||
asyncio.run(test_agent())
|
||||
|
|
|
@ -95,6 +95,8 @@ class ResponseProcessor:
|
|||
self,
|
||||
llm_response: AsyncGenerator,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
) -> AsyncGenerator:
|
||||
"""Process a streaming LLM response, handling tool calls and execution.
|
||||
|
@ -102,6 +104,8 @@ class ResponseProcessor:
|
|||
Args:
|
||||
llm_response: Streaming response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
|
||||
Yields:
|
||||
|
@ -140,9 +144,6 @@ class ResponseProcessor:
|
|||
# if config.max_xml_tool_calls > 0:
|
||||
# logger.info(f"XML tool call limit enabled: {config.max_xml_tool_calls}")
|
||||
|
||||
accumulated_cost = 0
|
||||
accumulated_token_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in llm_response:
|
||||
# Default content to yield
|
||||
|
@ -155,22 +156,17 @@ class ResponseProcessor:
|
|||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
delta = chunk.choices[0].delta if hasattr(chunk.choices[0], 'delta') else None
|
||||
|
||||
# Check for and log Anthropic thinking content
|
||||
if delta and hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||
logger.info(f"[THINKING]: {delta.reasoning_content}")
|
||||
accumulated_content += delta.reasoning_content # Append reasoning to main content
|
||||
|
||||
# Process content chunk
|
||||
if delta and hasattr(delta, 'content') and delta.content:
|
||||
chunk_content = delta.content
|
||||
accumulated_content += chunk_content
|
||||
current_xml_content += chunk_content
|
||||
|
||||
# Calculate cost using prompt and completion
|
||||
try:
|
||||
cost = completion_cost(model=chunk.model, prompt=accumulated_content, completion=chunk_content)
|
||||
tcount = token_counter(model=chunk.model, messages=[{"role": "user", "content": accumulated_content}])
|
||||
accumulated_cost += cost
|
||||
accumulated_token_count += tcount
|
||||
logger.debug(f"Cost: {cost:.6f}, Token count: {tcount}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating cost: {str(e)}")
|
||||
|
||||
# Check if we've reached the XML tool call limit before yielding content
|
||||
if config.max_xml_tool_calls > 0 and xml_tool_call_count >= config.max_xml_tool_calls:
|
||||
# We've reached the limit, don't yield any more content
|
||||
|
@ -334,7 +330,7 @@ class ResponseProcessor:
|
|||
|
||||
# If we've reached the XML tool call limit, stop streaming
|
||||
if finish_reason == "xml_tool_limit_reached":
|
||||
logger.info("Stopping stream due to XML tool call limit")
|
||||
logger.info("Stopping stream processing after loop due to XML tool call limit")
|
||||
break
|
||||
|
||||
# After streaming completes or is stopped due to limit, wait for any remaining tool executions
|
||||
|
@ -466,6 +462,27 @@ class ResponseProcessor:
|
|||
is_llm_message=True
|
||||
)
|
||||
|
||||
# Calculate and store cost AFTER adding the main assistant message
|
||||
if accumulated_content: # Calculate cost if there was content (now includes reasoning)
|
||||
try:
|
||||
final_cost = completion_cost(
|
||||
model=llm_model, # Use the passed model name
|
||||
messages=prompt_messages,
|
||||
completion=accumulated_content
|
||||
)
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False # Cost is metadata, not LLM content
|
||||
)
|
||||
else:
|
||||
logger.info("Cost calculation resulted in zero or None, not storing cost message.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for stream: {str(e)}")
|
||||
|
||||
# Now add all buffered tool results AFTER the assistant message, but don't yield if already yielded
|
||||
for tool_call, result, result_tool_index in tool_results_buffer:
|
||||
# Add result based on tool type to the conversation history
|
||||
|
@ -559,24 +576,19 @@ class ResponseProcessor:
|
|||
yield {"type": "error", "message": str(e)}
|
||||
|
||||
finally:
|
||||
pass
|
||||
# track the cost and token count
|
||||
# todo: there is a bug as it adds every chunk to db because finally will run every time even in yield
|
||||
# await self.add_message(
|
||||
# thread_id=thread_id,
|
||||
# type="cost",
|
||||
# content={
|
||||
# "cost": accumulated_cost,
|
||||
# "token_count": accumulated_token_count
|
||||
# },
|
||||
# is_llm_message=False
|
||||
# )
|
||||
|
||||
# Finally, yield the finish reason if it was detected
|
||||
if finish_reason:
|
||||
yield {
|
||||
"type": "finish",
|
||||
"finish_reason": finish_reason
|
||||
}
|
||||
|
||||
async def process_non_streaming_response(
|
||||
self,
|
||||
llm_response: Any,
|
||||
thread_id: str,
|
||||
prompt_messages: List[Dict[str, Any]],
|
||||
llm_model: str,
|
||||
config: ProcessorConfig = ProcessorConfig(),
|
||||
) -> AsyncGenerator:
|
||||
"""Process a non-streaming LLM response, handling tool calls and execution.
|
||||
|
@ -584,6 +596,8 @@ class ResponseProcessor:
|
|||
Args:
|
||||
llm_response: Response from the LLM
|
||||
thread_id: ID of the conversation thread
|
||||
prompt_messages: List of messages sent to the LLM (the prompt)
|
||||
llm_model: The name of the LLM model used
|
||||
config: Configuration for parsing and execution
|
||||
|
||||
Yields:
|
||||
|
@ -684,6 +698,33 @@ class ResponseProcessor:
|
|||
is_llm_message=True
|
||||
)
|
||||
|
||||
# Calculate and store cost AFTER adding the main assistant message
|
||||
if content or tool_calls: # Calculate cost if there's content or tool calls
|
||||
try:
|
||||
# Use the full response object for potentially more accurate cost calculation
|
||||
# Pass model explicitly as it might not be reliably in response_object for all providers
|
||||
# First check if response_cost is directly available in _hidden_params
|
||||
if hasattr(llm_response, '_hidden_params') and 'response_cost' in llm_response._hidden_params and llm_response._hidden_params['response_cost'] != 0.0:
|
||||
final_cost = llm_response._hidden_params['response_cost']
|
||||
else:
|
||||
# Fall back to calculating cost if direct cost not available
|
||||
final_cost = completion_cost(
|
||||
completion_response=llm_response,
|
||||
model=llm_model,
|
||||
call_type="completion" # Assuming 'completion' type for this context
|
||||
)
|
||||
|
||||
if final_cost is not None and final_cost > 0:
|
||||
logger.info(f"Calculated final cost for non-stream: {final_cost}")
|
||||
await self.add_message(
|
||||
thread_id=thread_id,
|
||||
type="cost",
|
||||
content={"cost": final_cost},
|
||||
is_llm_message=False # Cost is metadata
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating final cost for non-stream: {str(e)}")
|
||||
|
||||
# Yield content first
|
||||
yield {"type": "content", "content": content}
|
||||
|
||||
|
|
|
@ -154,9 +154,10 @@ class ThreadManager:
|
|||
native_max_auto_continues: int = 25,
|
||||
max_xml_tool_calls: int = 0,
|
||||
include_xml_examples: bool = False,
|
||||
enable_thinking: Optional[bool] = False, # Add enable_thinking parameter
|
||||
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort parameter
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""Run a conversation thread with LLM integration and tool execution.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread to run
|
||||
system_prompt: System message to set the assistant's behavior
|
||||
|
@ -279,7 +280,9 @@ class ThreadManager:
|
|||
max_tokens=llm_max_tokens,
|
||||
tools=openapi_tool_schemas,
|
||||
tool_choice=tool_choice if processor_config.native_tool_calling else None,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||
reasoning_effort=reasoning_effort # Pass reasoning_effort
|
||||
)
|
||||
logger.debug("Successfully received raw LLM API response stream/object")
|
||||
|
||||
|
@ -293,6 +296,8 @@ class ThreadManager:
|
|||
response_generator = self.response_processor.process_streaming_response(
|
||||
llm_response=llm_response,
|
||||
thread_id=thread_id,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model,
|
||||
config=processor_config
|
||||
)
|
||||
|
||||
|
@ -303,55 +308,91 @@ class ThreadManager:
|
|||
response_generator = self.response_processor.process_non_streaming_response(
|
||||
llm_response=llm_response,
|
||||
thread_id=thread_id,
|
||||
prompt_messages=prepared_messages,
|
||||
llm_model=llm_model,
|
||||
config=processor_config
|
||||
)
|
||||
return response_generator # Return the generator
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_thread: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
logger.error(f"Error in _run_once: {str(e)}", exc_info=True)
|
||||
# For generators, we need to yield an error structure if returning a generator is expected
|
||||
async def error_generator():
|
||||
yield {
|
||||
"type": "error",
|
||||
"message": f"Error during LLM call or setup: {str(e)}"
|
||||
}
|
||||
return error_generator()
|
||||
|
||||
# Define a wrapper generator that handles auto-continue logic
|
||||
async def auto_continue_wrapper():
|
||||
nonlocal auto_continue, auto_continue_count
|
||||
nonlocal auto_continue, auto_continue_count, temporary_message
|
||||
|
||||
current_temp_message = temporary_message # Use a local copy for the first run
|
||||
|
||||
while auto_continue and (native_max_auto_continues == 0 or auto_continue_count < native_max_auto_continues):
|
||||
# Reset auto_continue for this iteration
|
||||
auto_continue = False
|
||||
|
||||
# Run the thread once
|
||||
response_gen = await _run_once(temporary_message if auto_continue_count == 0 else None)
|
||||
# Pass current_temp_message, which is only set for the first iteration
|
||||
response_gen = await _run_once(temp_msg=current_temp_message)
|
||||
|
||||
# Handle error responses
|
||||
if isinstance(response_gen, dict) and "status" in response_gen and response_gen["status"] == "error":
|
||||
# Clear the temporary message after the first run
|
||||
current_temp_message = None
|
||||
|
||||
# Handle error responses (checking if it's an error dict, which _run_once might return directly)
|
||||
if isinstance(response_gen, dict) and response_gen.get("status") == "error":
|
||||
yield response_gen
|
||||
return
|
||||
|
||||
# Check if it's the error generator from _run_once exception handling
|
||||
# Need a way to check if it's the specific error generator or just inspect the first item
|
||||
first_chunk = None
|
||||
try:
|
||||
first_chunk = await anext(response_gen)
|
||||
except StopAsyncIteration:
|
||||
# Empty generator, possibly due to an issue before yielding.
|
||||
logger.warning("Response generator was empty.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting first chunk from generator: {e}")
|
||||
yield {"type": "error", "message": f"Error processing response: {e}"}
|
||||
break
|
||||
|
||||
if first_chunk and first_chunk.get('type') == 'error' and "Error during LLM call" in first_chunk.get('message', ''):
|
||||
yield first_chunk
|
||||
return # Stop processing if setup failed
|
||||
|
||||
# Yield the first chunk if it wasn't an error
|
||||
if first_chunk:
|
||||
yield first_chunk
|
||||
|
||||
# Process each chunk
|
||||
# Process remaining chunks
|
||||
async for chunk in response_gen:
|
||||
# Check if this is a finish reason chunk with tool_calls or xml_tool_limit_reached
|
||||
if chunk.get('type') == 'finish':
|
||||
if chunk.get('finish_reason') == 'tool_calls':
|
||||
finish_reason = chunk.get('finish_reason')
|
||||
if finish_reason == 'tool_calls':
|
||||
# Only auto-continue if enabled (max > 0)
|
||||
if native_max_auto_continues > 0:
|
||||
logger.info(f"Detected finish_reason='tool_calls', auto-continuing ({auto_continue_count + 1}/{native_max_auto_continues})")
|
||||
auto_continue = True
|
||||
auto_continue_count += 1
|
||||
# Don't yield the finish chunk to avoid confusing the client
|
||||
continue
|
||||
elif chunk.get('finish_reason') == 'xml_tool_limit_reached':
|
||||
# Don't yield the finish chunk to avoid confusing the client during auto-continue
|
||||
continue
|
||||
elif finish_reason == 'xml_tool_limit_reached':
|
||||
# Don't auto-continue if XML tool limit was reached
|
||||
logger.info(f"Detected finish_reason='xml_tool_limit_reached', stopping auto-continue")
|
||||
auto_continue = False
|
||||
# Still yield the chunk to inform the client
|
||||
|
||||
# Yield other finish reasons normally
|
||||
|
||||
# Otherwise just yield the chunk normally
|
||||
# Yield the chunk normally
|
||||
yield chunk
|
||||
|
||||
# If not auto-continuing, we're done
|
||||
# If not auto-continuing, we're done with the loop
|
||||
if not auto_continue:
|
||||
break
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
streamlit-quill==0.0.3
|
||||
python-dotenv==1.0.1
|
||||
litellm>=1.44.0
|
||||
litellm>=1.66.2
|
||||
click==8.1.7
|
||||
questionary==2.0.1
|
||||
requests>=2.31.0
|
||||
|
|
|
@ -28,7 +28,7 @@ RATE_LIMIT_DELAY = 30
|
|||
RETRY_DELAY = 5
|
||||
|
||||
# Define debug log directory relative to this file's location
|
||||
DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), '..', 'debug_logs') # Assumes backend/debug_logs
|
||||
DEBUG_LOG_DIR = os.path.join(os.path.dirname(__file__), 'debug_logs')
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base exception for LLM-related errors."""
|
||||
|
@ -86,7 +86,10 @@ def prepare_params(
|
|||
api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None
|
||||
model_id: Optional[str] = None,
|
||||
# Add parameters for thinking/reasoning
|
||||
enable_thinking: Optional[bool] = None,
|
||||
reasoning_effort: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare parameters for the API call."""
|
||||
params = {
|
||||
|
@ -156,6 +159,24 @@ def prepare_params(
|
|||
params["model_id"] = "arn:aws:bedrock:us-west-2:935064898258:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0"
|
||||
logger.debug(f"Auto-set model_id for Claude 3.7 Sonnet: {params['model_id']}")
|
||||
|
||||
# --- Add Anthropic Thinking/Reasoning Effort ---
|
||||
# Determine if thinking should be enabled based on the passed parameter
|
||||
use_thinking = enable_thinking if enable_thinking is not None else False
|
||||
|
||||
# Check if the model is Anthropic
|
||||
is_anthropic = "sonnet-3-7" in model_name.lower() or "anthropic" in model_name.lower()
|
||||
|
||||
# Add reasoning_effort parameter if enabled and applicable
|
||||
if is_anthropic and use_thinking:
|
||||
# Determine reasoning effort based on the passed parameter, defaulting to 'low'
|
||||
effort_level = reasoning_effort if reasoning_effort else 'low'
|
||||
|
||||
params["reasoning_effort"] = effort_level
|
||||
logger.info(f"Anthropic thinking enabled with reasoning_effort='{effort_level}'")
|
||||
|
||||
# Anthropic requires temperature=1 when thinking/reasoning_effort is enabled
|
||||
params["temperature"] = 1.0
|
||||
|
||||
return params
|
||||
|
||||
async def make_llm_api_call(
|
||||
|
@ -170,7 +191,10 @@ async def make_llm_api_call(
|
|||
api_base: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
top_p: Optional[float] = None,
|
||||
model_id: Optional[str] = None
|
||||
model_id: Optional[str] = None,
|
||||
# Add parameters for thinking/reasoning
|
||||
enable_thinking: Optional[bool] = None,
|
||||
reasoning_effort: Optional[str] = None
|
||||
) -> Union[Dict[str, Any], AsyncGenerator]:
|
||||
"""
|
||||
Make an API call to a language model using LiteLLM.
|
||||
|
@ -209,7 +233,10 @@ async def make_llm_api_call(
|
|||
api_base=api_base,
|
||||
stream=stream,
|
||||
top_p=top_p,
|
||||
model_id=model_id
|
||||
model_id=model_id,
|
||||
# Add parameters for thinking/reasoning
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort
|
||||
)
|
||||
|
||||
# Apply Anthropic prompt caching (minimal implementation)
|
||||
|
@ -272,15 +299,19 @@ async def make_llm_api_call(
|
|||
# Initialize log path to None, it will be set only if logging is enabled
|
||||
response_log_path = None
|
||||
enable_debug_logging = os.environ.get('ENABLE_LLM_DEBUG_LOGGING', 'false').lower() == 'true'
|
||||
# enable_debug_logging = True
|
||||
|
||||
if enable_debug_logging:
|
||||
try:
|
||||
os.makedirs(DEBUG_LOG_DIR, exist_ok=True)
|
||||
# save the model name too
|
||||
model_name = params["model"]
|
||||
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
# Use a unique ID or counter if calls can happen in the same second
|
||||
# For simplicity, using timestamp only for now
|
||||
request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}.json")
|
||||
response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}.json") # Set here if enabled
|
||||
request_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_request_{timestamp}_{model_name}.json")
|
||||
response_log_path = os.path.join(DEBUG_LOG_DIR, f"llm_response_{timestamp}_{model_name}.json") # Set here if enabled
|
||||
|
||||
# Log the request parameters just before the attempt loop
|
||||
logger.debug(f"Logging LLM request parameters to {request_log_path}")
|
||||
|
@ -300,7 +331,7 @@ async def make_llm_api_call(
|
|||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
logger.debug(f"Attempt {attempt + 1}/{MAX_RETRIES}")
|
||||
|
||||
# print(params)
|
||||
response = await litellm.acompletion(**params)
|
||||
logger.debug(f"Successfully received API response from {model_name}")
|
||||
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
import asyncio
|
||||
import litellm
|
||||
import copy
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
async def run_conversation_turn(model: str, messages: list, user_prompt: str | list, reasoning_effort: str | None = None):
|
||||
"""
|
||||
Handles a single turn of the conversation, prepares arguments, and calls litellm.acompletion.
|
||||
|
||||
Args:
|
||||
model: The model name string.
|
||||
messages: The list of message dictionaries (will be modified in place).
|
||||
user_prompt: The user's prompt for this turn (string or list).
|
||||
reasoning_effort: Optional reasoning effort string for Anthropic models.
|
||||
|
||||
Returns:
|
||||
The response object from litellm.acompletion.
|
||||
"""
|
||||
# Append user prompt
|
||||
if isinstance(user_prompt, str):
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
elif isinstance(user_prompt, list): # Handle list/dict content structure
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
# --- Start of merged logic from call_litellm_with_cache ---
|
||||
processed_messages = copy.deepcopy(messages) # Work on a copy for modification
|
||||
is_anthropic = model.startswith("anthropic")
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": processed_messages,
|
||||
}
|
||||
|
||||
call_description = [f"Calling {model}", f"{len(processed_messages)} messages"]
|
||||
|
||||
if is_anthropic:
|
||||
# Add cache_control for Anthropic models
|
||||
if processed_messages and processed_messages[0]["role"] == "system":
|
||||
content = processed_messages[0].get("content")
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part["cache_control"] = {"type": "ephemeral"}
|
||||
elif isinstance(content, str):
|
||||
processed_messages[0]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
|
||||
if processed_messages and processed_messages[-1]["role"] == "user":
|
||||
content = processed_messages[-1].get("content")
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part["cache_control"] = {"type": "ephemeral"}
|
||||
elif isinstance(content, str):
|
||||
processed_messages[-1]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
call_description.append("cache enabled")
|
||||
|
||||
# Add reasoning_effort only for Anthropic models if provided
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
call_description.append(f"reasoning: {reasoning_effort}")
|
||||
|
||||
print(f"\n--- {' | '.join(call_description)} ---")
|
||||
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
|
||||
print("--- Full Response Object ---")
|
||||
# Convert response object to dict and print as indented JSON
|
||||
try:
|
||||
print(json.dumps(response.dict(), indent=2))
|
||||
print(response._hidden_params)
|
||||
except Exception as e:
|
||||
print(f"Could not format response as JSON: {e}")
|
||||
print(response) # Fallback to printing the raw object if conversion fails
|
||||
print("--- End Response ---")
|
||||
# --- End of merged logic ---
|
||||
|
||||
|
||||
# Append assistant response to the original messages list
|
||||
if response.choices and response.choices[0].message.content:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.choices[0].message.content
|
||||
})
|
||||
else:
|
||||
# Handle cases where response might be empty or malformed
|
||||
print("Warning: Assistant response content is missing.")
|
||||
messages.append({"role": "assistant", "content": ""}) # Append empty content
|
||||
|
||||
return response
|
||||
|
||||
async def main(model_name: str, reasoning_effort: str = "medium"):
|
||||
hello_string = "Hello " * 1234
|
||||
|
||||
# Initial messages
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": f"Here is some text: {hello_string}"},
|
||||
{"role": "assistant", "content": "Okay, I have received the text."},
|
||||
]
|
||||
|
||||
# Turn 1: Ask to count "Hello"
|
||||
print("\n=== Turn 1: Counting 'Hello' ===")
|
||||
await run_conversation_turn(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
user_prompt="How many times does the word 'Hello' appear in the text I provided?",
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
# Turn 2: Ask for a short story
|
||||
print("\n=== Turn 2: Short Story Request ===")
|
||||
await run_conversation_turn(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
user_prompt=[ # Using list/dict format for user content
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Great, thanks for counting. Now, can you write a short story (less than 50 words) where the word 'Hello' appears exactly 5 times?",
|
||||
}
|
||||
],
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
# Turn 3: Ask about the main character
|
||||
print("\n=== Turn 3: Main Character Question ===")
|
||||
await run_conversation_turn(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
user_prompt=[ # Using list/dict format for user content
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Based on the short story you just wrote, who is the main character?",
|
||||
}
|
||||
],
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Select the model to test
|
||||
model = "anthropic/claude-3-7-sonnet-latest"
|
||||
# model = "groq/llama-3.3-70b-versatile"
|
||||
# model = "openai/gpt-4o-mini"
|
||||
# model = "openai/gpt-4.1-2025-04-14" # Placeholder if needed
|
||||
|
||||
print(f"Running test with model: {model}")
|
||||
asyncio.run(main(
|
||||
model_name=model,
|
||||
# reasoning_effort="medium"
|
||||
reasoning_effort="low"
|
||||
))
|
||||
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
import asyncio
|
||||
import litellm
|
||||
import copy
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
async def run_streaming_conversation_turn(model: str, messages: list, user_prompt: str | list, enable_thinking: bool = False, reasoning_effort: str | None = None):
|
||||
"""
|
||||
Handles a single turn of the conversation using streaming, prepares arguments,
|
||||
and calls litellm.acompletion with stream=True.
|
||||
|
||||
Args:
|
||||
model: The model name string.
|
||||
messages: The list of message dictionaries (will be modified in place).
|
||||
user_prompt: The user's prompt for this turn (string or list).
|
||||
enable_thinking: Boolean to enable thinking for Anthropic models.
|
||||
reasoning_effort: Optional reasoning effort string for Anthropic models.
|
||||
|
||||
Returns:
|
||||
The final accumulated assistant response dictionary.
|
||||
"""
|
||||
# Append user prompt
|
||||
if isinstance(user_prompt, str):
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
elif isinstance(user_prompt, list): # Handle list/dict content structure
|
||||
messages.append({"role": "user", "content": user_prompt})
|
||||
|
||||
processed_messages = copy.deepcopy(messages) # Work on a copy for modification
|
||||
is_anthropic = model.startswith("anthropic")
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": processed_messages,
|
||||
"stream": True, # Enable streaming
|
||||
}
|
||||
|
||||
|
||||
if is_anthropic:
|
||||
# Add cache_control for Anthropic models
|
||||
if processed_messages and processed_messages[0]["role"] == "system":
|
||||
content = processed_messages[0].get("content")
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part["cache_control"] = {"type": "ephemeral"}
|
||||
elif isinstance(content, str):
|
||||
processed_messages[0]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
|
||||
if processed_messages and processed_messages[-1]["role"] == "user":
|
||||
content = processed_messages[-1].get("content")
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part["cache_control"] = {"type": "ephemeral"}
|
||||
elif isinstance(content, str):
|
||||
processed_messages[-1]["content"] = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
|
||||
# Add reasoning_effort only for Anthropic models if provided and thinking is enabled
|
||||
if enable_thinking and reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
stream_response = await litellm.acompletion(**kwargs)
|
||||
|
||||
# Collect the full response from streaming chunks
|
||||
full_response_content = ""
|
||||
thinking_printed = False
|
||||
response_printed = False
|
||||
async for chunk in stream_response:
|
||||
if chunk.choices and chunk.choices[0].delta:
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# Print thinking/reasoning content if present
|
||||
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
||||
if not thinking_printed:
|
||||
print("[Thinking]: ", end="", flush=True)
|
||||
thinking_printed = True
|
||||
print(f"{delta.reasoning_content}", end="", flush=True) # Print thinking step
|
||||
|
||||
# Print and accumulate regular content
|
||||
if delta.content:
|
||||
if not response_printed:
|
||||
# Add newline if thinking was printed before response starts
|
||||
if thinking_printed:
|
||||
print() # Newline to separate thinking and response
|
||||
print("[Response]: ", end="", flush=True)
|
||||
response_printed = True
|
||||
chunk_content = delta.content
|
||||
full_response_content += chunk_content
|
||||
# Stream to stdout in real-time
|
||||
print(chunk_content, end="", flush=True)
|
||||
|
||||
print() # Newline after streaming finishes
|
||||
|
||||
# Print hidden params if available
|
||||
try:
|
||||
print("--- Hidden Params ---")
|
||||
print(stream_response._hidden_params)
|
||||
print("--- End Hidden Params ---")
|
||||
except AttributeError:
|
||||
print("(_hidden_params attribute not found on stream response object)")
|
||||
except Exception as e:
|
||||
print(f"Could not print _hidden_params: {e}")
|
||||
|
||||
print("--------------------------------")
|
||||
print() # Add another newline for separation
|
||||
|
||||
# Create a complete response object with the full content
|
||||
final_response = {
|
||||
"model": model,
|
||||
"choices": [{
|
||||
"message": {"role": "assistant", "content": full_response_content}
|
||||
}]
|
||||
}
|
||||
|
||||
# Add the assistant's response to the messages
|
||||
messages.append({"role": "assistant", "content": full_response_content})
|
||||
|
||||
return final_response
|
||||
|
||||
|
||||
async def main(model_name: str, enable_thinking: bool = False, reasoning_effort: str = "medium"):
|
||||
"""Runs a multi-turn conversation test with streaming enabled."""
|
||||
hello_string = "Hello " * 1234
|
||||
|
||||
# Initial messages
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": f"Here is some text: {hello_string}"},
|
||||
{"role": "assistant", "content": "Okay, I have received the text."},
|
||||
]
|
||||
|
||||
# Turn 1: Ask to count "Hello"
|
||||
await run_streaming_conversation_turn(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
user_prompt="How many times does the word 'Hello' appear in the text I provided?",
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
# Turn 2: Ask for a short story
|
||||
await run_streaming_conversation_turn(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
user_prompt=[ # Using list/dict format for user content
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Great, thanks for counting. Now, can you write a short story (less than 50 words) where the word 'Hello' appears exactly 5 times?",
|
||||
}
|
||||
],
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
# Turn 3: Ask about the main character
|
||||
await run_streaming_conversation_turn(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
user_prompt=[ # Using list/dict format for user content
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Based on the short story you just wrote, who is the main character?",
|
||||
}
|
||||
],
|
||||
enable_thinking=enable_thinking,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Select the model to test
|
||||
model = "anthropic/claude-3-7-sonnet-latest"
|
||||
# model = "openai/gpt-4o-mini"
|
||||
# model = "openai/gpt-4.1-2025-04-14" # Placeholder if needed
|
||||
|
||||
asyncio.run(main(
|
||||
model_name=model,
|
||||
enable_thinking=True, # Enable thinking for the test run
|
||||
# reasoning_effort="medium"
|
||||
reasoning_effort="low" # Start with low for faster responses
|
||||
))
|
|
@ -1,82 +0,0 @@
|
|||
import asyncio
|
||||
import litellm
|
||||
|
||||
async def main():
|
||||
initial_messages=[
|
||||
# System Message
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Here is the full text of a complex legal agreement"
|
||||
* 400,
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
# marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache.
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the key terms and conditions in this agreement?",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/month",
|
||||
},
|
||||
# The final turn is marked with cache-control, for continuing in followups.
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are the key terms and conditions in this agreement?",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
print("--- First call ---")
|
||||
first_response = await litellm.acompletion(
|
||||
model="anthropic/claude-3-7-sonnet-latest",
|
||||
messages=initial_messages
|
||||
)
|
||||
print(first_response)
|
||||
|
||||
# Prepare messages for the second call
|
||||
second_call_messages = initial_messages + [
|
||||
{
|
||||
"role": "assistant",
|
||||
# Extract the assistant's response content from the first call
|
||||
"content": first_response.choices[0].message.content
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you elaborate on the termination clause based on the provided text? Remember the context.",
|
||||
"cache_control": {"type": "ephemeral"}, # Mark for caching
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
print("\n--- Second call (testing cache) ---")
|
||||
second_response = await litellm.acompletion(
|
||||
model="anthropic/claude-3-7-sonnet-latest",
|
||||
messages=second_call_messages
|
||||
)
|
||||
print(second_response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Automated test script for the AgentPress agent.
|
||||
|
||||
This script sends a specific query ("short report for today's news") to the agent
|
||||
and prints the streaming output to the console.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
from agent.run import process_agent_response # Reusing the processing logic
|
||||
from utils.logger import logger
|
||||
|
||||
async def main():
|
||||
"""Main function to run the automated agent test."""
|
||||
load_dotenv() # Ensure environment variables are loaded
|
||||
|
||||
logger.info("--- Starting Automated Agent Test: News Report ---")
|
||||
|
||||
# Initialize ThreadManager and DBConnection
|
||||
thread_manager = ThreadManager()
|
||||
db_connection = DBConnection()
|
||||
await db_connection.initialize() # Ensure connection is ready
|
||||
client = await db_connection.client
|
||||
|
||||
project_id = None
|
||||
thread_id = None
|
||||
|
||||
try:
|
||||
# 1. Set up Test Project and Thread
|
||||
logger.info("Setting up test project and thread...")
|
||||
|
||||
# Get user's personal account (replace with actual logic if needed)
|
||||
# Using a fixed account ID for simplicity in this example
|
||||
# In a real scenario, you might fetch this dynamically
|
||||
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Example account ID
|
||||
test_project_name = "automated_test_project_news"
|
||||
|
||||
if not account_id:
|
||||
logger.error("Error: Could not determine account ID.")
|
||||
return
|
||||
|
||||
# Find or create a test project
|
||||
project_result = await client.table('projects').select('*').eq('name', test_project_name).eq('account_id',
|
||||
account_id).limit(1).execute()
|
||||
|
||||
if project_result.data:
|
||||
project_id = project_result.data[0]['project_id']
|
||||
logger.info(f"Using existing test project: {project_id}")
|
||||
else:
|
||||
project_insert_result = await client.table('projects').insert({
|
||||
"name": test_project_name,
|
||||
"account_id": account_id
|
||||
}).execute()
|
||||
if not project_insert_result.data:
|
||||
logger.error("Failed to create test project.")
|
||||
return
|
||||
project_id = project_insert_result.data[0]['project_id']
|
||||
logger.info(f"Created new test project: {project_id}")
|
||||
|
||||
# Create a new thread for this test run
|
||||
thread_result = await client.table('threads').insert({
|
||||
'project_id': project_id,
|
||||
'account_id': account_id
|
||||
}).execute()
|
||||
|
||||
if not thread_result.data:
|
||||
logger.error("Error: Failed to create test thread.")
|
||||
return
|
||||
|
||||
thread_id = thread_result.data[0]['thread_id']
|
||||
logger.info(f"Test Thread Created: {thread_id}")
|
||||
|
||||
# 2. Define and Add User Message
|
||||
user_message = "short report for today's news"
|
||||
logger.info(f"Adding user message to thread: '{user_message}'")
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
type="user",
|
||||
content={
|
||||
"role": "user",
|
||||
"content": user_message
|
||||
},
|
||||
is_llm_message=True # Treat it as a message the LLM should see
|
||||
)
|
||||
|
||||
# 3. Run the Agent and Process Response
|
||||
logger.info("Running agent and processing response...")
|
||||
# We reuse the process_agent_response function from run.py which handles the streaming output
|
||||
await process_agent_response(thread_id, project_id, thread_manager)
|
||||
|
||||
logger.info("--- Agent Test Completed ---")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during the test: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
# Optional: Clean up the created thread?
|
||||
# if thread_id:
|
||||
# logger.info(f"Cleaning up test thread: {thread_id}")
|
||||
# await client.table('messages').delete().eq('thread_id', thread_id).execute()
|
||||
# await client.table('threads').delete().eq('thread_id', thread_id).execute()
|
||||
|
||||
# Disconnect DB
|
||||
await db_connection.disconnect()
|
||||
logger.info("Database connection closed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging if needed (e.g., set level)
|
||||
# logging.getLogger('agentpress').setLevel(logging.DEBUG)
|
||||
|
||||
# Run the main async function
|
||||
asyncio.run(main())
|
|
@ -0,0 +1,215 @@
|
|||
"""
|
||||
Test script for running the AgentPress agent with thinking enabled.
|
||||
|
||||
This test specifically targets Anthropic models that support the 'reasoning_effort'
|
||||
parameter to observe the agent's behavior when thinking is explicitly enabled.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Ensure the backend directory is in the Python path
|
||||
backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if backend_dir not in sys.path:
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
import logging
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
from agent.run import run_agent, process_agent_response # Reuse processing logic
|
||||
from utils.logger import logger
|
||||
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
async def test_agent_with_thinking():
|
||||
"""
|
||||
Test running the agent with thinking enabled for an Anthropic model.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("🧪 TESTING AGENT RUN WITH THINKING ENABLED (Anthropic)")
|
||||
print("="*80 + "\n")
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize ThreadManager and DBConnection
|
||||
thread_manager = ThreadManager()
|
||||
db_connection = DBConnection()
|
||||
await db_connection.initialize() # Ensure connection is ready
|
||||
client = await db_connection.client
|
||||
|
||||
thread_id = None
|
||||
project_id = None
|
||||
project_created = False # Flag to track if we created the project
|
||||
|
||||
try:
|
||||
# --- Test Setup ---
|
||||
print("🔧 Setting up test environment (Project & Thread)...")
|
||||
logger.info("Setting up test project and thread...")
|
||||
|
||||
# Using a hardcoded account ID for consistency in tests
|
||||
account_id = "a5fe9cb6-4812-407e-a61c-fe95b7320c59" # Replace if necessary
|
||||
test_project_name = "test_agent_thinking_project"
|
||||
logger.info(f"Using Account ID: {account_id}")
|
||||
|
||||
if not account_id:
|
||||
print("❌ Error: Could not determine Account ID.")
|
||||
logger.error("Could not determine Account ID.")
|
||||
return
|
||||
|
||||
# Find or create a test project
|
||||
project_result = await client.table('projects').select('*').eq('name', test_project_name).eq('account_id', account_id).limit(1).execute()
|
||||
|
||||
if project_result.data:
|
||||
project_id = project_result.data[0]['project_id']
|
||||
print(f"🔄 Using existing test project: {project_id}")
|
||||
logger.info(f"Using existing test project: {project_id}")
|
||||
else:
|
||||
project_insert_result = await client.table('projects').insert({
|
||||
"name": test_project_name,
|
||||
"account_id": account_id
|
||||
}).execute()
|
||||
if not project_insert_result.data:
|
||||
print("❌ Error: Failed to create test project.")
|
||||
logger.error("Failed to create test project.")
|
||||
return
|
||||
project_id = project_insert_result.data[0]['project_id']
|
||||
project_created = True
|
||||
print(f"✨ Created new test project: {project_id}")
|
||||
logger.info(f"Created new test project: {project_id}")
|
||||
|
||||
# Create a new thread for this test run
|
||||
thread_result = await client.table('threads').insert({
|
||||
'project_id': project_id,
|
||||
'account_id': account_id
|
||||
}).execute()
|
||||
|
||||
if not thread_result.data:
|
||||
print("❌ Error: Failed to create test thread.")
|
||||
logger.error("Failed to create test thread.")
|
||||
return
|
||||
|
||||
thread_id = thread_result.data[0]['thread_id']
|
||||
print(f"🧵 Created new test thread: {thread_id}")
|
||||
logger.info(f"Test Thread Created: {thread_id}")
|
||||
|
||||
# Add an initial user message that requires planning
|
||||
initial_message = "Create a plan to build a simple 'Hello World' HTML page in the workspace, then execute the first step of the plan."
|
||||
print(f"\n💬 Adding initial user message: '{initial_message}'")
|
||||
logger.info(f"Adding initial user message: '{initial_message}'")
|
||||
await thread_manager.add_message(
|
||||
thread_id=thread_id,
|
||||
type="user",
|
||||
content={
|
||||
"role": "user",
|
||||
"content": initial_message
|
||||
},
|
||||
is_llm_message=True
|
||||
)
|
||||
print("✅ Initial message added.")
|
||||
|
||||
# --- Run Agent with Thinking Enabled ---
|
||||
logger.info("Running agent ...")
|
||||
|
||||
# Use the process_agent_response helper to handle streaming output.
|
||||
# Pass the desired model, thinking, and stream parameters directly to it.
|
||||
await process_agent_response(
|
||||
thread_id=thread_id,
|
||||
project_id=project_id,
|
||||
thread_manager=thread_manager,
|
||||
stream=False, # Explicitly set stream to True for testing
|
||||
model_name="anthropic/claude-3-7-sonnet-latest", # Specify the model here
|
||||
enable_thinking=True, # Enable thinking here
|
||||
reasoning_effort='low' # Specify effort here
|
||||
)
|
||||
# await process_agent_response(
|
||||
# thread_id=thread_id,
|
||||
# project_id=project_id,
|
||||
# thread_manager=thread_manager,
|
||||
# model_name="openai/gpt-4.1-2025-04-14", # Specify the model here
|
||||
# model_name="groq/llama-3.3-70b-versatile",
|
||||
# enable_thinking=False, # Enable thinking here
|
||||
# reasoning_effort='low' # Specify effort here
|
||||
# )
|
||||
|
||||
# --- Direct Stream Processing (Alternative to process_agent_response) ---
|
||||
# The direct run_agent call above was removed as process_agent_response handles it.
|
||||
# print("\n--- Agent Response Stream ---")
|
||||
# async for chunk in agent_run_generator:
|
||||
# chunk_type = chunk.get('type', 'unknown')
|
||||
# if chunk_type == 'content' and 'content' in chunk:
|
||||
# print(chunk['content'], end='', flush=True)
|
||||
# elif chunk_type == 'tool_result':
|
||||
# tool_name = chunk.get('function_name', 'Tool')
|
||||
# result = chunk.get('result', '')
|
||||
# print(f"\n\n🛠️ TOOL RESULT [{tool_name}] → {result}", flush=True)
|
||||
# elif chunk_type == 'tool_status':
|
||||
# status = chunk.get('status', '')
|
||||
# func_name = chunk.get('function_name', '')
|
||||
# if status and func_name:
|
||||
# emoji = "✅" if status == "completed" else "⏳" if status == "started" else "❌"
|
||||
# print(f"\n{emoji} TOOL {status.upper()}: {func_name}", flush=True)
|
||||
# elif chunk_type == 'finish':
|
||||
# reason = chunk.get('finish_reason', '')
|
||||
# if reason:
|
||||
# print(f"\n📌 Finished: {reason}", flush=True)
|
||||
# elif chunk_type == 'error':
|
||||
# print(f"\n❌ ERROR: {chunk.get('message', 'Unknown error')}", flush=True)
|
||||
# break # Stop processing on error
|
||||
|
||||
print("\n\n✅ Agent run finished.")
|
||||
logger.info("Agent run finished.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ An error occurred during the test: {e}")
|
||||
logger.error(f"An error occurred during the test: {str(e)}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# --- Cleanup ---
|
||||
print("\n🧹 Cleaning up test resources...")
|
||||
logger.info("Cleaning up test resources...")
|
||||
if thread_id:
|
||||
try:
|
||||
await client.table('messages').delete().eq('thread_id', thread_id).execute()
|
||||
await client.table('threads').delete().eq('thread_id', thread_id).execute()
|
||||
print(f"🗑️ Deleted test thread: {thread_id}")
|
||||
logger.info(f"Deleted test thread: {thread_id}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error cleaning up thread {thread_id}: {e}")
|
||||
logger.warning(f"Error cleaning up thread {thread_id}: {e}")
|
||||
if project_id and project_created: # Only delete if we created it in this run
|
||||
try:
|
||||
await client.table('projects').delete().eq('project_id', project_id).execute()
|
||||
print(f"🗑️ Deleted test project: {project_id}")
|
||||
logger.info(f"Deleted test project: {project_id}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error cleaning up project {project_id}: {e}")
|
||||
logger.warning(f"Error cleaning up project {project_id}: {e}")
|
||||
|
||||
# Disconnect DB
|
||||
await db_connection.disconnect()
|
||||
logger.info("Database connection closed.")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🏁 THINKING TEST COMPLETE")
|
||||
print("="*80 + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure the logger is configured
|
||||
logger.info("Starting test_agent_thinking script...")
|
||||
try:
|
||||
asyncio.run(test_agent_with_thinking())
|
||||
print("\n✅ Test script completed successfully.")
|
||||
sys.exit(0)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n❌ Test interrupted by user.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n\n❌ Error running test script: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
Loading…
Reference in New Issue