suna/backend/knowledge_base/file_processor.py

607 lines
24 KiB
Python

import os
import io
import zipfile
import tempfile
import shutil
import asyncio
import subprocess
import re
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import mimetypes
import chardet
import PyPDF2
import docx
import openpyxl
import csv
import json
import yaml
import xml.etree.ElementTree as ET
from PIL import Image
import pytesseract
from utils.logger import logger
from services.supabase import DBConnection
class FileProcessor:
"""Handles file upload, content extraction, and processing for agent knowledge bases."""
SUPPORTED_TEXT_EXTENSIONS = {
'.txt', '.md', '.py', '.js', '.ts', '.html', '.css', '.json', '.yaml', '.yml',
'.xml', '.csv', '.sql', '.sh', '.bat', '.ps1', '.dockerfile', '.gitignore',
'.env', '.ini', '.cfg', '.conf', '.log', '.rst', '.toml', '.lock'
}
SUPPORTED_DOCUMENT_EXTENSIONS = {
'.pdf', '.docx', '.xlsx', '.pptx'
}
SUPPORTED_IMAGE_EXTENSIONS = {
'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'
}
MAX_FILE_SIZE = 50 * 1024 * 1024
MAX_ZIP_ENTRIES = 1000
MAX_CONTENT_LENGTH = 100000
def __init__(self):
self.db = DBConnection()
async def process_file_upload(
self,
agent_id: str,
account_id: str,
file_content: bytes,
filename: str,
mime_type: str
) -> Dict[str, Any]:
"""Process a single uploaded file and extract its content."""
try:
file_size = len(file_content)
if file_size > self.MAX_FILE_SIZE:
raise ValueError(f"File too large: {file_size} bytes (max: {self.MAX_FILE_SIZE})")
file_extension = Path(filename).suffix.lower()
if file_extension == '.zip':
return await self._process_zip_file(agent_id, account_id, file_content, filename)
content = await self._extract_file_content(file_content, filename, mime_type)
if not content or not content.strip():
raise ValueError(f"No extractable content found in {filename}")
client = await self.db.client
entry_data = {
'agent_id': agent_id,
'account_id': account_id,
'name': f"📄 {filename}",
'description': f"Content extracted from uploaded file: {filename}",
'content': content[:self.MAX_CONTENT_LENGTH],
'source_type': 'file',
'source_metadata': {
'filename': filename,
'mime_type': mime_type,
'file_size': file_size,
'extraction_method': self._get_extraction_method(file_extension, mime_type)
},
'file_size': file_size,
'file_mime_type': mime_type,
'usage_context': 'always',
'is_active': True
}
result = await client.table('agent_knowledge_base_entries').insert(entry_data).execute()
if not result.data:
raise Exception("Failed to create knowledge base entry")
return {
'success': True,
'entry_id': result.data[0]['entry_id'],
'filename': filename,
'content_length': len(content),
'extraction_method': entry_data['source_metadata']['extraction_method']
}
except Exception as e:
logger.error(f"Error processing file {filename}: {str(e)}")
return {
'success': False,
'filename': filename,
'error': str(e)
}
async def _process_zip_file(
self,
agent_id: str,
account_id: str,
zip_content: bytes,
zip_filename: str
) -> Dict[str, Any]:
"""Extract and process all files from a ZIP archive."""
try:
client = await self.db.client
zip_entry_data = {
'agent_id': agent_id,
'account_id': account_id,
'name': f"📦 {zip_filename}",
'description': f"ZIP archive: {zip_filename}",
'content': f"ZIP archive containing multiple files. Extracted files will appear as separate entries.",
'source_type': 'file',
'source_metadata': {
'filename': zip_filename,
'mime_type': 'application/zip',
'file_size': len(zip_content),
'is_zip_container': True
},
'file_size': len(zip_content),
'file_mime_type': 'application/zip',
'usage_context': 'always',
'is_active': True
}
zip_result = await client.table('agent_knowledge_base_entries').insert(zip_entry_data).execute()
zip_entry_id = zip_result.data[0]['entry_id']
# Extract files from ZIP
extracted_files = []
failed_files = []
with zipfile.ZipFile(io.BytesIO(zip_content), 'r') as zip_ref:
file_list = zip_ref.namelist()
if len(file_list) > self.MAX_ZIP_ENTRIES:
raise ValueError(f"ZIP contains too many files: {len(file_list)} (max: {self.MAX_ZIP_ENTRIES})")
for file_path in file_list:
if file_path.endswith('/'):
continue
try:
file_content = zip_ref.read(file_path)
filename = os.path.basename(file_path)
if not filename: # Skip if no filename
continue
# Detect MIME type
mime_type, _ = mimetypes.guess_type(filename)
if not mime_type:
mime_type = 'application/octet-stream'
# Extract content
content = await self._extract_file_content(file_content, filename, mime_type)
if content and content.strip():
extracted_entry_data = {
'agent_id': agent_id,
'account_id': account_id,
'name': f"📄 {filename}",
'description': f"Extracted from {zip_filename}: {file_path}",
'content': content[:self.MAX_CONTENT_LENGTH],
'source_type': 'zip_extracted',
'source_metadata': {
'filename': filename,
'original_path': file_path,
'zip_filename': zip_filename,
'mime_type': mime_type,
'file_size': len(file_content),
'extraction_method': self._get_extraction_method(Path(filename).suffix.lower(), mime_type)
},
'file_size': len(file_content),
'file_mime_type': mime_type,
'extracted_from_zip_id': zip_entry_id,
'usage_context': 'always',
'is_active': True
}
extracted_result = await client.table('agent_knowledge_base_entries').insert(extracted_entry_data).execute()
extracted_files.append({
'filename': filename,
'path': file_path,
'entry_id': extracted_result.data[0]['entry_id'],
'content_length': len(content)
})
except Exception as e:
logger.error(f"Error extracting {file_path} from ZIP: {str(e)}")
failed_files.append({
'filename': os.path.basename(file_path),
'path': file_path,
'error': str(e)
})
return {
'success': True,
'zip_entry_id': zip_entry_id,
'zip_filename': zip_filename,
'extracted_files': extracted_files,
'failed_files': failed_files,
'total_extracted': len(extracted_files),
'total_failed': len(failed_files)
}
except Exception as e:
logger.error(f"Error processing ZIP file {zip_filename}: {str(e)}")
return {
'success': False,
'zip_filename': zip_filename,
'error': str(e)
}
async def process_git_repository(
self,
agent_id: str,
account_id: str,
git_url: str,
branch: str = 'main',
include_patterns: List[str] = None,
exclude_patterns: List[str] = None
) -> Dict[str, Any]:
"""Clone a Git repository and extract content from supported files."""
if include_patterns is None:
include_patterns = ['*.py', '*.js', '*.ts', '*.md', '*.txt', '*.json', '*.yaml', '*.yml']
if exclude_patterns is None:
exclude_patterns = ['node_modules/*', '.git/*', '*.pyc', '__pycache__/*', '.env', '*.log']
temp_dir = None
try:
# Create temporary directory
temp_dir = tempfile.mkdtemp()
# Clone repository
clone_cmd = ['git', 'clone', '--depth', '1', '--branch', branch, git_url, temp_dir]
process = await asyncio.create_subprocess_exec(
*clone_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise Exception(f"Git clone failed: {stderr.decode()}")
# Create main repository entry
client = await self.db.client
repo_name = git_url.split('/')[-1].replace('.git', '')
repo_entry_data = {
'agent_id': agent_id,
'account_id': account_id,
'name': f"🔗 {repo_name}",
'description': f"Git repository: {git_url} (branch: {branch})",
'content': f"Git repository cloned from {git_url}. Individual files are processed as separate entries.",
'source_type': 'git_repo',
'source_metadata': {
'git_url': git_url,
'branch': branch,
'include_patterns': include_patterns,
'exclude_patterns': exclude_patterns
},
'usage_context': 'always',
'is_active': True
}
repo_result = await client.table('agent_knowledge_base_entries').insert(repo_entry_data).execute()
repo_entry_id = repo_result.data[0]['entry_id']
# Process files in repository
processed_files = []
failed_files = []
for root, dirs, files in os.walk(temp_dir):
# Skip .git directory
if '.git' in dirs:
dirs.remove('.git')
for file in files:
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, temp_dir)
# Check if file should be included
if not self._should_include_file(relative_path, include_patterns, exclude_patterns):
continue
try:
with open(file_path, 'rb') as f:
file_content = f.read()
if len(file_content) > self.MAX_FILE_SIZE:
continue # Skip large files
# Detect MIME type
mime_type, _ = mimetypes.guess_type(file)
if not mime_type:
mime_type = 'application/octet-stream'
# Extract content
content = await self._extract_file_content(file_content, file, mime_type)
if content and content.strip():
# Create entry for file
file_entry_data = {
'agent_id': agent_id,
'account_id': account_id,
'name': f"📄 {file}",
'description': f"From {repo_name}: {relative_path}",
'content': content[:self.MAX_CONTENT_LENGTH],
'source_type': 'git_repo',
'source_metadata': {
'filename': file,
'relative_path': relative_path,
'git_url': git_url,
'branch': branch,
'repo_name': repo_name,
'mime_type': mime_type,
'file_size': len(file_content),
'extraction_method': self._get_extraction_method(Path(file).suffix.lower(), mime_type)
},
'file_size': len(file_content),
'file_mime_type': mime_type,
'extracted_from_zip_id': repo_entry_id, # Reuse this field for git repo reference
'usage_context': 'always',
'is_active': True
}
file_result = await client.table('agent_knowledge_base_entries').insert(file_entry_data).execute()
processed_files.append({
'filename': file,
'relative_path': relative_path,
'entry_id': file_result.data[0]['entry_id'],
'content_length': len(content)
})
except Exception as e:
logger.error(f"Error processing {relative_path} from git repo: {str(e)}")
failed_files.append({
'filename': file,
'relative_path': relative_path,
'error': str(e)
})
return {
'success': True,
'repo_entry_id': repo_entry_id,
'repo_name': repo_name,
'git_url': git_url,
'branch': branch,
'processed_files': processed_files,
'failed_files': failed_files,
'total_processed': len(processed_files),
'total_failed': len(failed_files)
}
except Exception as e:
logger.error(f"Error processing git repository {git_url}: {str(e)}")
return {
'success': False,
'git_url': git_url,
'error': str(e)
}
finally:
# Clean up temporary directory
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
async def _extract_file_content(self, file_content: bytes, filename: str, mime_type: str) -> str:
"""Extract text content from various file types."""
file_extension = Path(filename).suffix.lower()
try:
# Text files
if file_extension in self.SUPPORTED_TEXT_EXTENSIONS or mime_type.startswith('text/'):
return self._extract_text_content(file_content)
# PDF files
elif file_extension == '.pdf':
return self._extract_pdf_content(file_content)
# Word documents
elif file_extension == '.docx':
return self._extract_docx_content(file_content)
# Excel files
elif file_extension == '.xlsx':
return self._extract_xlsx_content(file_content)
# Images (OCR)
elif file_extension in self.SUPPORTED_IMAGE_EXTENSIONS:
return self._extract_image_content(file_content)
# JSON files
elif file_extension == '.json':
return self._extract_json_content(file_content)
# YAML files
elif file_extension in {'.yaml', '.yml'}:
return self._extract_yaml_content(file_content)
# XML files
elif file_extension == '.xml':
return self._extract_xml_content(file_content)
# CSV files
elif file_extension == '.csv':
return self._extract_csv_content(file_content)
else:
# Try to extract as text if possible
return self._extract_text_content(file_content)
except Exception as e:
logger.error(f"Error extracting content from {filename}: {str(e)}")
return f"Error extracting content: {str(e)}"
def _extract_text_content(self, file_content: bytes) -> str:
"""Extract content from text files with encoding detection."""
detected = chardet.detect(file_content)
encoding = detected.get('encoding', 'utf-8')
try:
raw_text = file_content.decode(encoding)
except UnicodeDecodeError:
raw_text = file_content.decode('utf-8', errors='replace')
return self._sanitize_content(raw_text)
def _extract_pdf_content(self, file_content: bytes) -> str:
"""Extract text from PDF files."""
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file_content))
text_content = []
for page in pdf_reader.pages:
text_content.append(page.extract_text())
raw_text = '\n\n'.join(text_content)
return self._sanitize_content(raw_text)
def _extract_docx_content(self, file_content: bytes) -> str:
"""Extract text from Word documents."""
doc = docx.Document(io.BytesIO(file_content))
text_content = []
for paragraph in doc.paragraphs:
text_content.append(paragraph.text)
raw_text = '\n'.join(text_content)
return self._sanitize_content(raw_text)
def _extract_xlsx_content(self, file_content: bytes) -> str:
"""Extract text from Excel files."""
workbook = openpyxl.load_workbook(io.BytesIO(file_content))
text_content = []
for sheet_name in workbook.sheetnames:
sheet = workbook[sheet_name]
text_content.append(f"Sheet: {sheet_name}")
for row in sheet.iter_rows(values_only=True):
row_text = [str(cell) if cell is not None else '' for cell in row]
if any(row_text):
text_content.append('\t'.join(row_text))
raw_text = '\n'.join(text_content)
return self._sanitize_content(raw_text)
def _extract_image_content(self, file_content: bytes) -> str:
"""Extract text from images using OCR."""
try:
image = Image.open(io.BytesIO(file_content))
raw_text = pytesseract.image_to_string(image)
return self._sanitize_content(raw_text)
except Exception as e:
return f"OCR extraction failed: {str(e)}"
def _extract_json_content(self, file_content: bytes) -> str:
"""Extract and format JSON content."""
text = self._extract_text_content(file_content)
try:
parsed = json.loads(text)
formatted = json.dumps(parsed, indent=2)
return self._sanitize_content(formatted)
except json.JSONDecodeError:
return self._sanitize_content(text)
def _extract_yaml_content(self, file_content: bytes) -> str:
"""Extract and format YAML content."""
text = self._extract_text_content(file_content)
try:
parsed = yaml.safe_load(text)
formatted = yaml.dump(parsed, default_flow_style=False)
return self._sanitize_content(formatted)
except yaml.YAMLError:
return self._sanitize_content(text)
def _extract_xml_content(self, file_content: bytes) -> str:
"""Extract content from XML files."""
try:
root = ET.fromstring(file_content)
xml_string = ET.tostring(root, encoding='unicode')
return self._sanitize_content(xml_string)
except ET.ParseError:
return self._extract_text_content(file_content)
def _extract_csv_content(self, file_content: bytes) -> str:
"""Extract and format CSV content."""
text = self._extract_text_content(file_content)
try:
reader = csv.reader(io.StringIO(text))
rows = list(reader)
formatted = '\n'.join(['\t'.join(row) for row in rows])
return self._sanitize_content(formatted)
except Exception:
return self._sanitize_content(text)
def _sanitize_content(self, content: str) -> str:
"""Sanitize extracted content to remove problematic characters for PostgreSQL."""
if not content:
return content
sanitized = ''.join(char for char in content if ord(char) >= 32 or char in '\n\r\t')
sanitized = sanitized.replace('\x00', '')
sanitized = sanitized.replace('\u0000', '')
sanitized = sanitized.replace('\ufeff', '')
sanitized = sanitized.replace('\r\n', '\n').replace('\r', '\n')
sanitized = re.sub(r'\n{4,}', '\n\n\n', sanitized)
return sanitized.strip()
def _get_extraction_method(self, file_extension: str, mime_type: str) -> str:
"""Get the extraction method used for a file type."""
if file_extension == '.pdf':
return 'PyPDF2'
elif file_extension == '.docx':
return 'python-docx'
elif file_extension == '.xlsx':
return 'openpyxl'
elif file_extension in self.SUPPORTED_IMAGE_EXTENSIONS:
return 'pytesseract OCR'
elif file_extension == '.json':
return 'JSON parser'
elif file_extension in {'.yaml', '.yml'}:
return 'YAML parser'
elif file_extension == '.xml':
return 'XML parser'
elif file_extension == '.csv':
return 'CSV parser'
else:
return 'text encoding detection'
def _should_include_file(self, file_path: str, include_patterns: List[str], exclude_patterns: List[str]) -> bool:
"""Check if a file should be included based on patterns."""
import fnmatch
for pattern in exclude_patterns:
if fnmatch.fnmatch(file_path, pattern):
return False
for pattern in include_patterns:
if fnmatch.fnmatch(file_path, pattern):
return True
return False