225 lines
6.6 KiB
Python
225 lines
6.6 KiB
Python
"""
|
|
Phase 2 test suite — uses TestClient with mocked IB and DB dependencies.
|
|
Run: pytest tests/test_phase2.py -v
|
|
"""
|
|
import sqlite3
|
|
import sys
|
|
import threading
|
|
import types
|
|
import unittest.mock as mock
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
# Ensure project root is on sys.path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
# ── Minimal ib_async stubs so the module imports without a real installation ──
|
|
def _make_ib_async_stub():
|
|
mod = types.ModuleType("ib_async")
|
|
|
|
class _FakeIB:
|
|
def positions(self):
|
|
return []
|
|
async def reqAccountSummaryAsync(self):
|
|
return []
|
|
def placeOrder(self, *a, **kw):
|
|
return mock.MagicMock(isDone=lambda: False, fills=[])
|
|
def cancelOrder(self, *a, **kw):
|
|
pass
|
|
|
|
mod.IB = _FakeIB
|
|
mod.Stock = mock.MagicMock
|
|
mod.MarketOrder = mock.MagicMock
|
|
mod.LimitOrder = mock.MagicMock
|
|
mod.ScannerSubscription = mock.MagicMock
|
|
|
|
def _bracket_order(action, qty, lp, tp, sl):
|
|
return [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()]
|
|
|
|
mod.BracketOrder = mock.MagicMock
|
|
return mod, _FakeIB, _bracket_order
|
|
|
|
|
|
_ib_mod, _FakeIB, _bracket_fn = _make_ib_async_stub()
|
|
sys.modules.setdefault("ib_async", _ib_mod)
|
|
|
|
# ── Patch dotenv so .env is not required during tests ──
|
|
sys.modules.setdefault(
|
|
"dotenv",
|
|
types.SimpleNamespace(load_dotenv=lambda: None),
|
|
)
|
|
|
|
# ── In-memory SQLite DB for tests ──
|
|
_test_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
|
_test_cursor = _test_conn.cursor()
|
|
_test_cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS trades (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
timestamp INTEGER NOT NULL,
|
|
symbol TEXT NOT NULL,
|
|
action TEXT NOT NULL,
|
|
quantity INTEGER NOT NULL,
|
|
price REAL NOT NULL,
|
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
""")
|
|
_test_conn.commit()
|
|
_test_lock = threading.Lock()
|
|
|
|
import dependencies # noqa: E402 — after stubs are in sys.modules
|
|
|
|
_fake_ib = _FakeIB()
|
|
dependencies.setup_dependencies(_fake_ib, _test_conn, _test_cursor, _test_lock)
|
|
|
|
# ── Import app AFTER dependencies are wired ──
|
|
from fastapi.testclient import TestClient # noqa: E402
|
|
|
|
# Bypass lifespan for tests
|
|
from fastapi import FastAPI # noqa: E402
|
|
from routers import charts, scanner, trades, portfolio # noqa: E402
|
|
|
|
_app = FastAPI()
|
|
_app.include_router(charts.router)
|
|
_app.include_router(scanner.router)
|
|
_app.include_router(trades.router)
|
|
_app.include_router(portfolio.router)
|
|
|
|
client = TestClient(_app, raise_server_exceptions=False)
|
|
|
|
WRONG_SECRET = "wrong_secret"
|
|
RIGHT_SECRET = "your_super_secret_string_123"
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
class TestHomePage:
|
|
def test_home_200(self):
|
|
r = client.get("/")
|
|
assert r.status_code == 200
|
|
|
|
def test_home_has_chart(self):
|
|
r = client.get("/")
|
|
assert "chart" in r.text.lower()
|
|
|
|
|
|
class TestPortfolio:
|
|
def test_portfolio_page_200(self):
|
|
r = client.get("/portfolio")
|
|
assert r.status_code == 200
|
|
|
|
def test_portfolio_data_is_list(self):
|
|
r = client.get("/portfolio/data")
|
|
assert r.status_code == 200
|
|
assert isinstance(r.json(), list)
|
|
|
|
def test_portfolio_pnl_returns_dict(self):
|
|
r = client.get("/portfolio/pnl")
|
|
assert r.status_code == 200
|
|
assert isinstance(r.json(), dict)
|
|
|
|
|
|
class TestRiskStatus:
|
|
def test_risk_status_200(self):
|
|
r = client.get("/risk/status")
|
|
assert r.status_code == 200
|
|
|
|
def test_risk_status_has_limits(self):
|
|
r = client.get("/risk/status")
|
|
body = r.json()
|
|
assert "limits" in body
|
|
assert "max_daily_loss" in body["limits"]
|
|
assert "max_positions" in body["limits"]
|
|
assert "max_order_value" in body["limits"]
|
|
|
|
|
|
class TestWebhookAuth:
|
|
def test_wrong_secret_403(self):
|
|
r = client.post(
|
|
"/webhook",
|
|
json={
|
|
"secret": WRONG_SECRET,
|
|
"symbol": "AAPL",
|
|
"strategy": {"order_action": "BUY", "order_contracts": 1},
|
|
},
|
|
)
|
|
assert r.status_code == 403
|
|
|
|
def test_invalid_action_400(self):
|
|
r = client.post(
|
|
"/webhook",
|
|
json={
|
|
"secret": RIGHT_SECRET,
|
|
"symbol": "AAPL",
|
|
"strategy": {"order_action": "HOLD", "order_contracts": 1},
|
|
},
|
|
)
|
|
assert r.status_code == 400
|
|
|
|
|
|
class TestAdvancedOrders:
|
|
def test_limit_order_missing_price_400(self):
|
|
r = client.post(
|
|
"/webhook",
|
|
json={
|
|
"secret": RIGHT_SECRET,
|
|
"symbol": "AAPL",
|
|
"order_type": "LIMIT",
|
|
"strategy": {"order_action": "BUY", "order_contracts": 1},
|
|
},
|
|
)
|
|
assert r.status_code == 400
|
|
|
|
def test_bracket_order_missing_fields_400(self):
|
|
r = client.post(
|
|
"/webhook",
|
|
json={
|
|
"secret": RIGHT_SECRET,
|
|
"symbol": "AAPL",
|
|
"order_type": "BRACKET",
|
|
"limit_price": 150.0,
|
|
"strategy": {"order_action": "BUY", "order_contracts": 1},
|
|
},
|
|
)
|
|
assert r.status_code == 400
|
|
|
|
|
|
class TestTradeLog:
|
|
def test_tradelog_200(self):
|
|
r = client.get("/tradelog")
|
|
assert r.status_code == 200
|
|
|
|
def test_tradelog_shows_table(self):
|
|
r = client.get("/tradelog")
|
|
assert "Trade Log" in r.text
|
|
|
|
|
|
class TestScanner:
|
|
def test_scanner_page_200(self):
|
|
r = client.get("/scanner")
|
|
assert r.status_code == 200
|
|
|
|
def test_scanner_export_requires_ib(self):
|
|
# Without a real IB connection the export endpoint still returns a response
|
|
r = client.post("/scanner/export", json={"scan_type": "HOT_BY_VOLUME"})
|
|
# Either a CSV or an error JSON — both are valid without IB
|
|
assert r.status_code in (200, 500)
|
|
|
|
|
|
class TestRiskManagerUnit:
|
|
def test_order_value_passes(self):
|
|
from risk_manager import RiskManager
|
|
rm = RiskManager()
|
|
assert rm.check_order_value(1, 100.0) is True
|
|
|
|
def test_order_value_fails(self):
|
|
from risk_manager import RiskManager
|
|
rm = RiskManager()
|
|
rm.max_order_value = 100.0
|
|
assert rm.check_order_value(10, 200.0) is False
|
|
|
|
def test_position_count_passes(self):
|
|
from risk_manager import RiskManager
|
|
rm = RiskManager()
|
|
assert rm.check_position_count(_fake_ib) is True
|