From 6d8a425a701bd8cdbe678aa36cc0387320ad9df9 Mon Sep 17 00:00:00 2001 From: Thomas Faour Date: Sat, 10 Jan 2026 21:09:24 -0500 Subject: [PATCH] Added multiple users possible --- add-user.py | 94 ++++++++++++++++++++++ server.py | 223 +++++++++++++++++++++++++++++++++++----------------- 2 files changed, 245 insertions(+), 72 deletions(-) create mode 100755 add-user.py diff --git a/add-user.py b/add-user.py new file mode 100755 index 0000000..8213e5c --- /dev/null +++ b/add-user.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +"""CLI tool to manage users in the KOReaderServerFetcher database.""" + +import argparse +import sqlite3 +import secrets +import sys + + +def add_user(db_path, username): + """Add a new user with a random token.""" + token = secrets.token_urlsafe(32) + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Ensure table exists + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + auth_token TEXT NOT NULL UNIQUE, + name TEXT UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + + try: + cursor.execute( + 'INSERT INTO users (auth_token, name) VALUES (?, ?)', + (token, username) + ) + conn.commit() + print(f"Created user: {username}") + print(f"Token: {token}") + return True + except sqlite3.IntegrityError: + print(f"Error: Username '{username}' already exists", file=sys.stderr) + return False + finally: + conn.close() + + +def delete_user(db_path, username): + """Delete a user by their username.""" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # First check if user exists + cursor.execute('SELECT id, auth_token FROM users WHERE name = ?', (username,)) + result = cursor.fetchone() + + if not result: + print(f"Error: No user found with username '{username}'", file=sys.stderr) + conn.close() + return False + + user_id, token = result + + # Delete user's downloads first + cursor.execute('DELETE FROM downloads WHERE user_id = ?', (user_id,)) + downloads_deleted = cursor.rowcount + + # Delete the user + cursor.execute('DELETE FROM users WHERE id = ?', (user_id,)) + conn.commit() + conn.close() + + print(f"Deleted user: {username}") + if downloads_deleted > 0: + print(f"Also removed {downloads_deleted} download record(s)") + return True + + +def main(): + parser = argparse.ArgumentParser(description='Manage users in the KOReaderServerFetcher database') + parser.add_argument('--db', required=True, help='Path to the SQLite database') + parser.add_argument('--username', help='Username for the user') + parser.add_argument('--delete', action='store_true', help='Delete a user instead of adding') + + args = parser.parse_args() + + if not args.username: + parser.error('--username is required') + + if args.delete: + success = delete_user(args.db, args.username) + else: + success = add_user(args.db, args.username) + + sys.exit(0 if success else 1) + + +if __name__ == '__main__': + main() diff --git a/server.py b/server.py index 2d0b923..113e6c2 100644 --- a/server.py +++ b/server.py @@ -1,22 +1,23 @@ #!/usr/bin/env python3 """ -Robust HTTP file server with authentication and automatic file archiving. -Serves files from a source directory and moves them to an archive directory after serving. +HTTP file server with user authentication and per-user download tracking. +Serves files from a source directory, tracking which users have downloaded which files +in an SQLite database to prevent duplicate downloads. """ -import os -import shutil import logging import zipfile import io +import json +import sqlite3 from pathlib import Path from http.server import HTTPServer, BaseHTTPRequestHandler from threading import Lock +from datetime import datetime # Configuration -AUTH_TOKEN = "chai7pu5oosigh4Ahzajoocheich9hio" SOURCE_DIR = Path("/data/books/ingest") -ARCHIVE_DIR = Path("/data/books/served") +DATABASE_PATH = Path("/data/books/downloads.db") HOST = "0.0.0.0" PORT = 18000 @@ -35,6 +36,98 @@ logger = logging.getLogger(__name__) file_lock = Lock() +def init_database(): + """Initialize the SQLite database and create tables if needed.""" + DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(DATABASE_PATH)) + cursor = conn.cursor() + + # Users table - auth_token is the unique identifier + cursor.execute(''' + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + auth_token TEXT NOT NULL UNIQUE, + name TEXT UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + ''') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_auth_token ON users(auth_token)') + + # Downloads table - tracks which user downloaded which file + cursor.execute(''' + CREATE TABLE IF NOT EXISTS downloads ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + filename TEXT NOT NULL, + downloaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, filename), + FOREIGN KEY (user_id) REFERENCES users(id) + ) + ''') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_downloads_user_id ON downloads(user_id)') + + conn.commit() + conn.close() + logger.info(f"Database initialized at {DATABASE_PATH}") + + +def get_user_by_token(auth_token): + """Get user ID by auth token. Returns user_id if valid, None if not.""" + conn = sqlite3.connect(str(DATABASE_PATH)) + cursor = conn.cursor() + cursor.execute('SELECT id, name FROM users WHERE auth_token = ?', (auth_token,)) + result = cursor.fetchone() + conn.close() + if result: + return {'id': result[0], 'name': result[1]} + return None + + +def get_user_downloads(user_id): + """Get list of filenames already downloaded by a user.""" + conn = sqlite3.connect(str(DATABASE_PATH)) + cursor = conn.cursor() + cursor.execute('SELECT filename FROM downloads WHERE user_id = ?', (user_id,)) + filenames = {row[0] for row in cursor.fetchall()} + conn.close() + return filenames + + +def record_downloads(user_id, filenames): + """Record that a user has downloaded specific files.""" + conn = sqlite3.connect(str(DATABASE_PATH)) + cursor = conn.cursor() + timestamp = datetime.now().isoformat() + for filename in filenames: + cursor.execute( + 'INSERT OR IGNORE INTO downloads (user_id, filename, downloaded_at) VALUES (?, ?, ?)', + (user_id, filename, timestamp) + ) + conn.commit() + conn.close() + logger.info(f"Recorded {len(filenames)} downloads for user {user_id}") + + +def create_user(auth_token, name=None): + """Create a new user with the given auth token. Returns user_id or None if token exists.""" + conn = sqlite3.connect(str(DATABASE_PATH)) + cursor = conn.cursor() + try: + cursor.execute( + 'INSERT INTO users (auth_token, name) VALUES (?, ?)', + (auth_token, name) + ) + conn.commit() + user_id = cursor.lastrowid + logger.info(f"Created user {name or auth_token[:8]}... with id {user_id}") + return user_id + except sqlite3.IntegrityError: + logger.warning(f"User with token {auth_token[:8]}... already exists") + return None + finally: + conn.close() + + class FileServerHandler(BaseHTTPRequestHandler): """HTTP request handler for authenticated file serving.""" @@ -57,35 +150,47 @@ class FileServerHandler(BaseHTTPRequestHandler): # Fallback to direct connection IP return self.address_string() - + def _send_error_response(self, code, message): """Send a JSON error response.""" self.send_response(code) self.send_header('Content-Type', 'application/json') self.end_headers() - self.wfile.write(f'{{"error": "{message}"}}\n'.encode()) - + self.wfile.write(json.dumps({"error": message}).encode() + b'\n') + def _check_auth(self): - """Verify the authentication token.""" + """Verify the authentication token and return user info if valid.""" auth_header = self.headers.get('Authorization', '') - + # Support both "Bearer TOKEN" and just "TOKEN" formats if auth_header.startswith('Bearer '): token = auth_header[7:] else: token = auth_header - - return token == AUTH_TOKEN + + if not token: + return None + + # Look up user by auth token in the database + return get_user_by_token(token) - def _get_files_to_serve(self): - """Get list of files in the source directory.""" + def _get_files_to_serve(self, user_id): + """Get list of files in the source directory that the user hasn't downloaded yet.""" try: if not SOURCE_DIR.exists(): logger.error(f"Source directory does not exist: {SOURCE_DIR}") return [] - - files = [f for f in SOURCE_DIR.iterdir() if f.is_file()] - logger.info(f"Found {len(files)} files to serve") + + # Get all files in source directory + all_files = [f for f in SOURCE_DIR.iterdir() if f.is_file()] + + # Get files already downloaded by this user + downloaded = get_user_downloads(user_id) + + # Filter out already-downloaded files + files = [f for f in all_files if f.name not in downloaded] + + logger.info(f"Found {len(all_files)} total files, {len(files)} new for user {user_id}") return files except Exception as e: logger.error(f"Error reading source directory: {e}") @@ -110,65 +215,37 @@ class FileServerHandler(BaseHTTPRequestHandler): logger.error(f"Error creating zip archive: {e}") raise - def _move_files_to_archive(self, files): - """Move served files to the archive directory.""" - # Ensure archive directory exists - ARCHIVE_DIR.mkdir(parents=True, exist_ok=True) - - moved_count = 0 - failed_files = [] - - for file_path in files: - try: - dest_path = ARCHIVE_DIR / file_path.name - - # Handle duplicate filenames - if dest_path.exists(): - base = dest_path.stem - suffix = dest_path.suffix - counter = 1 - while dest_path.exists(): - dest_path = ARCHIVE_DIR / f"{base}_{counter}{suffix}" - counter += 1 - - shutil.move(str(file_path), str(dest_path)) - logger.info(f"Moved {file_path.name} to {dest_path}") - moved_count += 1 - except Exception as e: - logger.error(f"Failed to move {file_path.name}: {e}") - failed_files.append(file_path.name) - - if failed_files: - logger.warning(f"Failed to move {len(failed_files)} files: {failed_files}") - - return moved_count, failed_files - def do_POST(self): """Handle POST requests to /get endpoint.""" if self.path != '/get': self._send_error_response(404, "Endpoint not found") return - - # Check authentication - if not self._check_auth(): + + # Check authentication - returns user info if valid + user = self._check_auth() + if not user: logger.warning(f"Unauthorized access attempt from {self.get_client_ip()}") self._send_error_response(401, "Unauthorized - Invalid token") return - + + user_id = user['id'] + user_name = user['name'] or f"User {user_id}" + logger.info(f"Processing request for user: {user_name} (id={user_id})") + # Use lock to prevent concurrent file operations with file_lock: try: - # Get files to serve - files = self._get_files_to_serve() - + # Get files to serve (excluding already-downloaded ones) + files = self._get_files_to_serve(user_id) + if not files: - self._send_error_response(404, "No files available to serve") + self._send_error_response(404, "No new files available for this user") return - + # Create zip archive - logger.info(f"Creating archive of {len(files)} files") + logger.info(f"Creating archive of {len(files)} files for user {user_id}") zip_data = self._create_zip_archive(files) - + # Send the zip file self.send_response(200) self.send_header('Content-Type', 'application/zip') @@ -176,13 +253,13 @@ class FileServerHandler(BaseHTTPRequestHandler): self.send_header('Content-Length', str(len(zip_data))) self.end_headers() self.wfile.write(zip_data) - + logger.info(f"Successfully sent {len(zip_data)} bytes to {self.get_client_ip()}") - - # Move files to archive directory - moved, failed = self._move_files_to_archive(files) - logger.info(f"Archived {moved}/{len(files)} files") - + + # Record the downloads in the database + filenames = [f.name for f in files] + record_downloads(user_id, filenames) + except Exception as e: logger.error(f"Error processing request: {e}", exc_info=True) self._send_error_response(500, "Internal server error") @@ -200,13 +277,15 @@ class FileServerHandler(BaseHTTPRequestHandler): def run_server(): """Start the HTTP server.""" - # Ensure directories exist + # Ensure source directory exists SOURCE_DIR.mkdir(parents=True, exist_ok=True) - ARCHIVE_DIR.mkdir(parents=True, exist_ok=True) - + + # Initialize the database + init_database() + logger.info(f"Starting server on {HOST}:{PORT}") logger.info(f"Source directory: {SOURCE_DIR}") - logger.info(f"Archive directory: {ARCHIVE_DIR}") + logger.info(f"Database: {DATABASE_PATH}") server = HTTPServer((HOST, PORT), FileServerHandler)