Skip to content

Commit

Permalink
dataloader manager refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
t1user committed Jun 24, 2024
1 parent 0fb0e1f commit 55da37c
Showing 1 changed file with 107 additions and 120 deletions.
227 changes: 107 additions & 120 deletions ib_tools/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import ib_insync as ibi
import pandas as pd
import pytz
from typing_extensions import Protocol

from ib_tools.config import CONFIG
Expand Down Expand Up @@ -212,28 +211,18 @@ class ExactFutureContractSelector(FutureContractSelector):
pass


class _DataWriter:
"""
Helper class, whose only purpose is running validators, which
would be otherwise difficult in a dataclass.
"""

barSize: ClassVar = Validator(bar_size_validator)
wts: ClassVar = Validator(wts_validator)


@dataclass
class DataWriter(_DataWriter):
class DataWriter:
"""Interface between dataloader and datastore"""

store: AbstractBaseStore
contract: ibi.Contract
head: datetime
barSize: str # type: ignore
wts: str # type: ignore
barSize: str
wts: str
aggression: float = 2
# !!!!! it was pytz.timezone("Europe/Berlin") here, INVESTIGATE!!
now: datetime = field(default_factory=partial(datetime.now, pytz.utc))
now: datetime = field(default_factory=partial(datetime.now, timezone.utc))
fill_gaps: bool = CONFIG.get("fill_gaps", True)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -563,63 +552,6 @@ def __repr__(self):
return f"{self.from_date} - {self.to_date}, update: {self.update}"


@dataclass
class ContractHolder:
"""Container class ensuring contract list kept after re-connect"""

ib: ibi.IB
source: str # csv file name with contract list
store: AbstractBaseStore
wts: str # whatToShow ib api parameter
barSize: str # ib api parameter
# how big series request at each call (1 = normal, 2 = double, etc.)
aggression: int = 1
objects: list = field(default_factory=list)
items: list[DataWriter] = field(default_factory=list)

def get_items(self):
for o in self.objects:
try:
headTimeStamp = self.ib.reqHeadTimeStamp(
o, whatToShow=self.wts, useRTH=False, formatDate=2
)

if headTimeStamp == []:
log.warning(
(
f"Unavailable headTimeStamp for {o.localSymbol}. "
f"No data will be downloaded"
)
)
continue
except Exception:
log.exception("Exception while getting headTimeStamp")
continue

try:
self.items.append(
DataWriter(
self.store,
o,
headTimeStamp,
barSize=self.barSize,
wts=self.wts,
aggression=self.aggression,
)
)
except Exception as e:
log.exception(f"Error ignored for object {o}", e)
# raise

def __call__(self):
log.debug("holder called")
log.debug(f"items: {self.items}")
if not self.items:
self.get_items()
log.debug(f"items obtained: {self.items}")
return self.items


def duration_in_secs(barSize: str):
"""Given duration string return duration in seconds int"""
number, time = barSize.split(" ")
Expand Down Expand Up @@ -710,8 +642,9 @@ def __repr__(self):


class Timer:
holder: Deque[datetime] = deque(maxlen=100)

def __init__(self, seconds: int, requests: int) -> None:
self.holder: Deque[datetime] = deque(maxlen=requests)
self.seconds = seconds
self._seconds = timedelta(seconds=seconds)
self.requests = requests
Expand All @@ -721,17 +654,6 @@ def __init__(self, seconds: int, requests: int) -> None:
def start(self) -> datetime:
return self.holder[0]

def register(self) -> None:
"""
Register data request made for future tracking of number of
requests made in the unit of time. To be called by
Pacer.__aexit__ method after every request made to ib.
"""
self.holder.append(datetime.now())
# log.debug(
# f'{self} registered request {len(self.holder)} at: {self.holder[-1]} '
# f'(elapsed: {self.elapsed_time()}sec)')

def elapsed_time(self) -> timedelta:
return datetime.now() - self.start()

Expand Down Expand Up @@ -782,8 +704,7 @@ async def __aenter__(self):
await asyncio.sleep(timer.sleep_time() + 2 * random.random())
# await asyncio.sleep(1)
# register request time right before exiting the context
for timer in self.timers:
timer.register()
Timer.holder.append(datetime.now())

async def __aexit__(self, exc_type, exc_value, exc_tb):
pass
Expand Down Expand Up @@ -835,7 +756,7 @@ def pacer(
def validate_age(contract: DataWriter) -> bool:
"""
IB doesn't permit to request data for bars < 30secs older than 6
monts. Trying to push it here with 30secs.
months. Trying to push it here with 30secs.
"""
if duration_in_secs(contract.barSize) < 30 and contract.next_date:
assert isinstance(contract.next_date, datetime)
Expand Down Expand Up @@ -876,22 +797,22 @@ async def worker(name: str, queue: asyncio.Queue, pacer: Pacer, ib: ibi.IB) -> N
queue.task_done()


async def main(holder: ContractHolder, ib: ibi.IB) -> None:
async def main(manager: Manager, ib: ibi.IB) -> None:

contracts = holder()
log.debug(f"Holder: {contracts}")
number_of_workers = min(len(contracts), MAX_NUMBER_OF_WORKERS)
writers = manager.writers
log.debug(f"{writers=}")
number_of_workers = min(len(writers), MAX_NUMBER_OF_WORKERS)

log.debug(
f"main function started, " f"retrieving data for {len(contracts)} instruments"
f"main function started, " f"retrieving data for {len(writers)} instruments"
)

queue: asyncio.Queue[DataWriter] = asyncio.LifoQueue()
for contract in contracts:
await queue.put(contract)
for writer in writers:
await queue.put(writer)
p = pacer(
holder.barSize,
holder.wts,
writer.barSize,
writer.wts,
restrictions=[(2, 6), (1200, 60 - number_of_workers)],
)
log.debug(f"Pacer initialized: {p}")
Expand Down Expand Up @@ -920,42 +841,108 @@ async def main(holder: ContractHolder, ib: ibi.IB) -> None:
await asyncio.gather(*workers)


async def sequence(ib: ibi.IB):
sources = pd.read_csv(CONFIG["source"], keep_default_na=False).to_dict("records")
log.debug(f"{sources=}")
ContractSelector.set_ib(ib)
objects = []
for s in sources:
objects.extend(ContractSelector.from_kwargs(**s).objects()) # type: ignore
await ib.qualifyContractsAsync(*objects)
log.debug(f"objects: {objects}")

# object where data is stored
# store = ArcticStore(f"{wts}_{barSize}")

holder = ContractHolder(
ib=ib,
source=CONFIG["source"],
store=CONFIG["datastore"],
wts=CONFIG["wts"],
barSize=CONFIG["barSize"],
aggression=CONFIG.get("aggression", 1),
objects=objects,
)
class _Manager:
"""
Helper class, whose only purpose is running validators, which
would be otherwise difficult in a dataclass.
"""

barSize: ClassVar = Validator(bar_size_validator)
wts: ClassVar = Validator(wts_validator)


@dataclass
class Manager(_Manager):
ib: ibi.IB
barSize: str = CONFIG["barSize"] # type: ignore
wts: str = CONFIG["wts"] # type: ignore
aggression: int = CONFIG["aggression"]
store: AbstractBaseStore = CONFIG["datastore"]

@functools.cached_property
def sources(self) -> list[dict]:
return pd.read_csv(CONFIG["source"], keep_default_na=False).to_dict("records")

@functools.cached_property
def contracts(self) -> list[ibi.Contract]:
return ibi.util.run(self._contracts())

async def _contracts(self) -> list[ibi.Contract]:
ContractSelector.set_ib(self.ib)
contracts = []
for s in self.sources:
contracts.extend(ContractSelector.from_kwargs(**s).objects())
await self.ib.qualifyContractsAsync(*contracts)
log.debug(f"{contracts=}")
return contracts

await main(holder, ib)
@functools.cached_property
def headstamps(self):
return ibi.uitl.run(self._headstamps())

async def _headstamps(self) -> dict[ibi.Contract, datetime]:
headstamps = {}
for c in self.contracts:
if c_ := await self.headstamp(c):
headstamps[c] = c_
return headstamps

@functools.cached_property
def writers(self) -> list[DataWriter]:
return [
self.init_writer(
self.store, contract, headstamp, self.barSize, self.wts, self.aggression
)
for contract, headstamp in self.headstamps.items()
]

@staticmethod
def init_writer(
store: AbstractBaseStore,
contract: ibi.Contract,
headTimeStamp: datetime,
barSize: str,
wts: str,
aggression: int,
):
DataWriter(
store,
contract,
headTimeStamp,
barSize=barSize,
wts=wts,
aggression=aggression,
)

async def headstamp(self, contract: ibi.Contract):
try:
headTimeStamp = await self.ib.reqHeadTimeStampAsync(
contract, whatToShow=self.wts, useRTH=False, formatDate=2
)

if headTimeStamp == []:
log.warning(
(
f"Unavailable headTimeStamp for {contract}. "
f"No data will be downloaded"
)
)
except Exception:
log.exception(f"Exception while getting headTimeStamp for {contract}")
return headTimeStamp


def start():

ibi.util.patchAsyncio()
ib = ibi.IB()

manager = Manager()
asyncio.get_event_loop().set_debug(True)
# util.logToConsole(logging.ERROR)
log.debug("Will start...")

Connection(ib, partial(sequence, ib), watchdog=CONFIG.get("watchdog", True))
Connection(ib, partial(main, ib, manager), watchdog=CONFIG.get("watchdog", True))

log.debug("script finished, about to disconnect")
ib.disconnect()
Expand Down

0 comments on commit 55da37c

Please sign in to comment.