Skip to content

Commit

Permalink
add signals
Browse files Browse the repository at this point in the history
  • Loading branch information
robswc committed Mar 17, 2023
1 parent 25afc3c commit 8f09b87
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 14 deletions.
13 changes: 11 additions & 2 deletions app/api/api_v1/endpoints/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from components import Strategy
from components.backtest.backtest import BacktestResult
from components.ohlc import DataAdapter
from components.ohlc.symbol import Symbol

router = APIRouter()

Expand Down Expand Up @@ -57,4 +56,14 @@ async def run_strategy(request: RunStrategyRequest):
ohlc = da.get_data(data_path, symbol="AAPL")

backtest_result, plots = strategy.run(data=ohlc, parameters=parameters, plots=True)
return RunStrategyResponse(backtest=backtest_result, plots=[p.as_dict() for p in plots])
return RunStrategyResponse(backtest=backtest_result, plots=[p.as_dict() for p in plots])

class SignalsRequest(BaseModel):
signal_type: str
strategy: RunStrategyRequest

@router.post("/signals")
async def run_signals(request: SignalsRequest):
"""Run signals"""
print(request.signal_type)
pass
11 changes: 2 additions & 9 deletions app/components/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ def test(self):


# use concurrent futures to test orders in parallel

# original code
# for p in positions:
# p.test(ohlc=self.data)
# print(str(p))

# new code
logger.debug(f'Testing {len(positions)} positions in parallel...')
with concurrent.futures.ThreadPoolExecutor() as executor:
for p in positions:
Expand All @@ -73,13 +66,13 @@ def test(self):
# calculate win/loss ratio
losing_trades = len([p for p in positions if p.pnl < 0])
winning_trades = len([p for p in positions if p.pnl > 0])
wl_ratio = round(winning_trades / losing_trades, 2) if losing_trades > 0 else 0 if winning_trades == 0 else 1
wl_ratio = round(winning_trades / losing_trades, 2)

# create backtest result
self.result = BacktestResult(
pnl=sum([position.pnl for position in positions]),
wl_ratio=wl_ratio,
trades=0,
trades=len(positions),
winning_trades=winning_trades,
losing_trades=losing_trades,
positions=positions,
Expand Down
13 changes: 12 additions & 1 deletion app/components/orders/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
Similar to Alpaca-py
"""


class OrderValidationError(Exception):
pass


def abbr_type(order_type):
if order_type == OrderType.MARKET:
return "mkt"
Expand All @@ -21,6 +23,7 @@ def abbr_type(order_type):
elif order_type == OrderType.STOP:
return "stp"


class TimeInForce(str, Enum):
DAY = "day"
GTC = "gtc"
Expand All @@ -29,17 +32,20 @@ class TimeInForce(str, Enum):
IOC = "ioc"
FOK = "fok"


class OrderType(str, Enum):
MARKET = "market"
LIMIT = "limit"
STOP = "stop"
STOP_LIMIT = "stop_limit"
TRAILING_STOP = "trailing_stop"


class OrderSide(str, Enum):
BUY = "buy"
SELL = "sell"


class Order(BaseModel):
id: Optional[str]
timestamp: int
Expand All @@ -52,6 +58,7 @@ class Order(BaseModel):
extended_hours: Optional[bool]
client_order_id: Optional[str]
filled_avg_price: Optional[float]

# take_profit: Optional[TakeProfitRequest]
# stop_loss: Optional[StopLossRequest]

Expand All @@ -77,6 +84,11 @@ def __hash__(self):
def get_id(self):
return hashlib.md5(str(hash(self)).encode()).hexdigest()

# will eventually make this a proper attribute
@property
def price(self):
return self.filled_avg_price

class Config:
schema_extra = {
"example": {
Expand All @@ -95,7 +107,6 @@ def __init__(self, **data):
logger.error(e)
raise e


# if valid, set id
self.id = self.get_id()

Expand Down
10 changes: 9 additions & 1 deletion app/components/orders/position.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from typing import Optional, List

from loguru import logger
from pydantic import BaseModel

from components.orders.order import Order
Expand Down Expand Up @@ -35,6 +36,7 @@ class Position(BaseModel):

def _get_effect(self, order: Order):
"""Get the effect of an order on a position."""
print(self.size, self.size + order.qty)
if abs(self.size) < abs(self.size + order.qty):
return PositionEffect.ADD
else:
Expand All @@ -44,7 +46,7 @@ def _get_size(self):
"""Get the size of all orders in the position."""
return sum([o.qty for o in self.orders])

def test(self, ohlc: 'OHLC'):
def test(self, ohlc: 'OHLC' = None):
# iterate over all orders, handling each one
for o in self.orders:
self.handle_order(o)
Expand All @@ -68,6 +70,10 @@ def add_closing_order(self, ohlc: 'OHLC'):

def handle_order(self, order: Order):

if order.type == 'stop':
logger.warning('Stop orders are not yet supported!')
return

# if the position is missing a side, set it to the side of the first order
if self.side is None:
self.side = order.side
Expand All @@ -76,6 +82,7 @@ def handle_order(self, order: Order):
effect = self._get_effect(order)

if effect == PositionEffect.ADD:
print('add')
# if the position is opened, set the opened timestamp
self.opened_timestamp = order.timestamp if not self.opened_timestamp else self.opened_timestamp
# since the position is added, we need to calculate the cost basis
Expand All @@ -84,6 +91,7 @@ def handle_order(self, order: Order):
self.average_entry_price = self.cost_basis / (self.size + order.qty)

if effect == PositionEffect.REDUCE:
print('reduce')
# if the position is closed, set the closed timestamp
self.closed_timestamp = order.timestamp
# since the position is reduced, we need to calculate the realized pnl
Expand Down
31 changes: 31 additions & 0 deletions app/components/orders/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Optional

from pydantic import BaseModel


class Signal(BaseModel):
order_type: Optional[str] = None
side: Optional[str] = None
quantity: Optional[int] = None
# symbol: Optional[str] = None
price: Optional[float] = None

def from_position(self, position: 'Position'):
self.order_type = position.orders[0].type
self.side = 'sell' if position.side == 'sell' else 'buy'
self.quantity = position.orders[0].qty
self.price = position.average_entry_price
return self



class BracketSignal(Signal):
stop_loss: Optional[float] = None
take_profit: Optional[float] = None

def from_position(self, position: 'Position'):
super().from_position(position)
# TODO: add validation for these
self.stop_loss = [o for o in position.orders if o.type == 'stop'][0].price
self.take_profit = [o for o in position.orders if o.type == 'limit'][0].price
return self
3 changes: 2 additions & 1 deletion app/components/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def run(self, data: OHLC, parameters: dict = None, **kwargs):
b = Backtest(strategy=self, data=data)

# runs the backtest
b.test()
if self.positions.positions or self.orders.orders:
b.test()

# run after methods
for method in self._after_methods:
Expand Down
66 changes: 66 additions & 0 deletions app/tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from components.orders.order import Order
from components.orders.position import Position
from components.orders.signals import Signal, BracketSignal

ROOT_ORDER = Order(
type='market',
side='buy',
qty=100,
symbol='AAPL',
filled_avg_price=100,
timestamp=1000
)

class TestSignals:
def test_basic_signal(self):
p = Position(
orders=[ROOT_ORDER],
)

p.test()

s = Signal().from_position(p)

assert s.order_type == 'market'
assert s.side == 'buy'
assert s.quantity == 100
assert s.price == 100


def test_bracket_signal(self):

stop_order = Order(
type='stop',
side='sell',
qty=-100,
symbol='AAPL',
filled_avg_price=90,
timestamp=1100,
)
limit_order = Order(
type='limit',
side='sell',
qty=-100,
symbol='AAPL',
filled_avg_price=110,
timestamp=1100,
)


p = Position(
orders=[ROOT_ORDER, stop_order, limit_order],
)

p.test()

s = BracketSignal().from_position(p)

assert s.order_type == 'market'
assert s.side == 'buy'
assert s.price == 100
assert s.stop_loss == 90
assert s.take_profit == 110

# check that serializing to JSON works as expected
assert s.json() == '{"order_type": "market", "side": "buy", "quantity": 100, "price": 100.0, ' \
'"stop_loss": 90.0, "take_profit": 110.0}'

0 comments on commit 8f09b87

Please sign in to comment.