mirror of https://github.com/kortix-ai/suna.git
541 lines
19 KiB
Python
541 lines
19 KiB
Python
import json
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, BackgroundTasks
|
|
from pydantic import BaseModel, Field, HttpUrl
|
|
from utils.auth_utils import get_current_user_id_from_jwt, verify_agent_access
|
|
from services.supabase import DBConnection
|
|
from knowledge_base.file_processor import FileProcessor
|
|
from utils.logger import logger
|
|
from flags.flags import is_enabled
|
|
|
|
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
|
|
|
class KnowledgeBaseEntry(BaseModel):
|
|
entry_id: Optional[str] = None
|
|
name: str = Field(..., min_length=1, max_length=255)
|
|
description: Optional[str] = None
|
|
content: str = Field(..., min_length=1)
|
|
usage_context: str = Field(default="always", pattern="^(always|on_request|contextual)$")
|
|
is_active: bool = True
|
|
|
|
class KnowledgeBaseEntryResponse(BaseModel):
|
|
entry_id: str
|
|
name: str
|
|
description: Optional[str]
|
|
content: str
|
|
usage_context: str
|
|
is_active: bool
|
|
content_tokens: Optional[int]
|
|
created_at: str
|
|
updated_at: str
|
|
source_type: Optional[str] = None
|
|
source_metadata: Optional[dict] = None
|
|
file_size: Optional[int] = None
|
|
file_mime_type: Optional[str] = None
|
|
|
|
class KnowledgeBaseListResponse(BaseModel):
|
|
entries: List[KnowledgeBaseEntryResponse]
|
|
total_count: int
|
|
total_tokens: int
|
|
|
|
class CreateKnowledgeBaseEntryRequest(BaseModel):
|
|
name: str = Field(..., min_length=1, max_length=255)
|
|
description: Optional[str] = None
|
|
content: str = Field(..., min_length=1)
|
|
usage_context: str = Field(default="always", pattern="^(always|on_request|contextual)$")
|
|
|
|
class UpdateKnowledgeBaseEntryRequest(BaseModel):
|
|
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
|
description: Optional[str] = None
|
|
content: Optional[str] = Field(None, min_length=1)
|
|
usage_context: Optional[str] = Field(None, pattern="^(always|on_request|contextual)$")
|
|
is_active: Optional[bool] = None
|
|
|
|
class ProcessingJobResponse(BaseModel):
|
|
job_id: str
|
|
job_type: str
|
|
status: str
|
|
source_info: dict
|
|
result_info: dict
|
|
entries_created: int
|
|
total_files: int
|
|
created_at: str
|
|
completed_at: Optional[str]
|
|
error_message: Optional[str]
|
|
|
|
db = DBConnection()
|
|
|
|
|
|
@router.get("/agents/{agent_id}", response_model=KnowledgeBaseListResponse)
|
|
async def get_agent_knowledge_base(
|
|
agent_id: str,
|
|
include_inactive: bool = False,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Get all knowledge base entries for an agent"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Verify agent access
|
|
await verify_agent_access(client, agent_id, user_id)
|
|
|
|
result = await client.rpc('get_agent_knowledge_base', {
|
|
'p_agent_id': agent_id,
|
|
'p_include_inactive': include_inactive
|
|
}).execute()
|
|
|
|
entries = []
|
|
total_tokens = 0
|
|
|
|
for entry_data in result.data or []:
|
|
entry = KnowledgeBaseEntryResponse(
|
|
entry_id=entry_data['entry_id'],
|
|
name=entry_data['name'],
|
|
description=entry_data['description'],
|
|
content=entry_data['content'],
|
|
usage_context=entry_data['usage_context'],
|
|
is_active=entry_data['is_active'],
|
|
content_tokens=entry_data.get('content_tokens'),
|
|
created_at=entry_data['created_at'],
|
|
updated_at=entry_data.get('updated_at', entry_data['created_at']),
|
|
source_type=entry_data.get('source_type'),
|
|
source_metadata=entry_data.get('source_metadata'),
|
|
file_size=entry_data.get('file_size'),
|
|
file_mime_type=entry_data.get('file_mime_type')
|
|
)
|
|
entries.append(entry)
|
|
total_tokens += entry_data.get('content_tokens', 0) or 0
|
|
|
|
return KnowledgeBaseListResponse(
|
|
entries=entries,
|
|
total_count=len(entries),
|
|
total_tokens=total_tokens
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting knowledge base for agent {agent_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve agent knowledge base")
|
|
|
|
@router.post("/agents/{agent_id}", response_model=KnowledgeBaseEntryResponse)
|
|
async def create_agent_knowledge_base_entry(
|
|
agent_id: str,
|
|
entry_data: CreateKnowledgeBaseEntryRequest,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Create a new knowledge base entry for an agent"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Verify agent access and get agent data
|
|
agent_data = await verify_agent_access(client, agent_id, user_id)
|
|
account_id = agent_data['account_id']
|
|
|
|
insert_data = {
|
|
'agent_id': agent_id,
|
|
'account_id': account_id,
|
|
'name': entry_data.name,
|
|
'description': entry_data.description,
|
|
'content': entry_data.content,
|
|
'usage_context': entry_data.usage_context
|
|
}
|
|
|
|
result = await client.table('agent_knowledge_base_entries').insert(insert_data).execute()
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=500, detail="Failed to create agent knowledge base entry")
|
|
|
|
created_entry = result.data[0]
|
|
|
|
return KnowledgeBaseEntryResponse(
|
|
entry_id=created_entry['entry_id'],
|
|
name=created_entry['name'],
|
|
description=created_entry['description'],
|
|
content=created_entry['content'],
|
|
usage_context=created_entry['usage_context'],
|
|
is_active=created_entry['is_active'],
|
|
content_tokens=created_entry.get('content_tokens'),
|
|
created_at=created_entry['created_at'],
|
|
updated_at=created_entry['updated_at']
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error creating knowledge base entry for agent {agent_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to create agent knowledge base entry")
|
|
|
|
@router.post("/agents/{agent_id}/upload-file")
|
|
async def upload_file_to_agent_kb(
|
|
agent_id: str,
|
|
background_tasks: BackgroundTasks,
|
|
file: UploadFile = File(...),
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Upload and process a file for agent knowledge base"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Verify agent access and get agent data
|
|
agent_data = await verify_agent_access(client, agent_id, user_id)
|
|
account_id = agent_data['account_id']
|
|
|
|
file_content = await file.read()
|
|
job_id = await client.rpc('create_agent_kb_processing_job', {
|
|
'p_agent_id': agent_id,
|
|
'p_account_id': account_id,
|
|
'p_job_type': 'file_upload',
|
|
'p_source_info': {
|
|
'filename': file.filename,
|
|
'mime_type': file.content_type,
|
|
'file_size': len(file_content)
|
|
}
|
|
}).execute()
|
|
|
|
if not job_id.data:
|
|
raise HTTPException(status_code=500, detail="Failed to create processing job")
|
|
|
|
job_id = job_id.data
|
|
background_tasks.add_task(
|
|
process_file_background,
|
|
job_id,
|
|
agent_id,
|
|
account_id,
|
|
file_content,
|
|
file.filename,
|
|
file.content_type or 'application/octet-stream'
|
|
)
|
|
|
|
return {
|
|
"job_id": job_id,
|
|
"message": "File upload started. Processing in background.",
|
|
"filename": file.filename
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error uploading file to agent {agent_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to upload file")
|
|
|
|
|
|
@router.put("/{entry_id}", response_model=KnowledgeBaseEntryResponse)
|
|
async def update_knowledge_base_entry(
|
|
entry_id: str,
|
|
entry_data: UpdateKnowledgeBaseEntryRequest,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Update an agent knowledge base entry"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Get the entry and verify it exists in agent_knowledge_base_entries table
|
|
entry_result = await client.table('agent_knowledge_base_entries').select('*').eq('entry_id', entry_id).execute()
|
|
|
|
if not entry_result.data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base entry not found")
|
|
|
|
entry = entry_result.data[0]
|
|
agent_id = entry['agent_id']
|
|
|
|
# Verify agent access
|
|
await verify_agent_access(client, agent_id, user_id)
|
|
|
|
update_data = {}
|
|
if entry_data.name is not None:
|
|
update_data['name'] = entry_data.name
|
|
if entry_data.description is not None:
|
|
update_data['description'] = entry_data.description
|
|
if entry_data.content is not None:
|
|
update_data['content'] = entry_data.content
|
|
if entry_data.usage_context is not None:
|
|
update_data['usage_context'] = entry_data.usage_context
|
|
if entry_data.is_active is not None:
|
|
update_data['is_active'] = entry_data.is_active
|
|
|
|
if not update_data:
|
|
raise HTTPException(status_code=400, detail="No fields to update")
|
|
|
|
result = await client.table('agent_knowledge_base_entries').update(update_data).eq('entry_id', entry_id).execute()
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=500, detail="Failed to update knowledge base entry")
|
|
|
|
updated_entry = result.data[0]
|
|
|
|
logger.info(f"Updated agent knowledge base entry {entry_id} for agent {agent_id}")
|
|
|
|
return KnowledgeBaseEntryResponse(
|
|
entry_id=updated_entry['entry_id'],
|
|
name=updated_entry['name'],
|
|
description=updated_entry['description'],
|
|
content=updated_entry['content'],
|
|
usage_context=updated_entry['usage_context'],
|
|
is_active=updated_entry['is_active'],
|
|
content_tokens=updated_entry.get('content_tokens'),
|
|
created_at=updated_entry['created_at'],
|
|
updated_at=updated_entry['updated_at'],
|
|
source_type=updated_entry.get('source_type'),
|
|
source_metadata=updated_entry.get('source_metadata'),
|
|
file_size=updated_entry.get('file_size'),
|
|
file_mime_type=updated_entry.get('file_mime_type')
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error updating knowledge base entry {entry_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to update knowledge base entry")
|
|
|
|
@router.delete("/{entry_id}")
|
|
async def delete_knowledge_base_entry(
|
|
entry_id: str,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Delete an agent knowledge base entry"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Get the entry and verify it exists in agent_knowledge_base_entries table
|
|
entry_result = await client.table('agent_knowledge_base_entries').select('entry_id, agent_id').eq('entry_id', entry_id).execute()
|
|
|
|
if not entry_result.data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base entry not found")
|
|
|
|
entry = entry_result.data[0]
|
|
agent_id = entry['agent_id']
|
|
|
|
# Verify agent access
|
|
await verify_agent_access(client, agent_id, user_id)
|
|
|
|
result = await client.table('agent_knowledge_base_entries').delete().eq('entry_id', entry_id).execute()
|
|
|
|
logger.info(f"Deleted agent knowledge base entry {entry_id} for agent {agent_id}")
|
|
|
|
return {"message": "Knowledge base entry deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting knowledge base entry {entry_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to delete knowledge base entry")
|
|
|
|
|
|
@router.get("/{entry_id}", response_model=KnowledgeBaseEntryResponse)
|
|
async def get_knowledge_base_entry(
|
|
entry_id: str,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
"""Get a specific agent knowledge base entry"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Get the entry from agent_knowledge_base_entries table only
|
|
result = await client.table('agent_knowledge_base_entries').select('*').eq('entry_id', entry_id).execute()
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base entry not found")
|
|
|
|
entry = result.data[0]
|
|
agent_id = entry['agent_id']
|
|
|
|
# Verify agent access
|
|
await verify_agent_access(client, agent_id, user_id)
|
|
|
|
logger.info(f"Retrieved agent knowledge base entry {entry_id} for agent {agent_id}")
|
|
|
|
return KnowledgeBaseEntryResponse(
|
|
entry_id=entry['entry_id'],
|
|
name=entry['name'],
|
|
description=entry['description'],
|
|
content=entry['content'],
|
|
usage_context=entry['usage_context'],
|
|
is_active=entry['is_active'],
|
|
content_tokens=entry.get('content_tokens'),
|
|
created_at=entry['created_at'],
|
|
updated_at=entry['updated_at'],
|
|
source_type=entry.get('source_type'),
|
|
source_metadata=entry.get('source_metadata'),
|
|
file_size=entry.get('file_size'),
|
|
file_mime_type=entry.get('file_mime_type')
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting knowledge base entry {entry_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve knowledge base entry")
|
|
|
|
|
|
@router.get("/agents/{agent_id}/processing-jobs", response_model=List[ProcessingJobResponse])
|
|
async def get_agent_processing_jobs(
|
|
agent_id: str,
|
|
limit: int = 10,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Get processing jobs for an agent"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Verify agent access
|
|
await verify_agent_access(client, agent_id, user_id)
|
|
|
|
result = await client.rpc('get_agent_kb_processing_jobs', {
|
|
'p_agent_id': agent_id,
|
|
'p_limit': limit
|
|
}).execute()
|
|
|
|
jobs = []
|
|
for job_data in result.data or []:
|
|
job = ProcessingJobResponse(
|
|
job_id=job_data['job_id'],
|
|
job_type=job_data['job_type'],
|
|
status=job_data['status'],
|
|
source_info=job_data['source_info'],
|
|
result_info=job_data['result_info'],
|
|
entries_created=job_data['entries_created'],
|
|
total_files=job_data['total_files'],
|
|
created_at=job_data['created_at'],
|
|
completed_at=job_data.get('completed_at'),
|
|
error_message=job_data.get('error_message')
|
|
)
|
|
jobs.append(job)
|
|
|
|
return jobs
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting processing jobs for agent {agent_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to get processing jobs")
|
|
|
|
async def process_file_background(
|
|
job_id: str,
|
|
agent_id: str,
|
|
account_id: str,
|
|
file_content: bytes,
|
|
filename: str,
|
|
mime_type: str
|
|
):
|
|
"""Background task to process uploaded files"""
|
|
|
|
processor = FileProcessor()
|
|
client = await processor.db.client
|
|
try:
|
|
await client.rpc('update_agent_kb_job_status', {
|
|
'p_job_id': job_id,
|
|
'p_status': 'processing'
|
|
}).execute()
|
|
|
|
result = await processor.process_file_upload(
|
|
agent_id, account_id, file_content, filename, mime_type
|
|
)
|
|
|
|
if result['success']:
|
|
await client.rpc('update_agent_kb_job_status', {
|
|
'p_job_id': job_id,
|
|
'p_status': 'completed',
|
|
'p_result_info': result,
|
|
'p_entries_created': 1,
|
|
'p_total_files': 1
|
|
}).execute()
|
|
else:
|
|
await client.rpc('update_agent_kb_job_status', {
|
|
'p_job_id': job_id,
|
|
'p_status': 'failed',
|
|
'p_error_message': result.get('error', 'Unknown error')
|
|
}).execute()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in background file processing for job {job_id}: {str(e)}")
|
|
try:
|
|
await client.rpc('update_agent_kb_job_status', {
|
|
'p_job_id': job_id,
|
|
'p_status': 'failed',
|
|
'p_error_message': str(e)
|
|
}).execute()
|
|
except:
|
|
pass
|
|
|
|
|
|
@router.get("/agents/{agent_id}/context")
|
|
async def get_agent_knowledge_base_context(
|
|
agent_id: str,
|
|
max_tokens: int = 4000,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Get knowledge base context for agent prompts"""
|
|
try:
|
|
client = await db.client
|
|
|
|
# Verify agent access
|
|
await verify_agent_access(client, agent_id, user_id)
|
|
|
|
result = await client.rpc('get_agent_knowledge_base_context', {
|
|
'p_agent_id': agent_id,
|
|
'p_max_tokens': max_tokens
|
|
}).execute()
|
|
|
|
context = result.data if result.data else None
|
|
|
|
return {
|
|
"context": context,
|
|
"max_tokens": max_tokens,
|
|
"agent_id": agent_id
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting knowledge base context for agent {agent_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve agent knowledge base context")
|
|
|