2026-01-10 21:09:24 -05:00

304 lines
10 KiB
Python

#!/usr/bin/env python3
"""
HTTP file server with user authentication and per-user download tracking.
Serves files from a source directory, tracking which users have downloaded which files
in an SQLite database to prevent duplicate downloads.
"""
import logging
import zipfile
import io
import json
import sqlite3
from pathlib import Path
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Lock
from datetime import datetime
# Configuration
SOURCE_DIR = Path("/data/books/ingest")
DATABASE_PATH = Path("/data/books/downloads.db")
HOST = "0.0.0.0"
PORT = 18000
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('file_server.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Thread-safe file operation lock
file_lock = Lock()
def init_database():
"""Initialize the SQLite database and create tables if needed."""
DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DATABASE_PATH))
cursor = conn.cursor()
# Users table - auth_token is the unique identifier
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
auth_token TEXT NOT NULL UNIQUE,
name TEXT UNIQUE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_token ON users(auth_token)')
# Downloads table - tracks which user downloaded which file
cursor.execute('''
CREATE TABLE IF NOT EXISTS downloads (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
filename TEXT NOT NULL,
downloaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, filename),
FOREIGN KEY (user_id) REFERENCES users(id)
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_downloads_user_id ON downloads(user_id)')
conn.commit()
conn.close()
logger.info(f"Database initialized at {DATABASE_PATH}")
def get_user_by_token(auth_token):
"""Get user ID by auth token. Returns user_id if valid, None if not."""
conn = sqlite3.connect(str(DATABASE_PATH))
cursor = conn.cursor()
cursor.execute('SELECT id, name FROM users WHERE auth_token = ?', (auth_token,))
result = cursor.fetchone()
conn.close()
if result:
return {'id': result[0], 'name': result[1]}
return None
def get_user_downloads(user_id):
"""Get list of filenames already downloaded by a user."""
conn = sqlite3.connect(str(DATABASE_PATH))
cursor = conn.cursor()
cursor.execute('SELECT filename FROM downloads WHERE user_id = ?', (user_id,))
filenames = {row[0] for row in cursor.fetchall()}
conn.close()
return filenames
def record_downloads(user_id, filenames):
"""Record that a user has downloaded specific files."""
conn = sqlite3.connect(str(DATABASE_PATH))
cursor = conn.cursor()
timestamp = datetime.now().isoformat()
for filename in filenames:
cursor.execute(
'INSERT OR IGNORE INTO downloads (user_id, filename, downloaded_at) VALUES (?, ?, ?)',
(user_id, filename, timestamp)
)
conn.commit()
conn.close()
logger.info(f"Recorded {len(filenames)} downloads for user {user_id}")
def create_user(auth_token, name=None):
"""Create a new user with the given auth token. Returns user_id or None if token exists."""
conn = sqlite3.connect(str(DATABASE_PATH))
cursor = conn.cursor()
try:
cursor.execute(
'INSERT INTO users (auth_token, name) VALUES (?, ?)',
(auth_token, name)
)
conn.commit()
user_id = cursor.lastrowid
logger.info(f"Created user {name or auth_token[:8]}... with id {user_id}")
return user_id
except sqlite3.IntegrityError:
logger.warning(f"User with token {auth_token[:8]}... already exists")
return None
finally:
conn.close()
class FileServerHandler(BaseHTTPRequestHandler):
"""HTTP request handler for authenticated file serving."""
def log_message(self, format, *args):
"""Override to use our logger instead of stderr."""
logger.info("%s - %s" % (self.get_client_ip(), format % args))
def get_client_ip(self):
"""Get the real client IP, considering reverse proxy headers."""
# Check X-Forwarded-For header (set by nginx)
forwarded_for = self.headers.get('X-Forwarded-For')
if forwarded_for:
# X-Forwarded-For can be a comma-separated list, get the first one
return forwarded_for.split(',')[0].strip()
# Check X-Real-IP header (alternative nginx header)
real_ip = self.headers.get('X-Real-IP')
if real_ip:
return real_ip.strip()
# Fallback to direct connection IP
return self.address_string()
def _send_error_response(self, code, message):
"""Send a JSON error response."""
self.send_response(code)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps({"error": message}).encode() + b'\n')
def _check_auth(self):
"""Verify the authentication token and return user info if valid."""
auth_header = self.headers.get('Authorization', '')
# Support both "Bearer TOKEN" and just "TOKEN" formats
if auth_header.startswith('Bearer '):
token = auth_header[7:]
else:
token = auth_header
if not token:
return None
# Look up user by auth token in the database
return get_user_by_token(token)
def _get_files_to_serve(self, user_id):
"""Get list of files in the source directory that the user hasn't downloaded yet."""
try:
if not SOURCE_DIR.exists():
logger.error(f"Source directory does not exist: {SOURCE_DIR}")
return []
# Get all files in source directory
all_files = [f for f in SOURCE_DIR.iterdir() if f.is_file()]
# Get files already downloaded by this user
downloaded = get_user_downloads(user_id)
# Filter out already-downloaded files
files = [f for f in all_files if f.name not in downloaded]
logger.info(f"Found {len(all_files)} total files, {len(files)} new for user {user_id}")
return files
except Exception as e:
logger.error(f"Error reading source directory: {e}")
return []
def _create_zip_archive(self, files):
"""Create a zip archive of the files in memory."""
zip_buffer = io.BytesIO()
try:
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for file_path in files:
try:
zip_file.write(file_path, file_path.name)
logger.info(f"Added {file_path.name} to archive")
except Exception as e:
logger.error(f"Failed to add {file_path.name} to archive: {e}")
raise
return zip_buffer.getvalue()
except Exception as e:
logger.error(f"Error creating zip archive: {e}")
raise
def do_POST(self):
"""Handle POST requests to /get endpoint."""
if self.path != '/get':
self._send_error_response(404, "Endpoint not found")
return
# Check authentication - returns user info if valid
user = self._check_auth()
if not user:
logger.warning(f"Unauthorized access attempt from {self.get_client_ip()}")
self._send_error_response(401, "Unauthorized - Invalid token")
return
user_id = user['id']
user_name = user['name'] or f"User {user_id}"
logger.info(f"Processing request for user: {user_name} (id={user_id})")
# Use lock to prevent concurrent file operations
with file_lock:
try:
# Get files to serve (excluding already-downloaded ones)
files = self._get_files_to_serve(user_id)
if not files:
self._send_error_response(404, "No new files available for this user")
return
# Create zip archive
logger.info(f"Creating archive of {len(files)} files for user {user_id}")
zip_data = self._create_zip_archive(files)
# Send the zip file
self.send_response(200)
self.send_header('Content-Type', 'application/zip')
self.send_header('Content-Disposition', 'attachment; filename="files.zip"')
self.send_header('Content-Length', str(len(zip_data)))
self.end_headers()
self.wfile.write(zip_data)
logger.info(f"Successfully sent {len(zip_data)} bytes to {self.get_client_ip()}")
# Record the downloads in the database
filenames = [f.name for f in files]
record_downloads(user_id, filenames)
except Exception as e:
logger.error(f"Error processing request: {e}", exc_info=True)
self._send_error_response(500, "Internal server error")
def do_GET(self):
"""Handle GET requests (health check)."""
if self.path == '/health':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(b'{"status": "ok"}\n')
else:
self._send_error_response(404, "Endpoint not found. Use POST /get")
def run_server():
"""Start the HTTP server."""
# Ensure source directory exists
SOURCE_DIR.mkdir(parents=True, exist_ok=True)
# Initialize the database
init_database()
logger.info(f"Starting server on {HOST}:{PORT}")
logger.info(f"Source directory: {SOURCE_DIR}")
logger.info(f"Database: {DATABASE_PATH}")
server = HTTPServer((HOST, PORT), FileServerHandler)
try:
logger.info("Server is running. Press Ctrl+C to stop.")
server.serve_forever()
except KeyboardInterrupt:
logger.info("Server stopped by user")
finally:
server.server_close()
logger.info("Server closed")
if __name__ == '__main__':
run_server()