initial commit
This commit is contained in:
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Phase 2 test suite — uses TestClient with mocked IB and DB dependencies.
|
||||
Run: pytest tests/test_phase2.py -v
|
||||
"""
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure project root is on sys.path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# ── Minimal ib_async stubs so the module imports without a real installation ──
|
||||
def _make_ib_async_stub():
|
||||
mod = types.ModuleType("ib_async")
|
||||
|
||||
class _FakeIB:
|
||||
def positions(self):
|
||||
return []
|
||||
async def reqAccountSummaryAsync(self):
|
||||
return []
|
||||
def placeOrder(self, *a, **kw):
|
||||
return mock.MagicMock(isDone=lambda: False, fills=[])
|
||||
def cancelOrder(self, *a, **kw):
|
||||
pass
|
||||
|
||||
mod.IB = _FakeIB
|
||||
mod.Stock = mock.MagicMock
|
||||
mod.MarketOrder = mock.MagicMock
|
||||
mod.LimitOrder = mock.MagicMock
|
||||
mod.ScannerSubscription = mock.MagicMock
|
||||
|
||||
def _bracket_order(action, qty, lp, tp, sl):
|
||||
return [mock.MagicMock(), mock.MagicMock(), mock.MagicMock()]
|
||||
|
||||
mod.BracketOrder = mock.MagicMock
|
||||
return mod, _FakeIB, _bracket_order
|
||||
|
||||
|
||||
_ib_mod, _FakeIB, _bracket_fn = _make_ib_async_stub()
|
||||
sys.modules.setdefault("ib_async", _ib_mod)
|
||||
|
||||
# ── Patch dotenv so .env is not required during tests ──
|
||||
sys.modules.setdefault(
|
||||
"dotenv",
|
||||
types.SimpleNamespace(load_dotenv=lambda: None),
|
||||
)
|
||||
|
||||
# ── In-memory SQLite DB for tests ──
|
||||
_test_conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
_test_cursor = _test_conn.cursor()
|
||||
_test_cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
quantity INTEGER NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
_test_conn.commit()
|
||||
_test_lock = threading.Lock()
|
||||
|
||||
import dependencies # noqa: E402 — after stubs are in sys.modules
|
||||
|
||||
_fake_ib = _FakeIB()
|
||||
dependencies.setup_dependencies(_fake_ib, _test_conn, _test_cursor, _test_lock)
|
||||
|
||||
# ── Import app AFTER dependencies are wired ──
|
||||
from fastapi.testclient import TestClient # noqa: E402
|
||||
|
||||
# Bypass lifespan for tests
|
||||
from fastapi import FastAPI # noqa: E402
|
||||
from routers import charts, scanner, trades, portfolio # noqa: E402
|
||||
|
||||
_app = FastAPI()
|
||||
_app.include_router(charts.router)
|
||||
_app.include_router(scanner.router)
|
||||
_app.include_router(trades.router)
|
||||
_app.include_router(portfolio.router)
|
||||
|
||||
client = TestClient(_app, raise_server_exceptions=False)
|
||||
|
||||
WRONG_SECRET = "wrong_secret"
|
||||
RIGHT_SECRET = "your_super_secret_string_123"
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestHomePage:
|
||||
def test_home_200(self):
|
||||
r = client.get("/")
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_home_has_chart(self):
|
||||
r = client.get("/")
|
||||
assert "chart" in r.text.lower()
|
||||
|
||||
|
||||
class TestPortfolio:
|
||||
def test_portfolio_page_200(self):
|
||||
r = client.get("/portfolio")
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_portfolio_data_is_list(self):
|
||||
r = client.get("/portfolio/data")
|
||||
assert r.status_code == 200
|
||||
assert isinstance(r.json(), list)
|
||||
|
||||
def test_portfolio_pnl_returns_dict(self):
|
||||
r = client.get("/portfolio/pnl")
|
||||
assert r.status_code == 200
|
||||
assert isinstance(r.json(), dict)
|
||||
|
||||
|
||||
class TestRiskStatus:
|
||||
def test_risk_status_200(self):
|
||||
r = client.get("/risk/status")
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_risk_status_has_limits(self):
|
||||
r = client.get("/risk/status")
|
||||
body = r.json()
|
||||
assert "limits" in body
|
||||
assert "max_daily_loss" in body["limits"]
|
||||
assert "max_positions" in body["limits"]
|
||||
assert "max_order_value" in body["limits"]
|
||||
|
||||
|
||||
class TestWebhookAuth:
|
||||
def test_wrong_secret_403(self):
|
||||
r = client.post(
|
||||
"/webhook",
|
||||
json={
|
||||
"secret": WRONG_SECRET,
|
||||
"symbol": "AAPL",
|
||||
"strategy": {"order_action": "BUY", "order_contracts": 1},
|
||||
},
|
||||
)
|
||||
assert r.status_code == 403
|
||||
|
||||
def test_invalid_action_400(self):
|
||||
r = client.post(
|
||||
"/webhook",
|
||||
json={
|
||||
"secret": RIGHT_SECRET,
|
||||
"symbol": "AAPL",
|
||||
"strategy": {"order_action": "HOLD", "order_contracts": 1},
|
||||
},
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
class TestAdvancedOrders:
|
||||
def test_limit_order_missing_price_400(self):
|
||||
r = client.post(
|
||||
"/webhook",
|
||||
json={
|
||||
"secret": RIGHT_SECRET,
|
||||
"symbol": "AAPL",
|
||||
"order_type": "LIMIT",
|
||||
"strategy": {"order_action": "BUY", "order_contracts": 1},
|
||||
},
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_bracket_order_missing_fields_400(self):
|
||||
r = client.post(
|
||||
"/webhook",
|
||||
json={
|
||||
"secret": RIGHT_SECRET,
|
||||
"symbol": "AAPL",
|
||||
"order_type": "BRACKET",
|
||||
"limit_price": 150.0,
|
||||
"strategy": {"order_action": "BUY", "order_contracts": 1},
|
||||
},
|
||||
)
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
class TestTradeLog:
|
||||
def test_tradelog_200(self):
|
||||
r = client.get("/tradelog")
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_tradelog_shows_table(self):
|
||||
r = client.get("/tradelog")
|
||||
assert "Trade Log" in r.text
|
||||
|
||||
|
||||
class TestScanner:
|
||||
def test_scanner_page_200(self):
|
||||
r = client.get("/scanner")
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_scanner_export_requires_ib(self):
|
||||
# Without a real IB connection the export endpoint still returns a response
|
||||
r = client.post("/scanner/export", json={"scan_type": "HOT_BY_VOLUME"})
|
||||
# Either a CSV or an error JSON — both are valid without IB
|
||||
assert r.status_code in (200, 500)
|
||||
|
||||
|
||||
class TestRiskManagerUnit:
|
||||
def test_order_value_passes(self):
|
||||
from risk_manager import RiskManager
|
||||
rm = RiskManager()
|
||||
assert rm.check_order_value(1, 100.0) is True
|
||||
|
||||
def test_order_value_fails(self):
|
||||
from risk_manager import RiskManager
|
||||
rm = RiskManager()
|
||||
rm.max_order_value = 100.0
|
||||
assert rm.check_order_value(10, 200.0) is False
|
||||
|
||||
def test_position_count_passes(self):
|
||||
from risk_manager import RiskManager
|
||||
rm = RiskManager()
|
||||
assert rm.check_position_count(_fake_ib) is True
|
||||
Reference in New Issue
Block a user