refactor: consolidate duplicate account lookup logic

- Extracted get_account_id_from_thread() to auth_utils.py (new shared utility)
- Refactored 2 identical _get_current_account_id() methods:
  * tools/agent_builder_tools/base_tool.py (reduced 23 lines → 9 lines)
  * tools/sb_upload_file_tool.py (reduced 23 lines → 9 lines)
- Simplified agent_creation_tool.py's version for consistency
- Eliminated 28+ lines of duplicate code
- Centralized thread→account lookup logic in one place
This commit is contained in:
marko-kraemer 2025-10-04 22:45:09 +02:00
parent 5383897977
commit 8b7bc36d5f
4 changed files with 46 additions and 45 deletions

View File

@ -13,28 +13,15 @@ class AgentBuilderBaseTool(Tool):
self.agent_id = agent_id
async def _get_current_account_id(self) -> str:
try:
context_vars = structlog.contextvars.get_contextvars()
thread_id = context_vars.get('thread_id')
if not thread_id:
raise ValueError("No thread_id available from execution context")
client = await self.db.client
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
if not thread_result.data:
raise ValueError(f"Could not find thread with ID: {thread_id}")
account_id = thread_result.data[0]['account_id']
if not account_id:
raise ValueError("Thread has no associated account_id")
return account_id
except Exception as e:
logger.error(f"Error getting current account_id: {e}")
raise
"""Get account_id from current thread context."""
context_vars = structlog.contextvars.get_contextvars()
thread_id = context_vars.get('thread_id')
if not thread_id:
raise ValueError("No thread_id available from execution context")
from core.utils.auth_utils import get_account_id_from_thread
return await get_account_id_from_thread(thread_id, self.db)
async def _get_agent_data(self) -> Optional[dict]:
try:

View File

@ -15,7 +15,10 @@ class AgentCreationTool(Tool):
self.db = db_connection
self.account_id = account_id
async def _get_current_account_id(self) -> Optional[str]:
async def _get_current_account_id(self) -> str:
"""Get account_id (already provided in constructor)."""
if not self.account_id:
raise ValueError("No account_id available")
return self.account_id
async def _sync_workflows_to_version_config(self, agent_id: str) -> None:

View File

@ -137,28 +137,15 @@ class SandboxUploadFileTool(SandboxToolsBase):
return self.fail_response(f"Unexpected error during secure file upload: {str(e)}")
async def _get_current_account_id(self) -> str:
try:
context_vars = structlog.contextvars.get_contextvars()
thread_id = context_vars.get('thread_id')
if not thread_id:
raise ValueError("No thread_id available from execution context")
client = await self.db.client
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
if not thread_result.data:
raise ValueError(f"Could not find thread with ID: {thread_id}")
account_id = thread_result.data[0]['account_id']
if not account_id:
raise ValueError("Thread has no associated account_id")
return account_id
except Exception as e:
logger.error(f"Error getting current account_id: {e}")
raise
"""Get account_id from current thread context."""
context_vars = structlog.contextvars.get_contextvars()
thread_id = context_vars.get('thread_id')
if not thread_id:
raise ValueError("No thread_id available from execution context")
from core.utils.auth_utils import get_account_id_from_thread
return await get_account_id_from_thread(thread_id, self.db)
async def _track_upload(
self,

View File

@ -44,6 +44,30 @@ def _decode_jwt_safely(token: str) -> dict:
}
)
async def get_account_id_from_thread(thread_id: str, db: "DBConnection") -> str:
"""
Get account_id from thread_id.
Raises:
ValueError: If thread not found or has no account_id
"""
try:
client = await db.client
thread_result = await client.table('threads').select('account_id').eq('thread_id', thread_id).limit(1).execute()
if not thread_result.data:
raise ValueError(f"Could not find thread with ID: {thread_id}")
account_id = thread_result.data[0]['account_id']
if not account_id:
raise ValueError("Thread has no associated account_id")
return account_id
except Exception as e:
structlog.get_logger().error(f"Error getting account_id from thread: {e}")
raise
async def _get_user_id_from_account_cached(account_id: str) -> Optional[str]:
cache_key = f"account_user:{account_id}"