Skip to content

Commit

Permalink
change tests, change managers
Browse files Browse the repository at this point in the history
  • Loading branch information
robswc committed Mar 14, 2023
1 parent 064ac25 commit 3095967
Show file tree
Hide file tree
Showing 17 changed files with 175 additions and 76 deletions.
2 changes: 1 addition & 1 deletion app/api/api_v1/endpoints/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
@router.get("/strategy")
async def list_all_strategies():
strategies = Strategy.objects.all()
return [s().as_model() for s in strategies]
return [s.as_model() for s in strategies]
Empty file.
24 changes: 24 additions & 0 deletions app/components/manager/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from loguru import logger


class ComponentManager:

_components = []

@classmethod
def register(cls, component):
if component in cls._components:
return
cls._components.append(component)
logger.debug(f'Registered component {component} ({component.__module__})')

@classmethod
def all(cls):
return [o() for o in cls._components]

@classmethod
def get(cls, name):
for component in cls._components:
if component.__name__ == name:
return component()
raise ValueError(f'Component {name} not found.')
21 changes: 16 additions & 5 deletions app/components/ohlc/data_adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
from loguru import logger

from components.manager.manager import ComponentManager
from components.ohlc import OHLC

ADAPTER_MAP = {}
class DataAdapterManager(ComponentManager):
_components = []

class DataAdapter:
"""Base class for data adapters."""
def __new__(cls, *args, **kwargs):
if cls not in ADAPTER_MAP:
ADAPTER_MAP[cls] = super().__new__(cls)
return ADAPTER_MAP[cls]

objects = DataAdapterManager

@classmethod
def register(cls):
cls.objects.register(cls)

def __init__(self):
self.name = self.__class__.__name__
self.objects.register(self)

# register
self.register()

def get_data(self, *args, **kwargs) -> OHLC:
raise NotImplementedError


class CSVAdapter(DataAdapter):
"""CSV Adapter, loads data from a csv file."""

def get_data(self, path: str, symbol: str):
"""
Loads data from a csv file.
Expand Down
20 changes: 20 additions & 0 deletions app/components/orders/order.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime
import hashlib
from enum import Enum
from typing import Optional

Expand Down Expand Up @@ -30,6 +32,8 @@ class OrderSide(str, Enum):
SELL = "sell"

class Order(BaseModel):
id: Optional[str]
timestamp: Optional[int]
symbol: str
qty: Optional[float]
notional: Optional[float]
Expand All @@ -38,12 +42,25 @@ class Order(BaseModel):
time_in_force: TimeInForce = TimeInForce.GTC
extended_hours: Optional[bool]
client_order_id: Optional[str]
filled_avg_price: Optional[float]
# take_profit: Optional[TakeProfitRequest]
# stop_loss: Optional[StopLossRequest]

def __str__(self):
dt = datetime.datetime.fromtimestamp(self.timestamp).strftime("%Y-%m-%d %H:%M:%S")
return f'{self.type} {self.side} {self.qty} {self.symbol} @ {self.filled_avg_price} ({dt})'

def __hash__(self):
h = hashlib.sha256(f'{self.timestamp}{self.symbol}{self.qty}{self.side}{self.type}'.encode())
return int(h.hexdigest(), 16)

def get_id(self):
return hashlib.md5(str(hash(self)).encode()).hexdigest()

class Config:
schema_extra = {
"example": {
"id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6",
"symbol": "AAPL",
"qty": 100,
"side": "buy",
Expand All @@ -61,3 +78,6 @@ def __init__(self, **data):
raise OrderValidationError("qty must be greater than 0.")
if self.notional is not None and self.notional <= 0:
raise OrderValidationError("notional must be greater than 0.")

# if valid, set id
self.id = self.get_id()
56 changes: 19 additions & 37 deletions app/components/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,69 +6,51 @@
from loguru import logger
from pydantic import BaseModel

from components.manager.manager import ComponentManager
from components.ohlc import OHLC
from components.orders.order_manager import OrderManager
from components.parameter import BaseParameter, Parameter, ParameterModel
from components.strategy.decorators import extract_decorators

class StrategyManager(ComponentManager):
_components = []

class StrategyModel(BaseModel):
name: str
parameters: List[ParameterModel]


class StrategyManager:
_strategies = []

@classmethod
def register(cls, strategy):
# check if the strategy is already registered
if strategy in cls._strategies:
return
cls._strategies.append(strategy)
logger.debug(f'Registered strategy {cls.__name__} ({strategy.__module__})')

@classmethod
def all(cls):
return [s for s in cls._strategies]

@classmethod
def get(cls, name):
for s in cls._strategies:
if s.name == name:
return s
raise Exception(f'No strategy found with name {name}')


class BaseStrategy:
parameters: List[BaseParameter] = []
_loop_index = 0

# strategy decorators
_step_methods = []
_before_methods = []
_after_methods = []

# each strategy gets a new order manager
orders = OrderManager()

objects = StrategyManager

@classmethod
def register(cls):
cls.objects.register(cls)

def __init__(self, data: Union[OHLC, None] = None):
self.name = self.__class__.__name__
if data is None:
data = OHLC()
self.data = data
self.parameters: List[BaseParameter] = []
self._set_parameters()

# register the strategy
self.objects.register(self)
self.register()

self._loop_index = 0

# strategy decorators
self._step_methods = []
self._before_methods = []
self._after_methods = []

befores, steps, afters = extract_decorators(self)
self._before_methods = befores
self._step_methods = steps
self._after_methods = afters

# each strategy gets a new order manager
self.orders = OrderManager()

def as_model(self) -> StrategyModel:
return StrategyModel(
name=self.name,
Expand Down
1 change: 0 additions & 1 deletion app/core/commands/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ def list_strategies():
List all strategies
"""
typer.echo("Listing strategies")
from utils import strategy_loader # noqa: F401
ListStrategies().handle()
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import FastAPI

from api.api_v1.api import api_router
from utils import strategy_loader # noqa: F401
from utils.loaders import load_all # noqa: F401

app = FastAPI(
title="Stratis API",
Expand Down
8 changes: 8 additions & 0 deletions app/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from components.ohlc import CSVAdapter
from components.ohlc.data_adapters.api_adapter import APIDataAdapter

# add any custom data adapters here
DATA_ADAPTERS = [
CSVAdapter,
APIDataAdapter
]
15 changes: 8 additions & 7 deletions app/storage/strategies/examples/sma_cross_over.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from components import Parameter
from components import Strategy, on_step
from components.strategy import Series, ta
from components.strategy import ta


class SMACrossOver(Strategy):
Expand All @@ -9,10 +9,11 @@ class SMACrossOver(Strategy):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# all_close = self.data.all('close')
# self.sma_fast = ta.sma(all_close, int(self.sma_fast_length))
# self.sma_slow = ta.sma(all_close, int(self.sma_slow_length))
all_close = self.data.all('close')
self.sma_fast = ta.sma(all_close, int(self.sma_fast_length))
self.sma_slow = ta.sma(all_close, int(self.sma_slow_length))

# @on_step
# def check_for_crossover(self):
# print(float(self.sma_fast), float(self.sma_slow))
@on_step
def check_for_crossover(self):
# add logic to crossover here
pass
4 changes: 0 additions & 4 deletions app/tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@ def test_init_data_adapter(self):
from components.ohlc import DataAdapter
adapter = DataAdapter()
assert adapter.name == 'DataAdapter'
adapter2 = DataAdapter()
assert adapter == adapter2

# test csv adapter
from components.ohlc import CSVAdapter
csv_adapter = CSVAdapter()
assert csv_adapter.name == 'CSVAdapter'
csv_adapter2 = CSVAdapter()
assert csv_adapter == csv_adapter2

def test_csv_adapter(self):
from components.ohlc import CSVAdapter
Expand Down
37 changes: 37 additions & 0 deletions app/tests/test_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from datetime import datetime

from components.orders.order import Order, OrderSide as Side, OrderType


class TestOrders:
def test_hashing(self):
ts = datetime.now().timestamp()
fake_order_1 = Order(
symbol='BTCUSDT',
side=Side.BUY,
type=OrderType.MARKET,
qty=1,
timestamp=ts,
)
fake_order_2 = Order(
symbol='BTCUSDT',
side=Side.BUY,
type=OrderType.MARKET,
qty=1,
timestamp=ts,
)
assert fake_order_1 == fake_order_2

fake_order_3 = Order(
symbol='BTCUSDT',
side=Side.BUY,
type=OrderType.MARKET,
qty=2,
datetime=ts,
)

assert fake_order_1 != fake_order_3

# test IDs
assert fake_order_1.get_id() == fake_order_2.get_id()
assert fake_order_1.get_id() != fake_order_3.get_id()
9 changes: 3 additions & 6 deletions app/tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_initializing_examples(self):
from storage.strategies.examples.sma_cross_over import SMACrossOver
strategy = SMACrossOver()
assert strategy.name == 'SMACrossOver'
assert int(strategy.sma_fast) == 10
assert int(strategy.sma_slow) == 20
assert int(strategy.sma_fast_length) == 10
assert int(strategy.sma_slow_length) == 60

def test_run_strategy(self):
from storage.strategies.examples.sma_cross_over import SMACrossOver
Expand All @@ -32,13 +32,10 @@ def test_ohlc_demo(self):
)

def test_load_strategies(self):
from utils.strategy_loader import import_all_strategies
from utils.loaders.strategy_loader import import_all_strategies
strategies = import_all_strategies()
assert len(strategies) > 0

def test_strategy_manager(self):
from utils import strategy_loader
assert len(Strategy.objects.all()) > 0
for s in Strategy.objects.all():
print(s)
assert Strategy.objects.get('SMACrossOver').name == 'SMACrossOver'
Empty file added app/utils/loaders/__init__.py
Empty file.
36 changes: 36 additions & 0 deletions app/utils/loaders/load_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import inspect
from pathlib import Path
from importlib import import_module

from loguru import logger

from components.ohlc import DataAdapter
from components.strategy.strategy import BaseStrategy


def import_components(path, component_type):
logger.debug(f'Importing {component_type.__name__}(s) from {path}...')
components = []
app_path = Path(__file__).parent.parent.parent
paths = app_path.joinpath(path).rglob('*.py')
for path in paths:
module_name = path.as_posix().replace('/', '.').replace('.py', '').split('app.')[1]
module = import_module(module_name)

for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, component_type):
if obj.__name__ != component_type.__name__:
components.append(obj)
logger.info(f'\t->\t{obj.__name__} ({obj.__module__})')
return components

# load all components
data_adapters = import_components('components/ohlc/data_adapters', DataAdapter)
strategies = import_components('storage/strategies', BaseStrategy)

# register all components
for adapter in data_adapters:
adapter.register()

for strategy in strategies:
strategy.register()
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@

def import_all_strategies() -> List[Type[BaseStrategy]]:
strategies = []
# get all paths in the strategies folder, including subfolders, assuming sources root is app/
paths = Path(__file__).parent.parent.joinpath('storage/strategies').rglob('*.py')
app_path = Path(__file__).parent.parent.parent
paths = app_path.joinpath('storage/strategies').rglob('*.py')
for path in paths:

# get the module name from the path
module_name = path.as_posix().replace('/', '.').replace('.py', '')
module_name = module_name.split('app.')[1]
Expand Down
Loading

0 comments on commit 3095967

Please sign in to comment.