304 lines
10 KiB
Python
304 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
HTTP file server with user authentication and per-user download tracking.
|
|
Serves files from a source directory, tracking which users have downloaded which files
|
|
in an SQLite database to prevent duplicate downloads.
|
|
"""
|
|
|
|
import logging
|
|
import zipfile
|
|
import io
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
from threading import Lock
|
|
from datetime import datetime
|
|
|
|
# Configuration
|
|
SOURCE_DIR = Path("/data/books/ingest")
|
|
DATABASE_PATH = Path("/data/books/downloads.db")
|
|
HOST = "0.0.0.0"
|
|
PORT = 18000
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('file_server.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Thread-safe file operation lock
|
|
file_lock = Lock()
|
|
|
|
|
|
def init_database():
|
|
"""Initialize the SQLite database and create tables if needed."""
|
|
DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
conn = sqlite3.connect(str(DATABASE_PATH))
|
|
cursor = conn.cursor()
|
|
|
|
# Users table - auth_token is the unique identifier
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
auth_token TEXT NOT NULL UNIQUE,
|
|
name TEXT UNIQUE,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_token ON users(auth_token)')
|
|
|
|
# Downloads table - tracks which user downloaded which file
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS downloads (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NOT NULL,
|
|
filename TEXT NOT NULL,
|
|
downloaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
UNIQUE(user_id, filename),
|
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
|
)
|
|
''')
|
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_downloads_user_id ON downloads(user_id)')
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
logger.info(f"Database initialized at {DATABASE_PATH}")
|
|
|
|
|
|
def get_user_by_token(auth_token):
|
|
"""Get user ID by auth token. Returns user_id if valid, None if not."""
|
|
conn = sqlite3.connect(str(DATABASE_PATH))
|
|
cursor = conn.cursor()
|
|
cursor.execute('SELECT id, name FROM users WHERE auth_token = ?', (auth_token,))
|
|
result = cursor.fetchone()
|
|
conn.close()
|
|
if result:
|
|
return {'id': result[0], 'name': result[1]}
|
|
return None
|
|
|
|
|
|
def get_user_downloads(user_id):
|
|
"""Get list of filenames already downloaded by a user."""
|
|
conn = sqlite3.connect(str(DATABASE_PATH))
|
|
cursor = conn.cursor()
|
|
cursor.execute('SELECT filename FROM downloads WHERE user_id = ?', (user_id,))
|
|
filenames = {row[0] for row in cursor.fetchall()}
|
|
conn.close()
|
|
return filenames
|
|
|
|
|
|
def record_downloads(user_id, filenames):
|
|
"""Record that a user has downloaded specific files."""
|
|
conn = sqlite3.connect(str(DATABASE_PATH))
|
|
cursor = conn.cursor()
|
|
timestamp = datetime.now().isoformat()
|
|
for filename in filenames:
|
|
cursor.execute(
|
|
'INSERT OR IGNORE INTO downloads (user_id, filename, downloaded_at) VALUES (?, ?, ?)',
|
|
(user_id, filename, timestamp)
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
logger.info(f"Recorded {len(filenames)} downloads for user {user_id}")
|
|
|
|
|
|
def create_user(auth_token, name=None):
|
|
"""Create a new user with the given auth token. Returns user_id or None if token exists."""
|
|
conn = sqlite3.connect(str(DATABASE_PATH))
|
|
cursor = conn.cursor()
|
|
try:
|
|
cursor.execute(
|
|
'INSERT INTO users (auth_token, name) VALUES (?, ?)',
|
|
(auth_token, name)
|
|
)
|
|
conn.commit()
|
|
user_id = cursor.lastrowid
|
|
logger.info(f"Created user {name or auth_token[:8]}... with id {user_id}")
|
|
return user_id
|
|
except sqlite3.IntegrityError:
|
|
logger.warning(f"User with token {auth_token[:8]}... already exists")
|
|
return None
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
class FileServerHandler(BaseHTTPRequestHandler):
|
|
"""HTTP request handler for authenticated file serving."""
|
|
|
|
def log_message(self, format, *args):
|
|
"""Override to use our logger instead of stderr."""
|
|
logger.info("%s - %s" % (self.get_client_ip(), format % args))
|
|
|
|
def get_client_ip(self):
|
|
"""Get the real client IP, considering reverse proxy headers."""
|
|
# Check X-Forwarded-For header (set by nginx)
|
|
forwarded_for = self.headers.get('X-Forwarded-For')
|
|
if forwarded_for:
|
|
# X-Forwarded-For can be a comma-separated list, get the first one
|
|
return forwarded_for.split(',')[0].strip()
|
|
|
|
# Check X-Real-IP header (alternative nginx header)
|
|
real_ip = self.headers.get('X-Real-IP')
|
|
if real_ip:
|
|
return real_ip.strip()
|
|
|
|
# Fallback to direct connection IP
|
|
return self.address_string()
|
|
|
|
def _send_error_response(self, code, message):
|
|
"""Send a JSON error response."""
|
|
self.send_response(code)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
self.wfile.write(json.dumps({"error": message}).encode() + b'\n')
|
|
|
|
def _check_auth(self):
|
|
"""Verify the authentication token and return user info if valid."""
|
|
auth_header = self.headers.get('Authorization', '')
|
|
|
|
# Support both "Bearer TOKEN" and just "TOKEN" formats
|
|
if auth_header.startswith('Bearer '):
|
|
token = auth_header[7:]
|
|
else:
|
|
token = auth_header
|
|
|
|
if not token:
|
|
return None
|
|
|
|
# Look up user by auth token in the database
|
|
return get_user_by_token(token)
|
|
|
|
def _get_files_to_serve(self, user_id):
|
|
"""Get list of files in the source directory that the user hasn't downloaded yet."""
|
|
try:
|
|
if not SOURCE_DIR.exists():
|
|
logger.error(f"Source directory does not exist: {SOURCE_DIR}")
|
|
return []
|
|
|
|
# Get all files in source directory
|
|
all_files = [f for f in SOURCE_DIR.iterdir() if f.is_file()]
|
|
|
|
# Get files already downloaded by this user
|
|
downloaded = get_user_downloads(user_id)
|
|
|
|
# Filter out already-downloaded files
|
|
files = [f for f in all_files if f.name not in downloaded]
|
|
|
|
logger.info(f"Found {len(all_files)} total files, {len(files)} new for user {user_id}")
|
|
return files
|
|
except Exception as e:
|
|
logger.error(f"Error reading source directory: {e}")
|
|
return []
|
|
|
|
def _create_zip_archive(self, files):
|
|
"""Create a zip archive of the files in memory."""
|
|
zip_buffer = io.BytesIO()
|
|
|
|
try:
|
|
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
|
for file_path in files:
|
|
try:
|
|
zip_file.write(file_path, file_path.name)
|
|
logger.info(f"Added {file_path.name} to archive")
|
|
except Exception as e:
|
|
logger.error(f"Failed to add {file_path.name} to archive: {e}")
|
|
raise
|
|
|
|
return zip_buffer.getvalue()
|
|
except Exception as e:
|
|
logger.error(f"Error creating zip archive: {e}")
|
|
raise
|
|
|
|
def do_POST(self):
|
|
"""Handle POST requests to /get endpoint."""
|
|
if self.path != '/get':
|
|
self._send_error_response(404, "Endpoint not found")
|
|
return
|
|
|
|
# Check authentication - returns user info if valid
|
|
user = self._check_auth()
|
|
if not user:
|
|
logger.warning(f"Unauthorized access attempt from {self.get_client_ip()}")
|
|
self._send_error_response(401, "Unauthorized - Invalid token")
|
|
return
|
|
|
|
user_id = user['id']
|
|
user_name = user['name'] or f"User {user_id}"
|
|
logger.info(f"Processing request for user: {user_name} (id={user_id})")
|
|
|
|
# Use lock to prevent concurrent file operations
|
|
with file_lock:
|
|
try:
|
|
# Get files to serve (excluding already-downloaded ones)
|
|
files = self._get_files_to_serve(user_id)
|
|
|
|
if not files:
|
|
self._send_error_response(404, "No new files available for this user")
|
|
return
|
|
|
|
# Create zip archive
|
|
logger.info(f"Creating archive of {len(files)} files for user {user_id}")
|
|
zip_data = self._create_zip_archive(files)
|
|
|
|
# Send the zip file
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/zip')
|
|
self.send_header('Content-Disposition', 'attachment; filename="files.zip"')
|
|
self.send_header('Content-Length', str(len(zip_data)))
|
|
self.end_headers()
|
|
self.wfile.write(zip_data)
|
|
|
|
logger.info(f"Successfully sent {len(zip_data)} bytes to {self.get_client_ip()}")
|
|
|
|
# Record the downloads in the database
|
|
filenames = [f.name for f in files]
|
|
record_downloads(user_id, filenames)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing request: {e}", exc_info=True)
|
|
self._send_error_response(500, "Internal server error")
|
|
|
|
def do_GET(self):
|
|
"""Handle GET requests (health check)."""
|
|
if self.path == '/health':
|
|
self.send_response(200)
|
|
self.send_header('Content-Type', 'application/json')
|
|
self.end_headers()
|
|
self.wfile.write(b'{"status": "ok"}\n')
|
|
else:
|
|
self._send_error_response(404, "Endpoint not found. Use POST /get")
|
|
|
|
|
|
def run_server():
|
|
"""Start the HTTP server."""
|
|
# Ensure source directory exists
|
|
SOURCE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Initialize the database
|
|
init_database()
|
|
|
|
logger.info(f"Starting server on {HOST}:{PORT}")
|
|
logger.info(f"Source directory: {SOURCE_DIR}")
|
|
logger.info(f"Database: {DATABASE_PATH}")
|
|
|
|
server = HTTPServer((HOST, PORT), FileServerHandler)
|
|
|
|
try:
|
|
logger.info("Server is running. Press Ctrl+C to stop.")
|
|
server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
logger.info("Server stopped by user")
|
|
finally:
|
|
server.server_close()
|
|
logger.info("Server closed")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_server()
|