154 lines
4.7 KiB
Python
154 lines
4.7 KiB
Python
import asyncio
|
|
import base64
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import warnings
|
|
from contextlib import asynccontextmanager
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, Response
|
|
from fastapi.staticfiles import StaticFiles
|
|
from ib_async import IB
|
|
|
|
import dependencies
|
|
from init_db import ensure_database
|
|
from routers import charts, health, portfolio, scanner, trades
|
|
|
|
load_dotenv()
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s — %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
IBKR_HOST = os.getenv("IBKR_HOST", "127.0.0.1")
|
|
IBKR_PORT = int(os.getenv("IBKR_PORT", 7497))
|
|
IBKR_CLIENT_ID = int(os.getenv("IBKR_CLIENT_ID", 1))
|
|
DB_PATH = os.getenv("DB_PATH", "trades.db")
|
|
|
|
UI_USERNAME = os.getenv("UI_USERNAME", "")
|
|
UI_PASSWORD = os.getenv("UI_PASSWORD", "")
|
|
AUTH_ENABLED = bool(UI_USERNAME and UI_PASSWORD)
|
|
|
|
_DEFAULT_SECRET = "your_super_secret_string_123"
|
|
|
|
worker_count = int(os.getenv("WEB_CONCURRENCY", "1"))
|
|
if worker_count > 1:
|
|
warnings.warn(
|
|
"WARNING: Multiple workers! SSE/ChartState will break. Use --workers 1",
|
|
RuntimeWarning,
|
|
)
|
|
logger.warning("Multiple workers detected — ChartState is not shared across workers!")
|
|
|
|
conn, cursor, db_write_lock = ensure_database(DB_PATH)
|
|
ib = IB()
|
|
|
|
|
|
def _validate_config() -> None:
|
|
webhook_secret = os.getenv("WEBHOOK_SECRET", "")
|
|
if not webhook_secret or webhook_secret == _DEFAULT_SECRET:
|
|
logger.warning(
|
|
"SECURITY: WEBHOOK_SECRET is unset or still the default value. "
|
|
"Generate a strong secret with: openssl rand -hex 32"
|
|
)
|
|
if AUTH_ENABLED:
|
|
logger.info("UI Basic Auth is ENABLED")
|
|
else:
|
|
logger.info(
|
|
"UI Basic Auth is DISABLED — set UI_USERNAME + UI_PASSWORD in .env to enable"
|
|
)
|
|
|
|
|
|
async def reconnect_loop() -> None:
|
|
max_attempts = 10
|
|
for attempt in range(1, max_attempts + 1):
|
|
await asyncio.sleep(5)
|
|
logger.info(f"Reconnect attempt {attempt}/{max_attempts}...")
|
|
try:
|
|
await ib.connectAsync(IBKR_HOST, IBKR_PORT, clientId=IBKR_CLIENT_ID)
|
|
dependencies.set_ib_instance(ib)
|
|
logger.info("Reconnected to IB Gateway successfully.")
|
|
return
|
|
except Exception as exc:
|
|
logger.warning(f"Reconnect attempt {attempt} failed: {exc}")
|
|
if attempt < max_attempts:
|
|
await asyncio.sleep(30)
|
|
logger.error("Max reconnect attempts reached. Manual restart required.")
|
|
|
|
|
|
def on_disconnect() -> None:
|
|
logger.warning("IB Gateway disconnected. Scheduling reconnect...")
|
|
asyncio.create_task(reconnect_loop())
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
_validate_config()
|
|
try:
|
|
await ib.connectAsync(IBKR_HOST, IBKR_PORT, clientId=IBKR_CLIENT_ID)
|
|
logger.info("Connected to IB Gateway")
|
|
dependencies.setup_dependencies(ib, conn, cursor, db_write_lock)
|
|
ib.disconnectedEvent += on_disconnect
|
|
yield
|
|
finally:
|
|
logger.info("Disconnecting from IB")
|
|
ib.disconnect()
|
|
conn.close()
|
|
logger.info("Database connection closed")
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
# UI pages that require Basic Auth when AUTH_ENABLED
|
|
_UI_PATHS = {"/", "/scanner", "/tradelog", "/portfolio"}
|
|
_PUBLIC_PREFIXES = (
|
|
"/health",
|
|
"/webhook",
|
|
"/static/",
|
|
"/stream",
|
|
"/history",
|
|
"/subscribe",
|
|
"/portfolio/data",
|
|
"/portfolio/pnl",
|
|
"/risk/",
|
|
"/scanner/export",
|
|
)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def basic_auth_middleware(request: Request, call_next):
|
|
if not AUTH_ENABLED:
|
|
return await call_next(request)
|
|
path = request.url.path
|
|
if any(path.startswith(p) for p in _PUBLIC_PREFIXES):
|
|
return await call_next(request)
|
|
if path not in _UI_PATHS:
|
|
return await call_next(request)
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Basic "):
|
|
try:
|
|
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
|
username, _, password = decoded.partition(":")
|
|
if secrets.compare_digest(username, UI_USERNAME) and secrets.compare_digest(
|
|
password, UI_PASSWORD
|
|
):
|
|
return await call_next(request)
|
|
except Exception:
|
|
pass
|
|
return Response(
|
|
status_code=401,
|
|
headers={"WWW-Authenticate": 'Basic realm="IBKR Dashboard"'},
|
|
content="Unauthorized",
|
|
)
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
app.include_router(health.router)
|
|
app.include_router(charts.router)
|
|
app.include_router(scanner.router)
|
|
app.include_router(trades.router)
|
|
app.include_router(portfolio.router)
|