diff --git a/app/api/api_v1/endpoints/strategy.py b/app/api/api_v1/endpoints/strategy.py index 772e4a3..88c931e 100644 --- a/app/api/api_v1/endpoints/strategy.py +++ b/app/api/api_v1/endpoints/strategy.py @@ -9,4 +9,4 @@ @router.get("/strategy") async def list_all_strategies(): strategies = Strategy.objects.all() - return [s().as_model() for s in strategies] \ No newline at end of file + return [s.as_model() for s in strategies] \ No newline at end of file diff --git a/app/components/manager/__init__.py b/app/components/manager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/components/manager/manager.py b/app/components/manager/manager.py new file mode 100644 index 0000000..93c5c70 --- /dev/null +++ b/app/components/manager/manager.py @@ -0,0 +1,24 @@ +from loguru import logger + + +class ComponentManager: + + _components = [] + + @classmethod + def register(cls, component): + if component in cls._components: + return + cls._components.append(component) + logger.debug(f'Registered component {component} ({component.__module__})') + + @classmethod + def all(cls): + return [o() for o in cls._components] + + @classmethod + def get(cls, name): + for component in cls._components: + if component.__name__ == name: + return component() + raise ValueError(f'Component {name} not found.') \ No newline at end of file diff --git a/app/components/ohlc/data_adapters/adapter.py b/app/components/ohlc/data_adapters/adapter.py index 323e1ff..fa9d575 100644 --- a/app/components/ohlc/data_adapters/adapter.py +++ b/app/components/ohlc/data_adapters/adapter.py @@ -1,16 +1,26 @@ +from loguru import logger + +from components.manager.manager import ComponentManager from components.ohlc import OHLC -ADAPTER_MAP = {} +class DataAdapterManager(ComponentManager): + _components = [] class DataAdapter: """Base class for data adapters.""" - def __new__(cls, *args, **kwargs): - if cls not in ADAPTER_MAP: - ADAPTER_MAP[cls] = super().__new__(cls) - return ADAPTER_MAP[cls] + + objects = DataAdapterManager + + @classmethod + def register(cls): + cls.objects.register(cls) def __init__(self): self.name = self.__class__.__name__ + self.objects.register(self) + + # register + self.register() def get_data(self, *args, **kwargs) -> OHLC: raise NotImplementedError @@ -18,6 +28,7 @@ def get_data(self, *args, **kwargs) -> OHLC: class CSVAdapter(DataAdapter): """CSV Adapter, loads data from a csv file.""" + def get_data(self, path: str, symbol: str): """ Loads data from a csv file. diff --git a/app/components/orders/order.py b/app/components/orders/order.py index e445e96..abbeab7 100644 --- a/app/components/orders/order.py +++ b/app/components/orders/order.py @@ -1,3 +1,5 @@ +import datetime +import hashlib from enum import Enum from typing import Optional @@ -30,6 +32,8 @@ class OrderSide(str, Enum): SELL = "sell" class Order(BaseModel): + id: Optional[str] + timestamp: Optional[int] symbol: str qty: Optional[float] notional: Optional[float] @@ -38,12 +42,25 @@ class Order(BaseModel): time_in_force: TimeInForce = TimeInForce.GTC extended_hours: Optional[bool] client_order_id: Optional[str] + filled_avg_price: Optional[float] # take_profit: Optional[TakeProfitRequest] # stop_loss: Optional[StopLossRequest] + def __str__(self): + dt = datetime.datetime.fromtimestamp(self.timestamp).strftime("%Y-%m-%d %H:%M:%S") + return f'{self.type} {self.side} {self.qty} {self.symbol} @ {self.filled_avg_price} ({dt})' + + def __hash__(self): + h = hashlib.sha256(f'{self.timestamp}{self.symbol}{self.qty}{self.side}{self.type}'.encode()) + return int(h.hexdigest(), 16) + + def get_id(self): + return hashlib.md5(str(hash(self)).encode()).hexdigest() + class Config: schema_extra = { "example": { + "id": "b6b6b6b6-b6b6-b6b6-b6b6-b6b6b6b6b6b6", "symbol": "AAPL", "qty": 100, "side": "buy", @@ -61,3 +78,6 @@ def __init__(self, **data): raise OrderValidationError("qty must be greater than 0.") if self.notional is not None and self.notional <= 0: raise OrderValidationError("notional must be greater than 0.") + + # if valid, set id + self.id = self.get_id() diff --git a/app/components/strategy/strategy.py b/app/components/strategy/strategy.py index d873a55..e03f86d 100644 --- a/app/components/strategy/strategy.py +++ b/app/components/strategy/strategy.py @@ -6,69 +6,51 @@ from loguru import logger from pydantic import BaseModel +from components.manager.manager import ComponentManager from components.ohlc import OHLC from components.orders.order_manager import OrderManager from components.parameter import BaseParameter, Parameter, ParameterModel from components.strategy.decorators import extract_decorators +class StrategyManager(ComponentManager): + _components = [] class StrategyModel(BaseModel): name: str parameters: List[ParameterModel] - -class StrategyManager: - _strategies = [] - - @classmethod - def register(cls, strategy): - # check if the strategy is already registered - if strategy in cls._strategies: - return - cls._strategies.append(strategy) - logger.debug(f'Registered strategy {cls.__name__} ({strategy.__module__})') - - @classmethod - def all(cls): - return [s for s in cls._strategies] - - @classmethod - def get(cls, name): - for s in cls._strategies: - if s.name == name: - return s - raise Exception(f'No strategy found with name {name}') - - class BaseStrategy: - parameters: List[BaseParameter] = [] - _loop_index = 0 - - # strategy decorators - _step_methods = [] - _before_methods = [] - _after_methods = [] - - # each strategy gets a new order manager - orders = OrderManager() - objects = StrategyManager + @classmethod + def register(cls): + cls.objects.register(cls) + def __init__(self, data: Union[OHLC, None] = None): self.name = self.__class__.__name__ if data is None: data = OHLC() self.data = data + self.parameters: List[BaseParameter] = [] self._set_parameters() - # register the strategy - self.objects.register(self) + self.register() + + self._loop_index = 0 + + # strategy decorators + self._step_methods = [] + self._before_methods = [] + self._after_methods = [] befores, steps, afters = extract_decorators(self) self._before_methods = befores self._step_methods = steps self._after_methods = afters + # each strategy gets a new order manager + self.orders = OrderManager() + def as_model(self) -> StrategyModel: return StrategyModel( name=self.name, diff --git a/app/core/commands/strategy/__init__.py b/app/core/commands/strategy/__init__.py index 6c5f9a0..5a311dd 100644 --- a/app/core/commands/strategy/__init__.py +++ b/app/core/commands/strategy/__init__.py @@ -21,5 +21,4 @@ def list_strategies(): List all strategies """ typer.echo("Listing strategies") - from utils import strategy_loader # noqa: F401 ListStrategies().handle() \ No newline at end of file diff --git a/app/main.py b/app/main.py index afbe8b2..d1a1f47 100644 --- a/app/main.py +++ b/app/main.py @@ -1,7 +1,7 @@ from fastapi import FastAPI from api.api_v1.api import api_router -from utils import strategy_loader # noqa: F401 +from utils.loaders import load_all # noqa: F401 app = FastAPI( title="Stratis API", diff --git a/app/settings.py b/app/settings.py new file mode 100644 index 0000000..904f9ab --- /dev/null +++ b/app/settings.py @@ -0,0 +1,8 @@ +from components.ohlc import CSVAdapter +from components.ohlc.data_adapters.api_adapter import APIDataAdapter + +# add any custom data adapters here +DATA_ADAPTERS = [ + CSVAdapter, + APIDataAdapter +] \ No newline at end of file diff --git a/app/storage/strategies/examples/sma_cross_over.py b/app/storage/strategies/examples/sma_cross_over.py index 4817a7c..2b77633 100644 --- a/app/storage/strategies/examples/sma_cross_over.py +++ b/app/storage/strategies/examples/sma_cross_over.py @@ -1,6 +1,6 @@ from components import Parameter from components import Strategy, on_step -from components.strategy import Series, ta +from components.strategy import ta class SMACrossOver(Strategy): @@ -9,10 +9,11 @@ class SMACrossOver(Strategy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # all_close = self.data.all('close') - # self.sma_fast = ta.sma(all_close, int(self.sma_fast_length)) - # self.sma_slow = ta.sma(all_close, int(self.sma_slow_length)) + all_close = self.data.all('close') + self.sma_fast = ta.sma(all_close, int(self.sma_fast_length)) + self.sma_slow = ta.sma(all_close, int(self.sma_slow_length)) - # @on_step - # def check_for_crossover(self): - # print(float(self.sma_fast), float(self.sma_slow)) + @on_step + def check_for_crossover(self): + # add logic to crossover here + pass diff --git a/app/tests/test_adapter.py b/app/tests/test_adapter.py index 8d20625..b8772c8 100644 --- a/app/tests/test_adapter.py +++ b/app/tests/test_adapter.py @@ -3,15 +3,11 @@ def test_init_data_adapter(self): from components.ohlc import DataAdapter adapter = DataAdapter() assert adapter.name == 'DataAdapter' - adapter2 = DataAdapter() - assert adapter == adapter2 # test csv adapter from components.ohlc import CSVAdapter csv_adapter = CSVAdapter() assert csv_adapter.name == 'CSVAdapter' - csv_adapter2 = CSVAdapter() - assert csv_adapter == csv_adapter2 def test_csv_adapter(self): from components.ohlc import CSVAdapter diff --git a/app/tests/test_order.py b/app/tests/test_order.py new file mode 100644 index 0000000..50083c6 --- /dev/null +++ b/app/tests/test_order.py @@ -0,0 +1,37 @@ +from datetime import datetime + +from components.orders.order import Order, OrderSide as Side, OrderType + + +class TestOrders: + def test_hashing(self): + ts = datetime.now().timestamp() + fake_order_1 = Order( + symbol='BTCUSDT', + side=Side.BUY, + type=OrderType.MARKET, + qty=1, + timestamp=ts, + ) + fake_order_2 = Order( + symbol='BTCUSDT', + side=Side.BUY, + type=OrderType.MARKET, + qty=1, + timestamp=ts, + ) + assert fake_order_1 == fake_order_2 + + fake_order_3 = Order( + symbol='BTCUSDT', + side=Side.BUY, + type=OrderType.MARKET, + qty=2, + datetime=ts, + ) + + assert fake_order_1 != fake_order_3 + + # test IDs + assert fake_order_1.get_id() == fake_order_2.get_id() + assert fake_order_1.get_id() != fake_order_3.get_id() diff --git a/app/tests/test_strategy.py b/app/tests/test_strategy.py index c2cfeec..70697a3 100644 --- a/app/tests/test_strategy.py +++ b/app/tests/test_strategy.py @@ -12,8 +12,8 @@ def test_initializing_examples(self): from storage.strategies.examples.sma_cross_over import SMACrossOver strategy = SMACrossOver() assert strategy.name == 'SMACrossOver' - assert int(strategy.sma_fast) == 10 - assert int(strategy.sma_slow) == 20 + assert int(strategy.sma_fast_length) == 10 + assert int(strategy.sma_slow_length) == 60 def test_run_strategy(self): from storage.strategies.examples.sma_cross_over import SMACrossOver @@ -32,13 +32,10 @@ def test_ohlc_demo(self): ) def test_load_strategies(self): - from utils.strategy_loader import import_all_strategies + from utils.loaders.strategy_loader import import_all_strategies strategies = import_all_strategies() assert len(strategies) > 0 def test_strategy_manager(self): - from utils import strategy_loader assert len(Strategy.objects.all()) > 0 - for s in Strategy.objects.all(): - print(s) assert Strategy.objects.get('SMACrossOver').name == 'SMACrossOver' \ No newline at end of file diff --git a/app/utils/loaders/__init__.py b/app/utils/loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/utils/loaders/load_all.py b/app/utils/loaders/load_all.py new file mode 100644 index 0000000..04089e0 --- /dev/null +++ b/app/utils/loaders/load_all.py @@ -0,0 +1,36 @@ +import inspect +from pathlib import Path +from importlib import import_module + +from loguru import logger + +from components.ohlc import DataAdapter +from components.strategy.strategy import BaseStrategy + + +def import_components(path, component_type): + logger.debug(f'Importing {component_type.__name__}(s) from {path}...') + components = [] + app_path = Path(__file__).parent.parent.parent + paths = app_path.joinpath(path).rglob('*.py') + for path in paths: + module_name = path.as_posix().replace('/', '.').replace('.py', '').split('app.')[1] + module = import_module(module_name) + + for name, obj in inspect.getmembers(module): + if inspect.isclass(obj) and issubclass(obj, component_type): + if obj.__name__ != component_type.__name__: + components.append(obj) + logger.info(f'\t->\t{obj.__name__} ({obj.__module__})') + return components + +# load all components +data_adapters = import_components('components/ohlc/data_adapters', DataAdapter) +strategies = import_components('storage/strategies', BaseStrategy) + +# register all components +for adapter in data_adapters: + adapter.register() + +for strategy in strategies: + strategy.register() \ No newline at end of file diff --git a/app/utils/strategy_loader.py b/app/utils/loaders/strategy_loader.py similarity index 88% rename from app/utils/strategy_loader.py rename to app/utils/loaders/strategy_loader.py index ce51538..3b63bcb 100644 --- a/app/utils/strategy_loader.py +++ b/app/utils/loaders/strategy_loader.py @@ -9,10 +9,9 @@ def import_all_strategies() -> List[Type[BaseStrategy]]: strategies = [] - # get all paths in the strategies folder, including subfolders, assuming sources root is app/ - paths = Path(__file__).parent.parent.joinpath('storage/strategies').rglob('*.py') + app_path = Path(__file__).parent.parent.parent + paths = app_path.joinpath('storage/strategies').rglob('*.py') for path in paths: - # get the module name from the path module_name = path.as_posix().replace('/', '.').replace('.py', '') module_name = module_name.split('app.')[1] diff --git a/test_main.http b/test_main.http deleted file mode 100644 index a2d81a9..0000000 --- a/test_main.http +++ /dev/null @@ -1,11 +0,0 @@ -# Test your FastAPI endpoints - -GET http://127.0.0.1:8000/ -Accept: application/json - -### - -GET http://127.0.0.1:8000/hello/User -Accept: application/json - -###