import asyncio import logging import os from datetime import datetime, timezone logger = logging.getLogger(__name__) class RiskManager: def __init__(self): self.max_daily_loss = float(os.getenv("MAX_DAILY_LOSS", "500.0")) self.max_positions = int(os.getenv("MAX_POSITIONS", "5")) self.max_order_value = float(os.getenv("MAX_ORDER_VALUE", "10000.0")) def check_daily_loss(self, cursor) -> bool: """DB-based realized P&L proxy — used as fallback.""" today_start = int( datetime.now(timezone.utc) .replace(hour=0, minute=0, second=0, microsecond=0) .timestamp() ) cursor.execute( "SELECT SUM(price * quantity) FROM trades WHERE action='SELL' AND timestamp >= ?", (today_start,), ) sell_value = cursor.fetchone()[0] or 0.0 cursor.execute( "SELECT SUM(price * quantity) FROM trades WHERE action='BUY' AND timestamp >= ?", (today_start,), ) buy_value = cursor.fetchone()[0] or 0.0 realized_loss = buy_value - sell_value return realized_loss < self.max_daily_loss async def check_daily_loss_live(self, ib, cursor) -> bool: """ Fetch today's RealizedPnL from IBKR account summary. Falls back to DB proxy on timeout or API error. """ try: accounts = ib.managedAccounts() if not accounts: logger.warning("No managed accounts returned — using DB P&L fallback") return self.check_daily_loss(cursor) account = accounts[0] summary = await asyncio.wait_for(ib.reqAccountSummaryAsync(), timeout=5.0) for item in summary: if item.tag == "RealizedPnL" and item.account == account: realized_pnl = float(item.value) passed = realized_pnl >= -self.max_daily_loss if not passed: logger.warning( f"Daily loss limit hit: RealizedPnL={realized_pnl:.2f}, " f"limit=-{self.max_daily_loss:.2f}" ) return passed except asyncio.TimeoutError: logger.warning("reqAccountSummaryAsync timed out — using DB P&L fallback") except Exception as exc: logger.warning(f"Live P&L check failed ({exc}) — using DB P&L fallback") return self.check_daily_loss(cursor) def check_position_count(self, ib) -> bool: positions = ib.positions() open_positions = [p for p in positions if p.position != 0] return len(open_positions) < self.max_positions def check_order_value(self, quantity, price) -> bool: return float(quantity) * float(price) <= self.max_order_value async def run_all_checks(self, ib, cursor, quantity, price) -> dict: failed = [] if not await self.check_daily_loss_live(ib, cursor): failed.append("daily_loss_limit_exceeded") if not self.check_position_count(ib): failed.append("max_positions_reached") if not self.check_order_value(quantity, price): failed.append("order_value_exceeds_limit") return {"passed": len(failed) == 0, "failed_checks": failed}