Skip to content

Commit

Permalink
add calc_train_trade_data
Browse files Browse the repository at this point in the history
  • Loading branch information
zhumingpassional committed Mar 25, 2023
1 parent 9a5262e commit 2190446
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions finrl/meta/data_processors/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Tuple

import numpy as np
import pandas as pd


# filename: str
Expand Down Expand Up @@ -92,11 +93,11 @@ def calc_dates(
# return: train_starts, train_ends, trade_starts, trade_ends, which has the same length num_subsets_if_rolling
# start is include, end is not include. The max of endIndex is len(dates) - 1.
def calc_train_trade_starts_ends_if_rolling(
init_train_dates: list[str], init_trade_dates: list[str], trade_window_length2: int
init_train_dates: list[str], init_trade_dates: list[str], rolling_window_length: int
) -> tuple[list[str], list[str], list[str], list[str]]:
trade_dates_length = len(init_trade_dates)
train_window_length = len(init_train_dates)
trade_window_length = min(trade_window_length2, trade_dates_length)
trade_window_length = min(rolling_window_length, trade_dates_length)
num_subsets_if_rolling = int(np.ceil(trade_dates_length / trade_window_length))
print("num_subsets_if_rolling: ", num_subsets_if_rolling)
dates = np.concatenate((init_train_dates, init_trade_dates), axis=0)
Expand All @@ -121,3 +122,29 @@ def calc_train_trade_starts_ends_if_rolling(
print("trade_starts: ", trade_starts)
print("trade_ends__: ", trade_ends)
return train_starts, train_ends, trade_starts, trade_ends


def calc_train_trade_data(i: int,
train_starts: List[str],
train_ends: List[str],
trade_starts: List[str],
trade_ends: List[str],
init_train_data: pd.DataFrame(),
init_trade_data: pd.DataFrame(),
date_col: str,
) -> Tuple[pd.DataFrame(), pd.DataFrame()]:
train_start = train_starts[i]
train_end = train_ends[i]
trade_start = trade_starts[i]
trade_end = trade_ends[i]
train_data = init_train_data.loc[
(init_train_data[date_col] >= train_start)
& (init_train_data[date_col] < train_end)
]
train_data.index = train_data[date_col].factorize()[0]
trade_data = init_trade_data.loc[
(init_trade_data[date_col] >= trade_start)
& (init_trade_data[date_col] < trade_end)
]
trade_data.index = trade_data[date_col].factorize()[0]
return train_data, trade_data

0 comments on commit 2190446

Please sign in to comment.