Skip to content

Commit

Permalink
backtesting updates
Browse files Browse the repository at this point in the history
  • Loading branch information
robswc committed Mar 16, 2023
1 parent ec93ac5 commit 8e5d198
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 51 deletions.
19 changes: 13 additions & 6 deletions app/api/api_v1/endpoints/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

class RunStrategyRequest(BaseModel):
strategy: str
parameters: List[dict]
adapter: str
data: str

Expand All @@ -32,22 +33,28 @@ async def get_strategy(name: str):
except ValueError:
return Response(status_code=404)

@router.post("/", response_model=BacktestResult)
class RunStrategyResponse(BaseModel):
backtest: BacktestResult
plots: List[dict]

@router.post("/", response_model=RunStrategyResponse)
async def run_strategy(request: RunStrategyRequest):
"""Run a strategy with data"""

# get arguments from request
name = request.strategy
data_adapter_name = request.adapter
data = str(request.data)
parameters = {p['name']: p['value'] for p in request.parameters}

# get strategy and data adapter
da = DataAdapter.objects.get(name=data_adapter_name)
strategy = Strategy.objects.get(name=name)

# start from app root
# start from app root, get data
app_path = Path(__file__).parent.parent.parent.parent
data_path = app_path / data

ohlc = da.get_data(data_path, symbol="AAPL")


backtest_result = strategy.run(data=ohlc)
return backtest_result
backtest_result, plots = strategy.run(data=ohlc, parameters=parameters, plots=True)
return RunStrategyResponse(backtest=backtest_result, plots=[p.as_dict() for p in plots])
70 changes: 64 additions & 6 deletions app/components/backtest/backtest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
from typing import Optional, List

from loguru import logger
from pydantic import BaseModel

from components.orders.position import Position
from components.orders.order import Order
from components.orders.position import Position, PositionClosedException


def get_effect(position: Position, order: Order):
"""Get the effect of an order on a position."""
if position.get_size() > 0:
if order.side == 'buy':
return 'increase'
elif order.side == 'sell':
return 'decrease'
elif position.get_size() < 0:
if order.side == 'buy':
return 'decrease'
elif order.side == 'sell':
return 'increase'
else:
return 'increase'

class BacktestResult(BaseModel):
pnl: float
wl_ratio: float
Expand All @@ -15,19 +32,60 @@ class BacktestResult(BaseModel):
winning_trades: int
losing_trades: int
positions: List[Position]
orders: List[Order]

class Backtest:
def __init__(self, data, strategy):
self.data = data
self.strategy = strategy
self.result = None

def _sort_orders(self, orders: List[Order]):
return sorted(orders, key=lambda x: x.timestamp)

def test(self):
logger.debug(f'Starting backtest for strategy {self.strategy.name}')


# sort orders by timestamp
sorted_orders = self._sort_orders(self.strategy.orders.all())

# positions
positions = []
last_position = None

# build the positions
for order in sorted_orders:
# main logic
try:
last_position.add_order(order)
except PositionClosedException:
p = Position()
p.add_order(order)
positions.append(p)
last_position = p
except AttributeError:
p = Position()
p.add_order(order)
positions.append(p)
last_position = p

# loop through each position
for position in positions:
position.test(ohlc=self.data)

# calculate win/loss ratio
losing_trades = len([position for position in positions if position.pnl < 0])
winning_trades = len([position for position in positions if position.pnl > 0])
wl_ratio = round(winning_trades / losing_trades, 2) if losing_trades > 0 else 0 if winning_trades == 0 else 1

# create backtest result
self.result = BacktestResult(
pnl=0,
wl_ratio=0,
pnl=sum([position.pnl for position in positions]),
wl_ratio=wl_ratio,
trades=0,
winning_trades=0,
losing_trades=0,
positions=[],
winning_trades=winning_trades,
losing_trades=losing_trades,
positions=positions,
orders=self.strategy.orders.all(),
)
2 changes: 1 addition & 1 deletion app/components/orders/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Order(BaseModel):
id: Optional[str]
timestamp: int
symbol: str
qty: Optional[int]
qty: int
notional: Optional[float]
side: OrderSide
type: OrderType
Expand Down
87 changes: 85 additions & 2 deletions app/components/orders/position.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import random
from typing import Optional, List

from pydantic import BaseModel

Expand All @@ -8,9 +9,91 @@
class PositionValidationException(Exception):
pass

class PositionClosedException(Exception):
pass

class Position(BaseModel):
root_order: Order
orders: List[Order] = []
closed: bool = False
average_entry_price: Optional[float] = None
average_exit_price: Optional[float] = None
size: Optional[int] = None
side: Optional[str] = None
pnl: Optional[float] = None
timestamp: Optional[int] = None

# TODO: restructure position to be more efficient

def test(self, ohlc: 'OHLC'):
self.pnl = self.calc_pnl()

def add_order(self, order: Order):
if self.closed:
raise PositionClosedException('Position is already closed')

# else we add the order
self.orders.append(order)

long_qty = sum([o.qty for o in self.orders if o.side == 'buy'])
short_qty = sum([o.qty for o in self.orders if o.side == 'sell'])
self.closed = long_qty == short_qty

def get_size(self):
return sum([order.qty for order in self.orders])

def get_all_buy_orders(self):
return [order for order in self.orders if order.side == 'buy']

def get_all_sell_orders(self):
return [order for order in self.orders if order.side == 'sell']

def get_average_entry_price(self):
if self.orders[0].side == 'buy':
return sum([o.filled_avg_price for o in self.get_all_buy_orders()]) / len(self.get_all_buy_orders())
else:
return sum([o.filled_avg_price for o in self.get_all_sell_orders()]) / len(self.get_all_sell_orders())

def get_average_exit_price(self):
if self.orders[0].side == 'buy':
return sum([o.filled_avg_price for o in self.get_all_sell_orders()]) / len(self.get_all_sell_orders())
else:
return sum([o.filled_avg_price for o in self.get_all_buy_orders()]) / len(self.get_all_buy_orders())

def get_side(self):
return self.orders[0].side

def get_timestamp(self):
return self.orders[-1].timestamp

def calc_pnl(self):
if self.closed:
# root order direction
root_order_side = self.orders[0].side

# calculate pnl
if root_order_side == 'buy':
return (self.get_average_exit_price() - self.get_average_entry_price()) * self.get_size()
else:
return (self.get_average_entry_price() - self.get_average_exit_price()) * self.get_size()
return 0

def dict(self, **kwargs):
d = super().dict(**kwargs)
if self.closed:
d['average_entry_price'] = self.get_average_entry_price()
d['average_exit_price'] = self.get_average_exit_price()
d['size'] = self.get_size() / 2
d['side'] = self.get_side()
d['timestamp'] = self.get_timestamp()
d['pnl'] = self.calc_pnl()
else:
d['average_entry_price'] = self.get_average_entry_price()
d['size'] = self.get_size()
d['side'] = self.get_side()
d['timestamp'] = self.get_timestamp()
d['pnl'] = self.calc_pnl()
return d


class BracketPosition(BaseModel):
take_profit: Optional[Order]
Expand Down
6 changes: 5 additions & 1 deletion app/components/strategy/builtins/ta/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
class Logic:
@staticmethod
def crossover(a: 'Series', b: 'Series') -> bool:
return a > b and a.shift(1) < b.shift(1)
return a > b and a.shift(1) < b.shift(1)

@staticmethod
def crossunder(a: 'Series', b: 'Series') -> bool:
return a < b and a.shift(1) > b.shift(1)
9 changes: 9 additions & 0 deletions app/components/strategy/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ def __init__(self, data: Union[list, pd.Series]):
def advance_index(self):
self._loop_index += 1

def as_list(self):
if isinstance(self._data, list):
return self._data
if isinstance(self._data, pd.Series):
df = self._data.copy()
# replace NaN with previous value
df.fillna(method='backfill', inplace=True)
return df.tolist()

def __repr__(self):
return str(float(self))

Expand Down
55 changes: 49 additions & 6 deletions app/components/strategy/strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import json
import sys
from typing import List, Union

Expand All @@ -13,6 +14,24 @@
from components.parameter import BaseParameter, Parameter, ParameterModel
from components.strategy.decorators import extract_decorators

class PlotConfig(BaseModel):
color: str = 'blue'
type: str = 'line'
lineWidth: int = 1

class Plot:
def __init__(self, series: 'Series', **kwargs):
self.data = series.as_list()
self.name = kwargs.get('name', None)
self.config = PlotConfig(**kwargs)

def as_dict(self):
return {
'name': self.name,
'data': self.data,
'config': self.config.dict()
}

class StrategyManager(ComponentManager):
_components = []

Expand Down Expand Up @@ -60,6 +79,13 @@ def __init__(self, data: Union[OHLC, None] = None):
# each strategy gets a new order manager
self.orders = OrderManager(self)

# each strategy gets plots
self.plots = []

def export_plots(self, plots: List[Plot]):
self.plots = plots


def as_model(self) -> StrategyModel:
return StrategyModel(
name=self.name,
Expand Down Expand Up @@ -108,7 +134,22 @@ def _get_all_series_data(self):
series.append(self.__getattribute__(attr))
return series

def run(self, data: OHLC, parameters: dict = None):
def _get_all_plots(self):
# will use in the future to get all plots
plots = []
for attr in dir(self):
if isinstance(self.__getattribute__(attr), Plot):
plots.append(self.__getattribute__(attr))
return plots

def run(self, data: OHLC, parameters: dict = None, **kwargs):

# set parameters
if parameters is not None:
for p in self.parameters:
if p.name in parameters:
p.value = parameters[p.name]

self.__init__(data=data)
self._setup_data(data)
self._create_series()
Expand All @@ -130,10 +171,7 @@ def run(self, data: OHLC, parameters: dict = None):
self.data.advance_index()

# handle backtest
b = Backtest(
strategy=self,
data=data,
)
b = Backtest(strategy=self, data=data)

# runs the backtest
b.test()
Expand All @@ -142,7 +180,12 @@ def run(self, data: OHLC, parameters: dict = None):
for method in self._after_methods:
getattr(self, method)()

# return the results of the backtest
# get all plots
plots = self.plots

if kwargs.get('plots', False):
logger.debug(f'Requested plots, found {len(plots)}')
return b.result, plots
return b.result


Loading

0 comments on commit 8e5d198

Please sign in to comment.