import asyncio import base64 import logging import os import secrets import warnings from contextlib import asynccontextmanager from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response from fastapi.staticfiles import StaticFiles from ib_async import IB import dependencies from init_db import ensure_database from routers import charts, health, portfolio, scanner, trades load_dotenv() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s", ) logger = logging.getLogger(__name__) IBKR_HOST = os.getenv("IBKR_HOST", "127.0.0.1") IBKR_PORT = int(os.getenv("IBKR_PORT", 7497)) IBKR_CLIENT_ID = int(os.getenv("IBKR_CLIENT_ID", 1)) DB_PATH = os.getenv("DB_PATH", "trades.db") UI_USERNAME = os.getenv("UI_USERNAME", "") UI_PASSWORD = os.getenv("UI_PASSWORD", "") AUTH_ENABLED = bool(UI_USERNAME and UI_PASSWORD) _DEFAULT_SECRET = "your_super_secret_string_123" worker_count = int(os.getenv("WEB_CONCURRENCY", "1")) if worker_count > 1: warnings.warn( "WARNING: Multiple workers! SSE/ChartState will break. Use --workers 1", RuntimeWarning, ) logger.warning("Multiple workers detected — ChartState is not shared across workers!") conn, cursor, db_write_lock = ensure_database(DB_PATH) ib = IB() def _validate_config() -> None: webhook_secret = os.getenv("WEBHOOK_SECRET", "") if not webhook_secret or webhook_secret == _DEFAULT_SECRET: logger.warning( "SECURITY: WEBHOOK_SECRET is unset or still the default value. " "Generate a strong secret with: openssl rand -hex 32" ) if AUTH_ENABLED: logger.info("UI Basic Auth is ENABLED") else: logger.info( "UI Basic Auth is DISABLED — set UI_USERNAME + UI_PASSWORD in .env to enable" ) async def reconnect_loop() -> None: max_attempts = 10 for attempt in range(1, max_attempts + 1): await asyncio.sleep(5) logger.info(f"Reconnect attempt {attempt}/{max_attempts}...") try: await ib.connectAsync(IBKR_HOST, IBKR_PORT, clientId=IBKR_CLIENT_ID) dependencies.set_ib_instance(ib) logger.info("Reconnected to IB Gateway successfully.") return except Exception as exc: logger.warning(f"Reconnect attempt {attempt} failed: {exc}") if attempt < max_attempts: await asyncio.sleep(30) logger.error("Max reconnect attempts reached. Manual restart required.") def on_disconnect() -> None: logger.warning("IB Gateway disconnected. Scheduling reconnect...") asyncio.create_task(reconnect_loop()) @asynccontextmanager async def lifespan(app: FastAPI): _validate_config() try: await ib.connectAsync(IBKR_HOST, IBKR_PORT, clientId=IBKR_CLIENT_ID) logger.info("Connected to IB Gateway") dependencies.setup_dependencies(ib, conn, cursor, db_write_lock) ib.disconnectedEvent += on_disconnect yield finally: logger.info("Disconnecting from IB") ib.disconnect() conn.close() logger.info("Database connection closed") app = FastAPI(lifespan=lifespan) # UI pages that require Basic Auth when AUTH_ENABLED _UI_PATHS = {"/", "/scanner", "/tradelog", "/portfolio"} _PUBLIC_PREFIXES = ( "/health", "/webhook", "/static/", "/stream", "/history", "/subscribe", "/portfolio/data", "/portfolio/pnl", "/risk/", "/scanner/export", ) @app.middleware("http") async def basic_auth_middleware(request: Request, call_next): if not AUTH_ENABLED: return await call_next(request) path = request.url.path if any(path.startswith(p) for p in _PUBLIC_PREFIXES): return await call_next(request) if path not in _UI_PATHS: return await call_next(request) auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Basic "): try: decoded = base64.b64decode(auth_header[6:]).decode("utf-8") username, _, password = decoded.partition(":") if secrets.compare_digest(username, UI_USERNAME) and secrets.compare_digest( password, UI_PASSWORD ): return await call_next(request) except Exception: pass return Response( status_code=401, headers={"WWW-Authenticate": 'Basic realm="IBKR Dashboard"'}, content="Unauthorized", ) app.mount("/static", StaticFiles(directory="static"), name="static") app.include_router(health.router) app.include_router(charts.router) app.include_router(scanner.router) app.include_router(trades.router) app.include_router(portfolio.router)