mirror of https://github.com/kortix-ai/suna.git
add stream option
This commit is contained in:
parent
36c8fce0f6
commit
94ee217e36
|
@ -23,6 +23,7 @@ class AgentStartRequest(BaseModel):
|
||||||
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
model_name: Optional[str] = "anthropic/claude-3-7-sonnet-latest"
|
||||||
enable_thinking: Optional[bool] = False
|
enable_thinking: Optional[bool] = False
|
||||||
reasoning_effort: Optional[str] = 'low'
|
reasoning_effort: Optional[str] = 'low'
|
||||||
|
stream: Optional[bool] = False # Add stream option, default to False
|
||||||
|
|
||||||
# Initialize shared resources
|
# Initialize shared resources
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -249,9 +250,9 @@ async def start_agent(
|
||||||
user_id: str = Depends(get_current_user_id)
|
user_id: str = Depends(get_current_user_id)
|
||||||
):
|
):
|
||||||
"""Start an agent for a specific thread in the background with dynamic settings."""
|
"""Start an agent for a specific thread in the background with dynamic settings."""
|
||||||
logger.info(f"Starting new agent for thread: {thread_id} with model: {body.model_name}, thinking: {body.enable_thinking}, effort: {body.reasoning_effort}")
|
logger.info(f"Starting new agent for thread: {thread_id} with model: {body.model_name}, thinking: {body.enable_thinking}, effort: {body.reasoning_effort}, stream: {body.stream}")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
|
||||||
# Verify user has access to this thread
|
# Verify user has access to this thread
|
||||||
await verify_thread_access(client, thread_id, user_id)
|
await verify_thread_access(client, thread_id, user_id)
|
||||||
|
|
||||||
|
@ -311,7 +312,8 @@ async def start_agent(
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
model_name=body.model_name,
|
model_name=body.model_name,
|
||||||
enable_thinking=body.enable_thinking,
|
enable_thinking=body.enable_thinking,
|
||||||
reasoning_effort=body.reasoning_effort
|
reasoning_effort=body.reasoning_effort,
|
||||||
|
stream=body.stream # Pass stream parameter
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -446,12 +448,13 @@ async def run_agent_background(
|
||||||
project_id: str,
|
project_id: str,
|
||||||
model_name: str, # Add model_name parameter
|
model_name: str, # Add model_name parameter
|
||||||
enable_thinking: Optional[bool], # Add enable_thinking parameter
|
enable_thinking: Optional[bool], # Add enable_thinking parameter
|
||||||
reasoning_effort: Optional[str] # Add reasoning_effort parameter
|
reasoning_effort: Optional[str], # Add reasoning_effort parameter
|
||||||
|
stream: bool # Add stream parameter
|
||||||
):
|
):
|
||||||
"""Run the agent in the background and handle status updates."""
|
"""Run the agent in the background and handle status updates."""
|
||||||
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model: {model_name}")
|
logger.debug(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (instance: {instance_id}) with model: {model_name}, stream: {stream}")
|
||||||
client = await db.client
|
client = await db.client
|
||||||
|
|
||||||
# Tracking variables
|
# Tracking variables
|
||||||
total_responses = 0
|
total_responses = 0
|
||||||
start_time = datetime.now(timezone.utc)
|
start_time = datetime.now(timezone.utc)
|
||||||
|
@ -570,7 +573,7 @@ async def run_agent_background(
|
||||||
agent_gen = run_agent(
|
agent_gen = run_agent(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
stream=True,
|
stream=stream, # Pass stream parameter from API request
|
||||||
thread_manager=thread_manager,
|
thread_manager=thread_manager,
|
||||||
model_name=model_name, # Pass model_name
|
model_name=model_name, # Pass model_name
|
||||||
enable_thinking=enable_thinking, # Pass enable_thinking
|
enable_thinking=enable_thinking, # Pass enable_thinking
|
||||||
|
|
|
@ -23,7 +23,7 @@ load_dotenv()
|
||||||
async def run_agent(
|
async def run_agent(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
stream: bool = True,
|
stream: bool, # Accept stream parameter from caller (api.py)
|
||||||
thread_manager: Optional[ThreadManager] = None,
|
thread_manager: Optional[ThreadManager] = None,
|
||||||
native_max_auto_continues: int = 25,
|
native_max_auto_continues: int = 25,
|
||||||
max_iterations: int = 150,
|
max_iterations: int = 150,
|
||||||
|
@ -180,7 +180,7 @@ async def run_agent(
|
||||||
response = await thread_manager.run_thread(
|
response = await thread_manager.run_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
system_prompt=system_message, # Pass the constructed message
|
system_prompt=system_message, # Pass the constructed message
|
||||||
stream=stream,
|
stream=stream, # Pass the stream parameter received by run_agent
|
||||||
llm_model=model_name, # Use the passed model_name
|
llm_model=model_name, # Use the passed model_name
|
||||||
llm_temperature=1, # Example temperature
|
llm_temperature=1, # Example temperature
|
||||||
llm_max_tokens=max_tokens, # Use the determined value
|
llm_max_tokens=max_tokens, # Use the determined value
|
||||||
|
@ -295,7 +295,8 @@ async def test_agent():
|
||||||
|
|
||||||
if not user_message.strip():
|
if not user_message.strip():
|
||||||
print("\n🔄 Running agent...\n")
|
print("\n🔄 Running agent...\n")
|
||||||
await process_agent_response(thread_id, project_id, thread_manager)
|
# Pass stream=True explicitly when calling from test_agent
|
||||||
|
await process_agent_response(thread_id, project_id, thread_manager, stream=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add the user message to the thread
|
# Add the user message to the thread
|
||||||
|
@ -310,19 +311,40 @@ async def test_agent():
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n🔄 Running agent...\n")
|
print("\n🔄 Running agent...\n")
|
||||||
await process_agent_response(thread_id, project_id, thread_manager)
|
# Pass stream=True explicitly when calling from test_agent
|
||||||
|
await process_agent_response(thread_id, project_id, thread_manager, stream=True)
|
||||||
|
|
||||||
print("\n👋 Test completed. Goodbye!")
|
print("\n👋 Test completed. Goodbye!")
|
||||||
|
|
||||||
async def process_agent_response(thread_id: str, project_id: str, thread_manager: ThreadManager):
|
async def process_agent_response(
|
||||||
"""Process the streaming response from the agent."""
|
thread_id: str,
|
||||||
|
project_id: str,
|
||||||
|
thread_manager: ThreadManager,
|
||||||
|
stream: bool = True, # Add stream parameter, default to True for testing
|
||||||
|
model_name: str = "anthropic/claude-3-7-sonnet-latest", # Add model_name with default
|
||||||
|
enable_thinking: Optional[bool] = False, # Add enable_thinking
|
||||||
|
reasoning_effort: Optional[str] = 'low' # Add reasoning_effort
|
||||||
|
):
|
||||||
|
"""Process the streaming response from the agent, passing model/thinking parameters."""
|
||||||
chunk_counter = 0
|
chunk_counter = 0
|
||||||
current_response = ""
|
current_response = ""
|
||||||
tool_call_counter = 0 # Track number of tool calls
|
tool_call_counter = 0 # Track number of tool calls
|
||||||
|
|
||||||
async for chunk in run_agent(thread_id=thread_id, project_id=project_id, stream=True, thread_manager=thread_manager, native_max_auto_continues=25):
|
# Pass the received parameters to the run_agent call
|
||||||
|
agent_generator = run_agent(
|
||||||
|
thread_id=thread_id,
|
||||||
|
project_id=project_id,
|
||||||
|
stream=stream, # Pass the stream parameter here
|
||||||
|
thread_manager=thread_manager,
|
||||||
|
native_max_auto_continues=25,
|
||||||
|
model_name=model_name,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
reasoning_effort=reasoning_effort
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in agent_generator:
|
||||||
chunk_counter += 1
|
chunk_counter += 1
|
||||||
|
|
||||||
if chunk.get('type') == 'content' and 'content' in chunk:
|
if chunk.get('type') == 'content' and 'content' in chunk:
|
||||||
current_response += chunk.get('content', '')
|
current_response += chunk.get('content', '')
|
||||||
# Print the response as it comes in
|
# Print the response as it comes in
|
||||||
|
|
Loading…
Reference in New Issue