mirror of https://github.com/kortix-ai/suna.git
refactor: consolidate duplicate account lookup logic
- Extracted get_account_id_from_thread() to auth_utils.py (new shared utility) - Refactored 2 identical _get_current_account_id() methods: * tools/agent_builder_tools/base_tool.py (reduced 23 lines → 9 lines) * tools/sb_upload_file_tool.py (reduced 23 lines → 9 lines) - Simplified agent_creation_tool.py's version for consistency - Eliminated 28+ lines of duplicate code - Centralized thread→account lookup logic in one place
This commit is contained in:
parent
5383897977
commit
8b7bc36d5f
|
@ -13,28 +13,15 @@ class AgentBuilderBaseTool(Tool):
|
|||
self.agent_id = agent_id
|
||||
|
||||
async def _get_current_account_id(self) -> str:
|
||||
try:
|
||||
context_vars = structlog.contextvars.get_contextvars()
|
||||
thread_id = context_vars.get('thread_id')
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError("No thread_id available from execution context")
|
||||
|
||||
client = await self.db.client
|
||||
|
||||
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
|
||||
if not thread_result.data:
|
||||
raise ValueError(f"Could not find thread with ID: {thread_id}")
|
||||
|
||||
account_id = thread_result.data[0]['account_id']
|
||||
if not account_id:
|
||||
raise ValueError("Thread has no associated account_id")
|
||||
|
||||
return account_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current account_id: {e}")
|
||||
raise
|
||||
"""Get account_id from current thread context."""
|
||||
context_vars = structlog.contextvars.get_contextvars()
|
||||
thread_id = context_vars.get('thread_id')
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError("No thread_id available from execution context")
|
||||
|
||||
from core.utils.auth_utils import get_account_id_from_thread
|
||||
return await get_account_id_from_thread(thread_id, self.db)
|
||||
|
||||
async def _get_agent_data(self) -> Optional[dict]:
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,10 @@ class AgentCreationTool(Tool):
|
|||
self.db = db_connection
|
||||
self.account_id = account_id
|
||||
|
||||
async def _get_current_account_id(self) -> Optional[str]:
|
||||
async def _get_current_account_id(self) -> str:
|
||||
"""Get account_id (already provided in constructor)."""
|
||||
if not self.account_id:
|
||||
raise ValueError("No account_id available")
|
||||
return self.account_id
|
||||
|
||||
async def _sync_workflows_to_version_config(self, agent_id: str) -> None:
|
||||
|
|
|
@ -137,28 +137,15 @@ class SandboxUploadFileTool(SandboxToolsBase):
|
|||
return self.fail_response(f"Unexpected error during secure file upload: {str(e)}")
|
||||
|
||||
async def _get_current_account_id(self) -> str:
|
||||
try:
|
||||
context_vars = structlog.contextvars.get_contextvars()
|
||||
thread_id = context_vars.get('thread_id')
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError("No thread_id available from execution context")
|
||||
|
||||
client = await self.db.client
|
||||
|
||||
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
|
||||
if not thread_result.data:
|
||||
raise ValueError(f"Could not find thread with ID: {thread_id}")
|
||||
|
||||
account_id = thread_result.data[0]['account_id']
|
||||
if not account_id:
|
||||
raise ValueError("Thread has no associated account_id")
|
||||
|
||||
return account_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current account_id: {e}")
|
||||
raise
|
||||
"""Get account_id from current thread context."""
|
||||
context_vars = structlog.contextvars.get_contextvars()
|
||||
thread_id = context_vars.get('thread_id')
|
||||
|
||||
if not thread_id:
|
||||
raise ValueError("No thread_id available from execution context")
|
||||
|
||||
from core.utils.auth_utils import get_account_id_from_thread
|
||||
return await get_account_id_from_thread(thread_id, self.db)
|
||||
|
||||
async def _track_upload(
|
||||
self,
|
||||
|
|
|
@ -44,6 +44,30 @@ def _decode_jwt_safely(token: str) -> dict:
|
|||
}
|
||||
)
|
||||
|
||||
async def get_account_id_from_thread(thread_id: str, db: "DBConnection") -> str:
|
||||
"""
|
||||
Get account_id from thread_id.
|
||||
|
||||
Raises:
|
||||
ValueError: If thread not found or has no account_id
|
||||
"""
|
||||
try:
|
||||
client = await db.client
|
||||
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
|
||||
|
||||
if not thread_result.data:
|
||||
raise ValueError(f"Could not find thread with ID: {thread_id}")
|
||||
|
||||
account_id = thread_result.data[0]['account_id']
|
||||
if not account_id:
|
||||
raise ValueError("Thread has no associated account_id")
|
||||
|
||||
return account_id
|
||||
except Exception as e:
|
||||
structlog.get_logger().error(f"Error getting account_id from thread: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
|
||||
cache_key = f"account_user:{account_id}"
|
||||
|
||||
|
|
Loading…
Reference in New Issue