mirror of https://github.com/kortix-ai/suna.git
Merge branch 'main' into feat/ux
This commit is contained in:
commit
3fd321df28
20
README.md
20
README.md
|
@ -86,6 +86,7 @@ Suna can be self-hosted on your own infrastructure. Follow these steps to set up
|
|||
You'll need the following components:
|
||||
- A Supabase project for database and authentication
|
||||
- Redis database for caching and session management
|
||||
- RabbitMQ message queue for orchestrating worker tasks
|
||||
- Daytona sandbox for secure agent execution
|
||||
- Python 3.11 for the API backend
|
||||
- API keys for LLM providers (Anthropic, OpenRouter)
|
||||
|
@ -99,9 +100,9 @@ You'll need the following components:
|
|||
- Save your project's API URL, anon key, and service role key for later use
|
||||
- Install the [Supabase CLI](https://supabase.com/docs/guides/cli/getting-started)
|
||||
|
||||
2. **Redis**:
|
||||
2. **Redis and RabbitMQ**:
|
||||
- Go to the `/backend` folder
|
||||
- Run `docker compose up redis`
|
||||
- Run `docker compose up redis rabbitmq`
|
||||
|
||||
3. **Daytona**:
|
||||
- Create an account on [Daytona](https://app.daytona.io/)
|
||||
|
@ -157,6 +158,9 @@ REDIS_PORT=6379
|
|||
REDIS_PASSWORD=your_redis_password
|
||||
REDIS_SSL=True # Set to False for local Redis without SSL
|
||||
|
||||
RABBITMQ_HOST=your_rabbitmq_host # Set to localhost if running locally
|
||||
RABBITMQ_PORT=5672
|
||||
|
||||
# Daytona credentials from step 3
|
||||
DAYTONA_API_KEY=your_daytona_api_key
|
||||
DAYTONA_SERVER_URL="https://app.daytona.io/api"
|
||||
|
@ -230,6 +234,12 @@ npm run dev
|
|||
```bash
|
||||
cd backend
|
||||
poetry run python3.11 api.py
|
||||
```
|
||||
|
||||
In one more terminal, start the backend worker:
|
||||
```bash
|
||||
cd backend
|
||||
poetry run python3.11 -m dramatiq run_agent_background
|
||||
```
|
||||
|
||||
5-6. **Docker Compose Alternative**:
|
||||
|
@ -237,12 +247,16 @@ poetry run python3.11 api.py
|
|||
Before running with Docker Compose, make sure your environment files are properly configured:
|
||||
- In `backend/.env`, set all the required environment variables as described above
|
||||
- For Redis configuration, use `REDIS_HOST=redis` instead of localhost
|
||||
- For RabbitMQ, use `RABBITMQ_HOST=rabbitmq` instead of localhost
|
||||
- The Docker Compose setup will automatically set these Redis environment variables:
|
||||
```
|
||||
REDIS_HOST=redis
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_SSL=False
|
||||
|
||||
RABBITMQ_HOST=rabbitmq
|
||||
RABBITMQ_PORT=5672
|
||||
```
|
||||
- In `frontend/.env.local`, make sure to set `NEXT_PUBLIC_BACKEND_URL="http://backend:8000/api"` to use the container name
|
||||
|
||||
|
@ -257,7 +271,7 @@ If you're building the images locally instead of using pre-built ones:
|
|||
docker compose up
|
||||
```
|
||||
|
||||
The Docker Compose setup includes a Redis service that will be used by the backend automatically.
|
||||
The Docker Compose setup includes Redis and RabbitMQ services that will be used by the backend automatically.
|
||||
|
||||
|
||||
7. **Access Suna**:
|
||||
|
|
|
@ -14,6 +14,9 @@ REDIS_PORT=6379
|
|||
REDIS_PASSWORD=
|
||||
REDIS_SSL=false
|
||||
|
||||
RABBITMQ_HOST=rabbitmq
|
||||
RABBITMQ_PORT=5672
|
||||
|
||||
# LLM Providers:
|
||||
ANTHROPIC_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
|
|
|
@ -11,25 +11,25 @@ docker compose down && docker compose up --build
|
|||
|
||||
You can run individual services from the docker-compose file. This is particularly useful during development:
|
||||
|
||||
### Running only Redis
|
||||
### Running only Redis and RabbitMQ
|
||||
```bash
|
||||
docker compose up redis
|
||||
docker compose up redis rabbitmq
|
||||
```
|
||||
|
||||
### Running only the API
|
||||
### Running only the API and Worker
|
||||
```bash
|
||||
docker compose up api
|
||||
docker compose up api worker
|
||||
```
|
||||
|
||||
## Development Setup
|
||||
|
||||
For local development, you might only need to run Redis while working on the API locally. This is useful when:
|
||||
For local development, you might only need to run Redis and RabbitMQ, while working on the API locally. This is useful when:
|
||||
- You're making changes to the API code and want to test them directly
|
||||
- You want to avoid rebuilding the API container on every change
|
||||
- You're running the API service directly on your machine
|
||||
|
||||
To run just Redis for development:```bash
|
||||
docker compose up redis
|
||||
To run just Redis and RabbitMQ for development:```bash
|
||||
docker compose up redis rabbitmq
|
||||
```
|
||||
|
||||
Then you can run your API service locally with your preferred method (e.g., poetry run python3.11 api.py).
|
||||
|
@ -38,16 +38,25 @@ Then you can run your API service locally with your preferred method (e.g., poet
|
|||
When running services individually, make sure to:
|
||||
1. Check your `.env` file and adjust any necessary environment variables
|
||||
2. Ensure Redis connection settings match your local setup (default: `localhost:6379`)
|
||||
3. Update any service-specific environment variables if needed
|
||||
3. Ensure RabbitMQ connection settings match your local setup (default: `localhost:5672`)
|
||||
4. Update any service-specific environment variables if needed
|
||||
|
||||
### Important: Redis Host Configuration
|
||||
When running the API locally with Redis in Docker, you need to set the correct Redis host in your `.env` file:
|
||||
- For Docker-to-Docker communication (when running both services in Docker): use `REDIS_HOST=redis`
|
||||
- For local-to-Docker communication (when running API locally): use `REDIS_HOST=localhost`
|
||||
|
||||
### Important: RabbitMQ Host Configuration
|
||||
When running the API locally with Redis in Docker, you need to set the correct RabbitMQ host in your `.env` file:
|
||||
- For Docker-to-Docker communication (when running both services in Docker): use `RABBITMQ_HOST=rabbitmq`
|
||||
- For local-to-Docker communication (when running API locally): use `RABBITMQ_HOST=localhost`
|
||||
|
||||
Example `.env` configuration for local development:
|
||||
```env
|
||||
REDIS_HOST=localhost (instead of 'redis')
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
|
||||
RABBITMQ_HOST=localhost (instead of 'rabbitmq')
|
||||
RABBITMQ_PORT=5672
|
||||
```
|
||||
|
|
|
@ -21,6 +21,7 @@ from services.billing import check_billing_status, can_use_model
|
|||
from utils.config import config
|
||||
from sandbox.sandbox import create_sandbox, get_or_start_sandbox
|
||||
from services.llm import make_llm_api_call
|
||||
from run_agent_background import run_agent_background, _cleanup_redis_response_list, update_agent_run_status
|
||||
|
||||
# Initialize shared resources
|
||||
router = APIRouter()
|
||||
|
@ -122,63 +123,6 @@ async def cleanup():
|
|||
await redis.close()
|
||||
logger.info("Completed cleanup of agent API resources")
|
||||
|
||||
async def update_agent_run_status(
|
||||
client,
|
||||
agent_run_id: str,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
responses: Optional[List[Any]] = None # Expects parsed list of dicts
|
||||
) -> bool:
|
||||
"""
|
||||
Centralized function to update agent run status.
|
||||
Returns True if update was successful.
|
||||
"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": status,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
if error:
|
||||
update_data["error"] = error
|
||||
|
||||
if responses:
|
||||
# Ensure responses are stored correctly as JSONB
|
||||
update_data["responses"] = responses
|
||||
|
||||
# Retry up to 3 times
|
||||
for retry in range(3):
|
||||
try:
|
||||
update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute()
|
||||
|
||||
if hasattr(update_result, 'data') and update_result.data:
|
||||
logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})")
|
||||
|
||||
# Verify the update
|
||||
verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute()
|
||||
if verify_result.data:
|
||||
actual_status = verify_result.data[0].get('status')
|
||||
completed_at = verify_result.data[0].get('completed_at')
|
||||
logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}")
|
||||
if retry == 2: # Last retry
|
||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
||||
return False
|
||||
except Exception as db_error:
|
||||
logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}")
|
||||
if retry < 2: # Not the last retry yet
|
||||
await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None):
|
||||
"""Update database and publish stop signal to Redis."""
|
||||
logger.info(f"Stopping agent run: {agent_run_id}")
|
||||
|
@ -239,16 +183,6 @@ async def stop_agent_run(agent_run_id: str, error_message: Optional[str] = None)
|
|||
|
||||
logger.info(f"Successfully initiated stop process for agent run: {agent_run_id}")
|
||||
|
||||
|
||||
async def _cleanup_redis_response_list(agent_run_id: str):
|
||||
"""Set TTL on the Redis response list."""
|
||||
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||
try:
|
||||
await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL)
|
||||
logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}")
|
||||
|
||||
# async def restore_running_agent_runs():
|
||||
# """Mark agent runs that were still 'running' in the database as failed and clean up Redis resources."""
|
||||
# logger.info("Restoring running agent runs after server restart")
|
||||
|
@ -307,20 +241,6 @@ async def get_agent_run_with_access_check(client, agent_run_id: str, user_id: st
|
|||
await verify_thread_access(client, thread_id, user_id)
|
||||
return agent_run_data
|
||||
|
||||
async def _cleanup_redis_instance_key(agent_run_id: str):
|
||||
"""Clean up the instance-specific Redis key for an agent run."""
|
||||
if not instance_id:
|
||||
logger.warning("Instance ID not set, cannot clean up instance key.")
|
||||
return
|
||||
key = f"active_run:{instance_id}:{agent_run_id}"
|
||||
logger.debug(f"Cleaning up Redis instance key: {key}")
|
||||
try:
|
||||
await redis.delete(key)
|
||||
logger.debug(f"Successfully cleaned up Redis key: {key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up Redis key {key}: {str(e)}")
|
||||
|
||||
|
||||
async def get_or_create_project_sandbox(client, project_id: str):
|
||||
"""Get or create a sandbox for a project."""
|
||||
project = await client.table('projects').select('*').eq('project_id', project_id).execute()
|
||||
|
@ -438,19 +358,15 @@ async def start_agent(
|
|||
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
||||
|
||||
# Run the agent in the background
|
||||
task = asyncio.create_task(
|
||||
run_agent_background(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id, sandbox=sandbox,
|
||||
model_name=model_name, # Already resolved above
|
||||
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
|
||||
stream=body.stream, enable_context_manager=body.enable_context_manager
|
||||
)
|
||||
run_agent_background.send(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
model_name=model_name, # Already resolved above
|
||||
enable_thinking=body.enable_thinking, reasoning_effort=body.reasoning_effort,
|
||||
stream=body.stream, enable_context_manager=body.enable_context_manager
|
||||
)
|
||||
|
||||
# Set a callback to clean up Redis instance key when task is done
|
||||
task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id)))
|
||||
|
||||
return {"agent_run_id": agent_run_id, "status": "running"}
|
||||
|
||||
@router.post("/agent-run/{agent_run_id}/stop")
|
||||
|
@ -671,187 +587,6 @@ async def stream_agent_run(
|
|||
"Access-Control-Allow-Origin": "*"
|
||||
})
|
||||
|
||||
async def run_agent_background(
|
||||
agent_run_id: str,
|
||||
thread_id: str,
|
||||
instance_id: str, # Use the global instance ID passed during initialization
|
||||
project_id: str,
|
||||
sandbox,
|
||||
model_name: str,
|
||||
enable_thinking: Optional[bool],
|
||||
reasoning_effort: Optional[str],
|
||||
stream: bool,
|
||||
enable_context_manager: bool
|
||||
):
|
||||
"""Run the agent in the background using Redis for state."""
|
||||
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
||||
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
||||
|
||||
client = await db.client
|
||||
start_time = datetime.now(timezone.utc)
|
||||
total_responses = 0
|
||||
pubsub = None
|
||||
stop_checker = None
|
||||
stop_signal_received = False
|
||||
|
||||
# Define Redis keys and channels
|
||||
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||
response_channel = f"agent_run:{agent_run_id}:new_response"
|
||||
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
|
||||
global_control_channel = f"agent_run:{agent_run_id}:control"
|
||||
instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
|
||||
|
||||
async def check_for_stop_signal():
|
||||
nonlocal stop_signal_received
|
||||
if not pubsub: return
|
||||
try:
|
||||
while not stop_signal_received:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
|
||||
if message and message.get("type") == "message":
|
||||
data = message.get("data")
|
||||
if isinstance(data, bytes): data = data.decode('utf-8')
|
||||
if data == "STOP":
|
||||
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
|
||||
stop_signal_received = True
|
||||
break
|
||||
# Periodically refresh the active run key TTL
|
||||
if total_responses % 50 == 0: # Refresh every 50 responses or so
|
||||
try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
|
||||
except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
|
||||
await asyncio.sleep(0.1) # Short sleep to prevent tight loop
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
||||
stop_signal_received = True # Stop the run if the checker fails
|
||||
|
||||
try:
|
||||
# Setup Pub/Sub listener for control signals
|
||||
pubsub = await redis.create_pubsub()
|
||||
await pubsub.subscribe(instance_control_channel, global_control_channel)
|
||||
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
||||
stop_checker = asyncio.create_task(check_for_stop_signal())
|
||||
|
||||
# Ensure active run key exists and has TTL
|
||||
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
|
||||
|
||||
# Initialize agent generator
|
||||
agent_gen = run_agent(
|
||||
thread_id=thread_id, project_id=project_id, stream=stream,
|
||||
thread_manager=thread_manager, model_name=model_name,
|
||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
||||
final_status = "running"
|
||||
error_message = None
|
||||
|
||||
async for response in agent_gen:
|
||||
if stop_signal_received:
|
||||
logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
||||
final_status = "stopped"
|
||||
break
|
||||
|
||||
# Store response in Redis list and publish notification
|
||||
response_json = json.dumps(response)
|
||||
await redis.rpush(response_list_key, response_json)
|
||||
await redis.publish(response_channel, "new")
|
||||
total_responses += 1
|
||||
|
||||
# Check for agent-signaled completion or error
|
||||
if response.get('type') == 'status':
|
||||
status_val = response.get('status')
|
||||
if status_val in ['completed', 'failed', 'stopped']:
|
||||
logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
|
||||
final_status = status_val
|
||||
if status_val == 'failed' or status_val == 'stopped':
|
||||
error_message = response.get('message', f"Run ended with status: {status_val}")
|
||||
break
|
||||
|
||||
# If loop finished without explicit completion/error/stop signal, mark as completed
|
||||
if final_status == "running":
|
||||
final_status = "completed"
|
||||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
|
||||
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
|
||||
await redis.rpush(response_list_key, json.dumps(completion_message))
|
||||
await redis.publish(response_channel, "new") # Notify about the completion message
|
||||
|
||||
# Fetch final responses from Redis for DB update
|
||||
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||
all_responses = [json.loads(r) for r in all_responses_json]
|
||||
|
||||
# Update DB status
|
||||
await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
|
||||
|
||||
# Publish final control signal (END_STREAM or ERROR)
|
||||
control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
|
||||
try:
|
||||
await redis.publish(global_control_channel, control_signal)
|
||||
# No need to publish to instance channel as the run is ending on this instance
|
||||
logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
traceback_str = traceback.format_exc()
|
||||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
|
||||
final_status = "failed"
|
||||
|
||||
# Push error message to Redis list
|
||||
error_response = {"type": "status", "status": "error", "message": error_message}
|
||||
try:
|
||||
await redis.rpush(response_list_key, json.dumps(error_response))
|
||||
await redis.publish(response_channel, "new")
|
||||
except Exception as redis_err:
|
||||
logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
|
||||
|
||||
# Fetch final responses (including the error)
|
||||
all_responses = []
|
||||
try:
|
||||
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||
all_responses = [json.loads(r) for r in all_responses_json]
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
|
||||
all_responses = [error_response] # Use the error message we tried to push
|
||||
|
||||
# Update DB status
|
||||
await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
|
||||
|
||||
# Publish ERROR signal
|
||||
try:
|
||||
await redis.publish(global_control_channel, "ERROR")
|
||||
logger.debug(f"Published ERROR signal to {global_control_channel}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish ERROR signal: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Cleanup stop checker task
|
||||
if stop_checker and not stop_checker.done():
|
||||
stop_checker.cancel()
|
||||
try: await stop_checker
|
||||
except asyncio.CancelledError: pass
|
||||
except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
|
||||
|
||||
# Close pubsub connection
|
||||
if pubsub:
|
||||
try:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.close()
|
||||
logger.debug(f"Closed pubsub connection for {agent_run_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
|
||||
|
||||
# Set TTL on the response list in Redis
|
||||
await _cleanup_redis_response_list(agent_run_id)
|
||||
|
||||
# Remove the instance-specific active run key
|
||||
await _cleanup_redis_instance_key(agent_run_id)
|
||||
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
|
||||
|
||||
async def generate_and_update_project_name(project_id: str, prompt: str):
|
||||
"""Generates a project name using an LLM and updates the database."""
|
||||
logger.info(f"Starting background task to generate name for project: {project_id}")
|
||||
|
@ -1044,16 +779,13 @@ async def initiate_agent_with_files(
|
|||
logger.warning(f"Failed to register agent run in Redis ({instance_key}): {str(e)}")
|
||||
|
||||
# Run agent in background
|
||||
task = asyncio.create_task(
|
||||
run_agent_background(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id, sandbox=sandbox,
|
||||
model_name=model_name, # Already resolved above
|
||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||
stream=stream, enable_context_manager=enable_context_manager
|
||||
)
|
||||
run_agent_background.send(
|
||||
agent_run_id=agent_run_id, thread_id=thread_id, instance_id=instance_id,
|
||||
project_id=project_id,
|
||||
model_name=model_name, # Already resolved above
|
||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||
stream=stream, enable_context_manager=enable_context_manager
|
||||
)
|
||||
task.add_done_callback(lambda _: asyncio.create_task(_cleanup_redis_instance_key(agent_run_id)))
|
||||
|
||||
return {"thread_id": thread_id, "agent_run_id": agent_run_id}
|
||||
|
||||
|
|
|
@ -147,6 +147,8 @@ class ResponseProcessor:
|
|||
if assist_start_msg_obj: yield assist_start_msg_obj
|
||||
# --- End Start Events ---
|
||||
|
||||
__sequence = 0
|
||||
|
||||
async for chunk in llm_response:
|
||||
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
@ -175,12 +177,14 @@ class ResponseProcessor:
|
|||
# Yield ONLY content chunk (don't save)
|
||||
now_chunk = datetime.now(timezone.utc).isoformat()
|
||||
yield {
|
||||
"sequence": __sequence,
|
||||
"message_id": None, "thread_id": thread_id, "type": "assistant",
|
||||
"is_llm_message": True,
|
||||
"content": json.dumps({"role": "assistant", "content": chunk_content}),
|
||||
"metadata": json.dumps({"stream_status": "chunk", "thread_run_id": thread_run_id}),
|
||||
"created_at": now_chunk, "updated_at": now_chunk
|
||||
}
|
||||
__sequence += 1
|
||||
else:
|
||||
logger.info("XML tool call limit reached - not yielding more content chunks")
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
version: '3.8'
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
api:
|
||||
|
@ -16,6 +16,8 @@ services:
|
|||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
environment:
|
||||
|
@ -23,6 +25,8 @@ services:
|
|||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=
|
||||
- LOG_LEVEL=INFO
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
- RABBITMQ_PORT=5672
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
|
@ -31,10 +35,10 @@ services:
|
|||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '14'
|
||||
cpus: "14"
|
||||
memory: 48G
|
||||
reservations:
|
||||
cpus: '8'
|
||||
cpus: "8"
|
||||
memory: 32G
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/api/health"]
|
||||
|
@ -43,6 +47,84 @@ services:
|
|||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
worker-1:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
command: python -m dramatiq run_agent_background
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- .:/app
|
||||
- ./worker-1-logs:/app/logs
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
environment:
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=
|
||||
- LOG_LEVEL=INFO
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
- RABBITMQ_PORT=5672
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "14"
|
||||
memory: 48G
|
||||
reservations:
|
||||
cpus: "8"
|
||||
memory: 32G
|
||||
|
||||
worker-2:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
command: python -m dramatiq run_agent_background
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- .:/app
|
||||
- ./worker-2-logs:/app/logs
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
environment:
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
- REDIS_PASSWORD=
|
||||
- LOG_LEVEL=INFO
|
||||
- RABBITMQ_HOST=rabbitmq
|
||||
- RABBITMQ_PORT=5672
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "14"
|
||||
memory: 48G
|
||||
reservations:
|
||||
cpus: "8"
|
||||
memory: 32G
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
ports:
|
||||
|
@ -67,10 +149,39 @@ services:
|
|||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2'
|
||||
cpus: "2"
|
||||
memory: 12G
|
||||
reservations:
|
||||
cpus: '1'
|
||||
cpus: "1"
|
||||
memory: 8G
|
||||
|
||||
rabbitmq:
|
||||
image: rabbitmq
|
||||
ports:
|
||||
- "127.0.0.1:5672:5672"
|
||||
volumes:
|
||||
- rabbitmq_data:/var/lib/rabbitmq
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- app-network
|
||||
healthcheck:
|
||||
test: ["CMD", "rabbitmq-diagnostics", "-q", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "2"
|
||||
memory: 12G
|
||||
reservations:
|
||||
cpus: "1"
|
||||
memory: 8G
|
||||
|
||||
networks:
|
||||
|
@ -78,4 +189,5 @@ networks:
|
|||
driver: bridge
|
||||
|
||||
volumes:
|
||||
redis_data:
|
||||
redis_data:
|
||||
rabbitmq_data:
|
||||
|
|
|
@ -31,4 +31,5 @@ vncdotool>=1.2.0
|
|||
pydantic
|
||||
tavily-python>=0.5.4
|
||||
pytesseract==0.3.13
|
||||
stripe>=7.0.0
|
||||
stripe>=7.0.0
|
||||
dramatiq[rabbitmq]>=1.17.1
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from services import redis
|
||||
from agent.run import run_agent
|
||||
from utils.logger import logger
|
||||
import dramatiq
|
||||
import uuid
|
||||
from agentpress.thread_manager import ThreadManager
|
||||
from services.supabase import DBConnection
|
||||
from services import redis
|
||||
from dramatiq.brokers.rabbitmq import RabbitmqBroker
|
||||
import os
|
||||
|
||||
rabbitmq_host = os.getenv('RABBITMQ_HOST', 'rabbitmq')
|
||||
rabbitmq_port = int(os.getenv('RABBITMQ_PORT', 5672))
|
||||
rabbitmq_broker = RabbitmqBroker(host=rabbitmq_host, port=rabbitmq_port, middleware=[dramatiq.middleware.AsyncIO()])
|
||||
dramatiq.set_broker(rabbitmq_broker)
|
||||
|
||||
_initialized = False
|
||||
db = DBConnection()
|
||||
thread_manager = None
|
||||
instance_id = "single"
|
||||
|
||||
async def initialize():
|
||||
"""Initialize the agent API with resources from the main API."""
|
||||
global thread_manager, db, instance_id, _initialized
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
# Use provided instance_id or generate a new one
|
||||
if not instance_id:
|
||||
# Generate instance ID
|
||||
instance_id = str(uuid.uuid4())[:8]
|
||||
await redis.initialize_async()
|
||||
await db.initialize()
|
||||
thread_manager = ThreadManager()
|
||||
|
||||
_initialized = True
|
||||
logger.info(f"Initialized agent API with instance ID: {instance_id}")
|
||||
|
||||
|
||||
@dramatiq.actor
|
||||
async def run_agent_background(
|
||||
agent_run_id: str,
|
||||
thread_id: str,
|
||||
instance_id: str, # Use the global instance ID passed during initialization
|
||||
project_id: str,
|
||||
model_name: str,
|
||||
enable_thinking: Optional[bool],
|
||||
reasoning_effort: Optional[str],
|
||||
stream: bool,
|
||||
enable_context_manager: bool
|
||||
):
|
||||
"""Run the agent in the background using Redis for state."""
|
||||
await initialize()
|
||||
|
||||
logger.info(f"Starting background agent run: {agent_run_id} for thread: {thread_id} (Instance: {instance_id})")
|
||||
logger.info(f"🚀 Using model: {model_name} (thinking: {enable_thinking}, reasoning_effort: {reasoning_effort})")
|
||||
|
||||
client = await db.client
|
||||
start_time = datetime.now(timezone.utc)
|
||||
total_responses = 0
|
||||
pubsub = None
|
||||
stop_checker = None
|
||||
stop_signal_received = False
|
||||
|
||||
# Define Redis keys and channels
|
||||
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||
response_channel = f"agent_run:{agent_run_id}:new_response"
|
||||
instance_control_channel = f"agent_run:{agent_run_id}:control:{instance_id}"
|
||||
global_control_channel = f"agent_run:{agent_run_id}:control"
|
||||
instance_active_key = f"active_run:{instance_id}:{agent_run_id}"
|
||||
|
||||
async def check_for_stop_signal():
|
||||
nonlocal stop_signal_received
|
||||
if not pubsub: return
|
||||
try:
|
||||
while not stop_signal_received:
|
||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=0.5)
|
||||
if message and message.get("type") == "message":
|
||||
data = message.get("data")
|
||||
if isinstance(data, bytes): data = data.decode('utf-8')
|
||||
if data == "STOP":
|
||||
logger.info(f"Received STOP signal for agent run {agent_run_id} (Instance: {instance_id})")
|
||||
stop_signal_received = True
|
||||
break
|
||||
# Periodically refresh the active run key TTL
|
||||
if total_responses % 50 == 0: # Refresh every 50 responses or so
|
||||
try: await redis.expire(instance_active_key, redis.REDIS_KEY_TTL)
|
||||
except Exception as ttl_err: logger.warning(f"Failed to refresh TTL for {instance_active_key}: {ttl_err}")
|
||||
await asyncio.sleep(0.1) # Short sleep to prevent tight loop
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Stop signal checker cancelled for {agent_run_id} (Instance: {instance_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop signal checker for {agent_run_id}: {e}", exc_info=True)
|
||||
stop_signal_received = True # Stop the run if the checker fails
|
||||
|
||||
try:
|
||||
# Setup Pub/Sub listener for control signals
|
||||
pubsub = await redis.create_pubsub()
|
||||
await pubsub.subscribe(instance_control_channel, global_control_channel)
|
||||
logger.debug(f"Subscribed to control channels: {instance_control_channel}, {global_control_channel}")
|
||||
stop_checker = asyncio.create_task(check_for_stop_signal())
|
||||
|
||||
# Ensure active run key exists and has TTL
|
||||
await redis.set(instance_active_key, "running", ex=redis.REDIS_KEY_TTL)
|
||||
|
||||
# Initialize agent generator
|
||||
agent_gen = run_agent(
|
||||
thread_id=thread_id, project_id=project_id, stream=stream,
|
||||
thread_manager=thread_manager, model_name=model_name,
|
||||
enable_thinking=enable_thinking, reasoning_effort=reasoning_effort,
|
||||
enable_context_manager=enable_context_manager
|
||||
)
|
||||
|
||||
final_status = "running"
|
||||
error_message = None
|
||||
|
||||
async for response in agent_gen:
|
||||
if stop_signal_received:
|
||||
logger.info(f"Agent run {agent_run_id} stopped by signal.")
|
||||
final_status = "stopped"
|
||||
break
|
||||
|
||||
# Store response in Redis list and publish notification
|
||||
response_json = json.dumps(response)
|
||||
asyncio.create_task(redis.rpush(response_list_key, response_json))
|
||||
asyncio.create_task(redis.publish(response_channel, "new"))
|
||||
total_responses += 1
|
||||
|
||||
# Check for agent-signaled completion or error
|
||||
if response.get('type') == 'status':
|
||||
status_val = response.get('status')
|
||||
if status_val in ['completed', 'failed', 'stopped']:
|
||||
logger.info(f"Agent run {agent_run_id} finished via status message: {status_val}")
|
||||
final_status = status_val
|
||||
if status_val == 'failed' or status_val == 'stopped':
|
||||
error_message = response.get('message', f"Run ended with status: {status_val}")
|
||||
break
|
||||
|
||||
# If loop finished without explicit completion/error/stop signal, mark as completed
|
||||
if final_status == "running":
|
||||
final_status = "completed"
|
||||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
logger.info(f"Agent run {agent_run_id} completed normally (duration: {duration:.2f}s, responses: {total_responses})")
|
||||
completion_message = {"type": "status", "status": "completed", "message": "Agent run completed successfully"}
|
||||
await redis.rpush(response_list_key, json.dumps(completion_message))
|
||||
await redis.publish(response_channel, "new") # Notify about the completion message
|
||||
|
||||
# Fetch final responses from Redis for DB update
|
||||
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||
all_responses = [json.loads(r) for r in all_responses_json]
|
||||
|
||||
# Update DB status
|
||||
await update_agent_run_status(client, agent_run_id, final_status, error=error_message, responses=all_responses)
|
||||
|
||||
# Publish final control signal (END_STREAM or ERROR)
|
||||
control_signal = "END_STREAM" if final_status == "completed" else "ERROR" if final_status == "failed" else "STOP"
|
||||
try:
|
||||
await redis.publish(global_control_channel, control_signal)
|
||||
# No need to publish to instance channel as the run is ending on this instance
|
||||
logger.debug(f"Published final control signal '{control_signal}' to {global_control_channel}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish final control signal {control_signal}: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
traceback_str = traceback.format_exc()
|
||||
duration = (datetime.now(timezone.utc) - start_time).total_seconds()
|
||||
logger.error(f"Error in agent run {agent_run_id} after {duration:.2f}s: {error_message}\n{traceback_str} (Instance: {instance_id})")
|
||||
final_status = "failed"
|
||||
|
||||
# Push error message to Redis list
|
||||
error_response = {"type": "status", "status": "error", "message": error_message}
|
||||
try:
|
||||
await redis.rpush(response_list_key, json.dumps(error_response))
|
||||
await redis.publish(response_channel, "new")
|
||||
except Exception as redis_err:
|
||||
logger.error(f"Failed to push error response to Redis for {agent_run_id}: {redis_err}")
|
||||
|
||||
# Fetch final responses (including the error)
|
||||
all_responses = []
|
||||
try:
|
||||
all_responses_json = await redis.lrange(response_list_key, 0, -1)
|
||||
all_responses = [json.loads(r) for r in all_responses_json]
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Failed to fetch responses from Redis after error for {agent_run_id}: {fetch_err}")
|
||||
all_responses = [error_response] # Use the error message we tried to push
|
||||
|
||||
# Update DB status
|
||||
await update_agent_run_status(client, agent_run_id, "failed", error=f"{error_message}\n{traceback_str}", responses=all_responses)
|
||||
|
||||
# Publish ERROR signal
|
||||
try:
|
||||
await redis.publish(global_control_channel, "ERROR")
|
||||
logger.debug(f"Published ERROR signal to {global_control_channel}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish ERROR signal: {str(e)}")
|
||||
|
||||
finally:
|
||||
# Cleanup stop checker task
|
||||
if stop_checker and not stop_checker.done():
|
||||
stop_checker.cancel()
|
||||
try: await stop_checker
|
||||
except asyncio.CancelledError: pass
|
||||
except Exception as e: logger.warning(f"Error during stop_checker cancellation: {e}")
|
||||
|
||||
# Close pubsub connection
|
||||
if pubsub:
|
||||
try:
|
||||
await pubsub.unsubscribe()
|
||||
await pubsub.close()
|
||||
logger.debug(f"Closed pubsub connection for {agent_run_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing pubsub for {agent_run_id}: {str(e)}")
|
||||
|
||||
# Set TTL on the response list in Redis
|
||||
await _cleanup_redis_response_list(agent_run_id)
|
||||
|
||||
# Remove the instance-specific active run key
|
||||
await _cleanup_redis_instance_key(agent_run_id)
|
||||
|
||||
logger.info(f"Agent run background task fully completed for: {agent_run_id} (Instance: {instance_id}) with final status: {final_status}")
|
||||
|
||||
|
||||
async def _cleanup_redis_instance_key(agent_run_id: str):
|
||||
"""Clean up the instance-specific Redis key for an agent run."""
|
||||
if not instance_id:
|
||||
logger.warning("Instance ID not set, cannot clean up instance key.")
|
||||
return
|
||||
key = f"active_run:{instance_id}:{agent_run_id}"
|
||||
logger.debug(f"Cleaning up Redis instance key: {key}")
|
||||
try:
|
||||
await redis.delete(key)
|
||||
logger.debug(f"Successfully cleaned up Redis key: {key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up Redis key {key}: {str(e)}")
|
||||
|
||||
# TTL for Redis response lists (24 hours)
|
||||
REDIS_RESPONSE_LIST_TTL = 3600 * 24
|
||||
|
||||
async def _cleanup_redis_response_list(agent_run_id: str):
|
||||
"""Set TTL on the Redis response list."""
|
||||
response_list_key = f"agent_run:{agent_run_id}:responses"
|
||||
try:
|
||||
await redis.expire(response_list_key, REDIS_RESPONSE_LIST_TTL)
|
||||
logger.debug(f"Set TTL ({REDIS_RESPONSE_LIST_TTL}s) on response list: {response_list_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set TTL on response list {response_list_key}: {str(e)}")
|
||||
|
||||
async def update_agent_run_status(
|
||||
client,
|
||||
agent_run_id: str,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
responses: Optional[list[any]] = None # Expects parsed list of dicts
|
||||
) -> bool:
|
||||
"""
|
||||
Centralized function to update agent run status.
|
||||
Returns True if update was successful.
|
||||
"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": status,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
if error:
|
||||
update_data["error"] = error
|
||||
|
||||
if responses:
|
||||
# Ensure responses are stored correctly as JSONB
|
||||
update_data["responses"] = responses
|
||||
|
||||
# Retry up to 3 times
|
||||
for retry in range(3):
|
||||
try:
|
||||
update_result = await client.table('agent_runs').update(update_data).eq("id", agent_run_id).execute()
|
||||
|
||||
if hasattr(update_result, 'data') and update_result.data:
|
||||
logger.info(f"Successfully updated agent run {agent_run_id} status to '{status}' (retry {retry})")
|
||||
|
||||
# Verify the update
|
||||
verify_result = await client.table('agent_runs').select('status', 'completed_at').eq("id", agent_run_id).execute()
|
||||
if verify_result.data:
|
||||
actual_status = verify_result.data[0].get('status')
|
||||
completed_at = verify_result.data[0].get('completed_at')
|
||||
logger.info(f"Verified agent run update: status={actual_status}, completed_at={completed_at}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Database update returned no data for agent run {agent_run_id} on retry {retry}: {update_result}")
|
||||
if retry == 2: # Last retry
|
||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}")
|
||||
return False
|
||||
except Exception as db_error:
|
||||
logger.error(f"Database error on retry {retry} updating status for {agent_run_id}: {str(db_error)}")
|
||||
if retry < 2: # Not the last retry yet
|
||||
await asyncio.sleep(0.5 * (2 ** retry)) # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Failed to update agent run status after all retries: {agent_run_id}", exc_info=True)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating agent run status for {agent_run_id}: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
return False
|
|
@ -9,6 +9,7 @@ export type ThreadParams = {
|
|||
|
||||
// Unified Message Interface matching the backend/database schema
|
||||
export interface UnifiedMessage {
|
||||
sequence?: number;
|
||||
message_id: string | null; // Can be null for transient stream events (chunks, unsaved statuses)
|
||||
thread_id: string;
|
||||
type: 'user' | 'assistant' | 'tool' | 'system' | 'status' | 'browser_state'; // Add 'system' if used
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { useState, useEffect, useRef, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
streamAgent,
|
||||
getAgentStatus,
|
||||
|
@ -72,7 +72,9 @@ export function useAgentStream(
|
|||
): UseAgentStreamResult {
|
||||
const [agentRunId, setAgentRunId] = useState<string | null>(null);
|
||||
const [status, setStatus] = useState<string>('idle');
|
||||
const [textContent, setTextContent] = useState<string>('');
|
||||
const [textContent, setTextContent] = useState<
|
||||
{ content: string; sequence?: number }[]
|
||||
>([]);
|
||||
const [toolCall, setToolCall] = useState<ParsedContent | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
|
@ -82,6 +84,12 @@ export function useAgentStream(
|
|||
const threadIdRef = useRef(threadId); // Ref to hold the current threadId
|
||||
const setMessagesRef = useRef(setMessages); // Ref to hold the setMessages function
|
||||
|
||||
const orderedTextContent = useMemo(() => {
|
||||
return textContent
|
||||
.sort((a, b) => a.sequence - b.sequence)
|
||||
.reduce((acc, curr) => acc + curr.content, '');
|
||||
}, [textContent]);
|
||||
|
||||
// Update refs if threadId or setMessages changes
|
||||
useEffect(() => {
|
||||
threadIdRef.current = threadId;
|
||||
|
@ -148,7 +156,7 @@ export function useAgentStream(
|
|||
}
|
||||
|
||||
// Reset streaming-specific state
|
||||
setTextContent('');
|
||||
setTextContent([]);
|
||||
setToolCall(null);
|
||||
|
||||
// Update status and clear run ID
|
||||
|
@ -292,10 +300,15 @@ export function useAgentStream(
|
|||
parsedMetadata.stream_status === 'chunk' &&
|
||||
parsedContent.content
|
||||
) {
|
||||
setTextContent((prev) => prev + parsedContent.content);
|
||||
setTextContent((prev) => {
|
||||
return prev.concat({
|
||||
sequence: message.sequence,
|
||||
content: parsedContent.content,
|
||||
});
|
||||
});
|
||||
callbacks.onAssistantChunk?.({ content: parsedContent.content });
|
||||
} else if (parsedMetadata.stream_status === 'complete') {
|
||||
setTextContent('');
|
||||
setTextContent([]);
|
||||
setToolCall(null);
|
||||
if (message.message_id) callbacks.onMessage(message);
|
||||
} else if (!parsedMetadata.stream_status) {
|
||||
|
@ -501,7 +514,7 @@ export function useAgentStream(
|
|||
}
|
||||
// Reset state on unmount if needed, though finalizeStream should handle most cases
|
||||
setStatus('idle');
|
||||
setTextContent('');
|
||||
setTextContent([]);
|
||||
setToolCall(null);
|
||||
setError(null);
|
||||
setAgentRunId(null);
|
||||
|
@ -528,7 +541,7 @@ export function useAgentStream(
|
|||
}
|
||||
|
||||
// Reset state before starting
|
||||
setTextContent('');
|
||||
setTextContent([]);
|
||||
setToolCall(null);
|
||||
setError(null);
|
||||
updateStatus('connecting');
|
||||
|
@ -616,7 +629,7 @@ export function useAgentStream(
|
|||
|
||||
return {
|
||||
status,
|
||||
textContent,
|
||||
textContent: orderedTextContent,
|
||||
toolCall,
|
||||
error,
|
||||
agentRunId,
|
||||
|
|
Loading…
Reference in New Issue