initial commit
This commit is contained in:
@@ -0,0 +1,153 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from ib_async import IB
|
||||
|
||||
import dependencies
|
||||
from init_db import ensure_database
|
||||
from routers import charts, health, portfolio, scanner, trades
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s — %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IBKR_HOST = os.getenv("IBKR_HOST", "127.0.0.1")
|
||||
IBKR_PORT = int(os.getenv("IBKR_PORT", 7497))
|
||||
IBKR_CLIENT_ID = int(os.getenv("IBKR_CLIENT_ID", 1))
|
||||
DB_PATH = os.getenv("DB_PATH", "trades.db")
|
||||
|
||||
UI_USERNAME = os.getenv("UI_USERNAME", "")
|
||||
UI_PASSWORD = os.getenv("UI_PASSWORD", "")
|
||||
AUTH_ENABLED = bool(UI_USERNAME and UI_PASSWORD)
|
||||
|
||||
_DEFAULT_SECRET = "your_super_secret_string_123"
|
||||
|
||||
worker_count = int(os.getenv("WEB_CONCURRENCY", "1"))
|
||||
if worker_count > 1:
|
||||
warnings.warn(
|
||||
"WARNING: Multiple workers! SSE/ChartState will break. Use --workers 1",
|
||||
RuntimeWarning,
|
||||
)
|
||||
logger.warning("Multiple workers detected — ChartState is not shared across workers!")
|
||||
|
||||
conn, cursor, db_write_lock = ensure_database(DB_PATH)
|
||||
ib = IB()
|
||||
|
||||
|
||||
def _validate_config() -> None:
|
||||
webhook_secret = os.getenv("WEBHOOK_SECRET", "")
|
||||
if not webhook_secret or webhook_secret == _DEFAULT_SECRET:
|
||||
logger.warning(
|
||||
"SECURITY: WEBHOOK_SECRET is unset or still the default value. "
|
||||
"Generate a strong secret with: openssl rand -hex 32"
|
||||
)
|
||||
if AUTH_ENABLED:
|
||||
logger.info("UI Basic Auth is ENABLED")
|
||||
else:
|
||||
logger.info(
|
||||
"UI Basic Auth is DISABLED — set UI_USERNAME + UI_PASSWORD in .env to enable"
|
||||
)
|
||||
|
||||
|
||||
async def reconnect_loop() -> None:
|
||||
max_attempts = 10
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
await asyncio.sleep(5)
|
||||
logger.info(f"Reconnect attempt {attempt}/{max_attempts}...")
|
||||
try:
|
||||
await ib.connectAsync(IBKR_HOST, IBKR_PORT, clientId=IBKR_CLIENT_ID)
|
||||
dependencies.set_ib_instance(ib)
|
||||
logger.info("Reconnected to IB Gateway successfully.")
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(f"Reconnect attempt {attempt} failed: {exc}")
|
||||
if attempt < max_attempts:
|
||||
await asyncio.sleep(30)
|
||||
logger.error("Max reconnect attempts reached. Manual restart required.")
|
||||
|
||||
|
||||
def on_disconnect() -> None:
|
||||
logger.warning("IB Gateway disconnected. Scheduling reconnect...")
|
||||
asyncio.create_task(reconnect_loop())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
_validate_config()
|
||||
try:
|
||||
await ib.connectAsync(IBKR_HOST, IBKR_PORT, clientId=IBKR_CLIENT_ID)
|
||||
logger.info("Connected to IB Gateway")
|
||||
dependencies.setup_dependencies(ib, conn, cursor, db_write_lock)
|
||||
ib.disconnectedEvent += on_disconnect
|
||||
yield
|
||||
finally:
|
||||
logger.info("Disconnecting from IB")
|
||||
ib.disconnect()
|
||||
conn.close()
|
||||
logger.info("Database connection closed")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# UI pages that require Basic Auth when AUTH_ENABLED
|
||||
_UI_PATHS = {"/", "/scanner", "/tradelog", "/portfolio"}
|
||||
_PUBLIC_PREFIXES = (
|
||||
"/health",
|
||||
"/webhook",
|
||||
"/static/",
|
||||
"/stream",
|
||||
"/history",
|
||||
"/subscribe",
|
||||
"/portfolio/data",
|
||||
"/portfolio/pnl",
|
||||
"/risk/",
|
||||
"/scanner/export",
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def basic_auth_middleware(request: Request, call_next):
|
||||
if not AUTH_ENABLED:
|
||||
return await call_next(request)
|
||||
path = request.url.path
|
||||
if any(path.startswith(p) for p in _PUBLIC_PREFIXES):
|
||||
return await call_next(request)
|
||||
if path not in _UI_PATHS:
|
||||
return await call_next(request)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
|
||||
username, _, password = decoded.partition(":")
|
||||
if secrets.compare_digest(username, UI_USERNAME) and secrets.compare_digest(
|
||||
password, UI_PASSWORD
|
||||
):
|
||||
return await call_next(request)
|
||||
except Exception:
|
||||
pass
|
||||
return Response(
|
||||
status_code=401,
|
||||
headers={"WWW-Authenticate": 'Basic realm="IBKR Dashboard"'},
|
||||
content="Unauthorized",
|
||||
)
|
||||
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
app.include_router(health.router)
|
||||
app.include_router(charts.router)
|
||||
app.include_router(scanner.router)
|
||||
app.include_router(trades.router)
|
||||
app.include_router(portfolio.router)
|
||||
Reference in New Issue
Block a user