mirror of https://github.com/kortix-ai/suna.git
322 lines
12 KiB
Python
322 lines
12 KiB
Python
import json
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, HTTPException, Depends
|
|
from pydantic import BaseModel, Field
|
|
from utils.auth_utils import get_current_user_id_from_jwt
|
|
from services.supabase import DBConnection
|
|
from utils.logger import logger
|
|
from flags.flags import is_enabled
|
|
|
|
router = APIRouter(prefix="/knowledge-base", tags=["knowledge-base"])
|
|
|
|
class KnowledgeBaseEntry(BaseModel):
|
|
entry_id: Optional[str] = None
|
|
name: str = Field(..., min_length=1, max_length=255)
|
|
description: Optional[str] = None
|
|
content: str = Field(..., min_length=1)
|
|
usage_context: str = Field(default="always", pattern="^(always|on_request|contextual)$")
|
|
is_active: bool = True
|
|
|
|
class KnowledgeBaseEntryResponse(BaseModel):
|
|
entry_id: str
|
|
name: str
|
|
description: Optional[str]
|
|
content: str
|
|
usage_context: str
|
|
is_active: bool
|
|
content_tokens: Optional[int]
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
class KnowledgeBaseListResponse(BaseModel):
|
|
entries: List[KnowledgeBaseEntryResponse]
|
|
total_count: int
|
|
total_tokens: int
|
|
|
|
class CreateKnowledgeBaseEntryRequest(BaseModel):
|
|
name: str = Field(..., min_length=1, max_length=255)
|
|
description: Optional[str] = None
|
|
content: str = Field(..., min_length=1)
|
|
usage_context: str = Field(default="always", pattern="^(always|on_request|contextual)$")
|
|
|
|
class UpdateKnowledgeBaseEntryRequest(BaseModel):
|
|
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
|
description: Optional[str] = None
|
|
content: Optional[str] = Field(None, min_length=1)
|
|
usage_context: Optional[str] = Field(None, pattern="^(always|on_request|contextual)$")
|
|
is_active: Optional[bool] = None
|
|
|
|
db = DBConnection()
|
|
|
|
@router.get("/threads/{thread_id}", response_model=KnowledgeBaseListResponse)
|
|
async def get_thread_knowledge_base(
|
|
thread_id: str,
|
|
include_inactive: bool = False,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Get all knowledge base entries for a thread"""
|
|
try:
|
|
client = await db.client
|
|
|
|
thread_result = await client.table('threads').select('*').eq('thread_id', thread_id).execute()
|
|
if not thread_result.data:
|
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
result = await client.rpc('get_thread_knowledge_base', {
|
|
'p_thread_id': thread_id,
|
|
'p_include_inactive': include_inactive
|
|
}).execute()
|
|
|
|
entries = []
|
|
total_tokens = 0
|
|
|
|
for entry_data in result.data or []:
|
|
entry = KnowledgeBaseEntryResponse(
|
|
entry_id=entry_data['entry_id'],
|
|
name=entry_data['name'],
|
|
description=entry_data['description'],
|
|
content=entry_data['content'],
|
|
usage_context=entry_data['usage_context'],
|
|
is_active=entry_data['is_active'],
|
|
content_tokens=entry_data.get('content_tokens'),
|
|
created_at=entry_data['created_at'],
|
|
updated_at=entry_data.get('updated_at', entry_data['created_at'])
|
|
)
|
|
entries.append(entry)
|
|
total_tokens += entry_data.get('content_tokens', 0) or 0
|
|
|
|
return KnowledgeBaseListResponse(
|
|
entries=entries,
|
|
total_count=len(entries),
|
|
total_tokens=total_tokens
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting knowledge base for thread {thread_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve knowledge base")
|
|
|
|
@router.post("/threads/{thread_id}", response_model=KnowledgeBaseEntryResponse)
|
|
async def create_knowledge_base_entry(
|
|
thread_id: str,
|
|
entry_data: CreateKnowledgeBaseEntryRequest,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Create a new knowledge base entry for a thread"""
|
|
try:
|
|
client = await db.client
|
|
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).execute()
|
|
if not thread_result.data:
|
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
account_id = thread_result.data[0]['account_id']
|
|
|
|
insert_data = {
|
|
'thread_id': thread_id,
|
|
'account_id': account_id,
|
|
'name': entry_data.name,
|
|
'description': entry_data.description,
|
|
'content': entry_data.content,
|
|
'usage_context': entry_data.usage_context
|
|
}
|
|
|
|
result = await client.table('knowledge_base_entries').insert(insert_data).execute()
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=500, detail="Failed to create knowledge base entry")
|
|
|
|
created_entry = result.data[0]
|
|
|
|
return KnowledgeBaseEntryResponse(
|
|
entry_id=created_entry['entry_id'],
|
|
name=created_entry['name'],
|
|
description=created_entry['description'],
|
|
content=created_entry['content'],
|
|
usage_context=created_entry['usage_context'],
|
|
is_active=created_entry['is_active'],
|
|
content_tokens=created_entry.get('content_tokens'),
|
|
created_at=created_entry['created_at'],
|
|
updated_at=created_entry['updated_at']
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error creating knowledge base entry for thread {thread_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to create knowledge base entry")
|
|
|
|
@router.put("/{entry_id}", response_model=KnowledgeBaseEntryResponse)
|
|
async def update_knowledge_base_entry(
|
|
entry_id: str,
|
|
entry_data: UpdateKnowledgeBaseEntryRequest,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Update a knowledge base entry"""
|
|
try:
|
|
client = await db.client
|
|
entry_result = await client.table('knowledge_base_entries').select('*').eq('entry_id', entry_id).execute()
|
|
if not entry_result.data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base entry not found")
|
|
|
|
update_data = {}
|
|
if entry_data.name is not None:
|
|
update_data['name'] = entry_data.name
|
|
if entry_data.description is not None:
|
|
update_data['description'] = entry_data.description
|
|
if entry_data.content is not None:
|
|
update_data['content'] = entry_data.content
|
|
if entry_data.usage_context is not None:
|
|
update_data['usage_context'] = entry_data.usage_context
|
|
if entry_data.is_active is not None:
|
|
update_data['is_active'] = entry_data.is_active
|
|
|
|
if not update_data:
|
|
raise HTTPException(status_code=400, detail="No fields to update")
|
|
|
|
result = await client.table('knowledge_base_entries').update(update_data).eq('entry_id', entry_id).execute()
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=500, detail="Failed to update knowledge base entry")
|
|
|
|
updated_entry = result.data[0]
|
|
|
|
return KnowledgeBaseEntryResponse(
|
|
entry_id=updated_entry['entry_id'],
|
|
name=updated_entry['name'],
|
|
description=updated_entry['description'],
|
|
content=updated_entry['content'],
|
|
usage_context=updated_entry['usage_context'],
|
|
is_active=updated_entry['is_active'],
|
|
content_tokens=updated_entry.get('content_tokens'),
|
|
created_at=updated_entry['created_at'],
|
|
updated_at=updated_entry['updated_at']
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error updating knowledge base entry {entry_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to update knowledge base entry")
|
|
|
|
@router.delete("/{entry_id}")
|
|
async def delete_knowledge_base_entry(
|
|
entry_id: str,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Delete a knowledge base entry"""
|
|
try:
|
|
client = await db.client
|
|
entry_result = await client.table('knowledge_base_entries').select('entry_id').eq('entry_id', entry_id).execute()
|
|
if not entry_result.data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base entry not found")
|
|
|
|
result = await client.table('knowledge_base_entries').delete().eq('entry_id', entry_id).execute()
|
|
|
|
return {"message": "Knowledge base entry deleted successfully"}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting knowledge base entry {entry_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to delete knowledge base entry")
|
|
|
|
@router.get("/{entry_id}", response_model=KnowledgeBaseEntryResponse)
|
|
async def get_knowledge_base_entry(
|
|
entry_id: str,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
"""Get a specific knowledge base entry"""
|
|
try:
|
|
client = await db.client
|
|
result = await client.table('knowledge_base_entries').select('*').eq('entry_id', entry_id).execute()
|
|
|
|
if not result.data:
|
|
raise HTTPException(status_code=404, detail="Knowledge base entry not found")
|
|
|
|
entry = result.data[0]
|
|
|
|
return KnowledgeBaseEntryResponse(
|
|
entry_id=entry['entry_id'],
|
|
name=entry['name'],
|
|
description=entry['description'],
|
|
content=entry['content'],
|
|
usage_context=entry['usage_context'],
|
|
is_active=entry['is_active'],
|
|
content_tokens=entry.get('content_tokens'),
|
|
created_at=entry['created_at'],
|
|
updated_at=entry['updated_at']
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting knowledge base entry {entry_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve knowledge base entry")
|
|
|
|
@router.get("/threads/{thread_id}/context")
|
|
async def get_knowledge_base_context(
|
|
thread_id: str,
|
|
max_tokens: int = 4000,
|
|
user_id: str = Depends(get_current_user_id_from_jwt)
|
|
):
|
|
if not await is_enabled("knowledge_base"):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="This feature is not available at the moment."
|
|
)
|
|
|
|
"""Get knowledge base context for agent prompts"""
|
|
try:
|
|
client = await db.client
|
|
thread_result = await client.table('threads').select('thread_id').eq('thread_id', thread_id).execute()
|
|
if not thread_result.data:
|
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
result = await client.rpc('get_knowledge_base_context', {
|
|
'p_thread_id': thread_id,
|
|
'p_max_tokens': max_tokens
|
|
}).execute()
|
|
|
|
context = result.data if result.data else None
|
|
|
|
return {
|
|
"context": context,
|
|
"max_tokens": max_tokens,
|
|
"thread_id": thread_id
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error getting knowledge base context for thread {thread_id}: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve knowledge base context") |