Skip to content

Commit

Permalink
Refactor to skip Hive dataset and ingest directly from CSV files begun.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jim White committed Sep 24, 2024
1 parent 6fa6ea2 commit 5d29981
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 37 deletions.
67 changes: 36 additions & 31 deletions src/concat_all_aggs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import shutil
from typing import Iterator
from typing import Iterator, Tuple
from config import PolygonConfig

import argparse
Expand All @@ -13,12 +13,12 @@
import pandas as pd


def csv_agg_scanner(
paths: list,
def generate_tables_from_csv_files(
paths: Iterator,
schema: pa.Schema,
start_timestamp: pd.Timestamp,
limit_timestamp: pd.Timestamp,
) -> Iterator[pa.RecordBatch]:
) -> Iterator[pa.Table]:
empty_table = schema.empty_table()
# TODO: Find which column(s) need to be cast to int64 from the schema.
empty_table = empty_table.set_column(
Expand Down Expand Up @@ -64,32 +64,20 @@ def csv_agg_scanner(
continue

print(f"{path=}")
for batch in table.to_batches():
yield batch
yield table
print(f"{skipped_tables=}")


def concat_all_aggs_from_csv(
def generate_csv_agg_tables(
config: PolygonConfig,
aggs_pattern: str = "**/*.csv.gz",
overwrite: bool = False,
) -> list:
) -> Tuple[pa.Schema, Iterator[pa.Table]]:
"""zipline does bundle ingestion one ticker at a time."""
if os.path.exists(config.by_ticker_hive_dir):
if overwrite:
print(f"Removing {config.by_ticker_hive_dir=}")
shutil.rmtree(config.by_ticker_hive_dir)
else:
raise FileExistsError(
f"{config.by_ticker_hive_dir=} exists and overwrite is False."
)

# We sort by path because they have the year and month in the dir names and the date in the filename.
paths = sorted(
list(
glob.glob(
os.path.join(config.aggs_dir, aggs_pattern),
recursive="**" in aggs_pattern,
os.path.join(config.aggs_dir, config.csv_paths_pattern),
recursive="**" in config.csv_paths_pattern,
)
)
)
Expand Down Expand Up @@ -126,18 +114,36 @@ def concat_all_aggs_from_csv(
]
)

agg_scanner = pa_ds.Scanner.from_batches(
csv_agg_scanner(
paths=paths,
schema=polygon_aggs_schema,
start_timestamp=config.start_timestamp,
limit_timestamp=config.end_timestamp + pd.to_timedelta(1, unit="day"),
),
return polygon_aggs_schema, generate_tables_from_csv_files(
paths=paths,
schema=polygon_aggs_schema,
start_timestamp=config.start_timestamp,
limit_timestamp=config.end_timestamp + pd.to_timedelta(1, unit="day"),
)


def generate_batches_from_tables(tables):
for table in tables:
yield table.to_batches()


def concat_all_aggs_from_csv(
config: PolygonConfig,
overwrite: bool = False,
) -> None:
if os.path.exists(config.by_ticker_hive_dir):
if overwrite:
print(f"Removing {config.by_ticker_hive_dir=}")
shutil.rmtree(config.by_ticker_hive_dir)
else:
raise FileExistsError(
f"{config.by_ticker_hive_dir=} exists and overwrite is False."
)

schema, tables = generate_csv_agg_tables(config)
pa_ds.write_dataset(
agg_scanner,
generate_batches_from_tables(tables),
schema=schema,
base_dir=config.by_ticker_hive_dir,
format="parquet",
existing_data_behavior="overwrite_or_ignore",
Expand All @@ -160,7 +166,7 @@ def concat_all_aggs_from_csv(

# TODO: These defaults should be None but for dev convenience they are set for my local config.
parser.add_argument("--data_dir", default="/Volumes/Oahu/Mirror/files.polygon.io")
parser.add_argument("--aggs_pattern", default="**/*.csv.gz")
# parser.add_argument("--aggs_pattern", default="**/*.csv.gz")
# parser.add_argument("--aggs_pattern", default="2020/10/**/*.csv.gz")

args = parser.parse_args()
Expand All @@ -179,6 +185,5 @@ def concat_all_aggs_from_csv(

concat_all_aggs_from_csv(
config=config,
aggs_pattern=args.aggs_pattern,
overwrite=args.overwrite,
)
1 change: 1 addition & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.flat_files_dir = environ.get(
"POLYGON_FLAT_FILES_DIR", os.path.join(self.data_dir, "flatfiles")
)
self.csv_paths_pattern = environ.get("POLYGON_FLAT_FILES_CSV_PATTERN", "**/*.csv.gz")
self.agg_time = agg_time
self.asset_files_dir = os.path.join(self.flat_files_dir, self.asset_subdir)
self.minute_aggs_dir = os.path.join(self.asset_files_dir, "minute_aggs_v1")
Expand Down
37 changes: 31 additions & 6 deletions src/zipline_polygon_bundle.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from zipline.data.bundles import register

from config import PolygonConfig
from concat_all_aggs import concat_all_aggs_from_csv
from concat_all_aggs import generate_csv_agg_tables

import polygon
from urllib3 import HTTPResponse

import pyarrow
import pyarrow.compute
import pyarrow as pa
from pyarrow import dataset as pa_ds
from pyarrow import csv as pa_csv

import pandas as pd
import datetime
import os
import logging
import glob


def polygon_equities_bundle_day(
Expand Down Expand Up @@ -86,6 +91,7 @@ def polygon_equities_bundle_minute(

# TODO: Change warnings to be relative to number of days in the range.


def load_polygon_splits(
config: PolygonConfig, first_start_end: datetime.date, last_end_date: datetime.date
) -> pd.DataFrame:
Expand Down Expand Up @@ -198,6 +204,26 @@ def load_dividends(
return dividends


def generate_all_aggs_from_csv(
config: PolygonConfig,
calendar,
start_session,
end_session,
ticker_to_sid: dict[str, int],
dates_with_data: set,
):
schema, tables = generate_csv_agg_tables(config)
for table in tables:
table = table.rename_columns({"ticker": "symbol", "window_start": "day"})
table = table.sort_by([("symbol", "ascending")])
symbols = sorted(set(table.column("symbol").to_pylist()))
for sid, symbol in enumerate(symbols):
ticker_to_sid[symbol] = sid
df = table.filter(
pyarrow.compute.field("symbol") == pyarrow.scalar(symbol)
).to_pandas()


def polygon_equities_bundle(
config: PolygonConfig,
asset_db_writer,
Expand All @@ -224,10 +250,9 @@ def polygon_equities_bundle(
)
)

if not os.path.exists(config.by_ticker_hive_dir):
concat_all_aggs_from_csv(config)

aggregates = pyarrow.dataset.dataset(config.by_ticker_hive_dir)
aggregates = generate_all_aggs_from_csv(
config, ticker_to_sid, dates_with_data, calendar, start_session, end_session
)

ticker_to_sid = {}
dates_with_data = set()
Expand Down Expand Up @@ -426,7 +451,7 @@ def register_polygon_equities_bundle(
# include_asset_types=None,
):
if agg_time not in ["day", "minute"]:
raise ValueError(f"agg_time must be 'day' or 'minute', not {agg_time}")
raise ValueError(f"agg_time must be 'day' or 'minute', not '{agg_time}'")
register(
bundlename,
(
Expand Down

0 comments on commit 5d29981

Please sign in to comment.