Skip to content

Commit

Permalink
Fixbacktest (PaddlePaddle#153)
Browse files Browse the repository at this point in the history
* fix backtest

* fix backtest
  • Loading branch information
QGN123 authored Oct 24, 2022
1 parent 68a7f5f commit b9a9826
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ msgid ""
" set to simulate the realy prediction."
msgstr ""
"``predict_window`` 是每次预测的窗口长度。``stride`` "
"是两次连续预测之间的移动步长。在大多数情况下我们需要自定义这两个参数来模拟真是的预测场景。"
"是两次连续预测之间的移动步长。在大多数情况下我们需要自定义这两个参数来模拟真实的预测场景。"

#: ../../source/modules/backtest/overview.rst:123
#: e4d054512d7440d3a79f7444fae9f856
Expand Down
14 changes: 6 additions & 8 deletions paddlets/tests/utils/test_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def setUp(self):
))
known_cov = TimeSeries.load_from_dataframe(
pd.DataFrame(
np.random.randn(2500, 2).astype(np.float32),
index=pd.date_range("2022-01-01", periods=2500, freq="15T"),
np.random.randn(2000, 2).astype(np.float32),
index=pd.date_range("2022-01-01", periods=2000, freq="15T"),
columns=["b1", "c1"]
))
static_cov = {"f": 1, "g": 2}
Expand Down Expand Up @@ -83,9 +83,8 @@ def test_backtest(self):
score, predicts = backtest(self.tsdataset1, lstnet, start=pd.Timestamp('2022-01-07T12'), predict_window=50, stride=50,
return_predicts=True)

start = 624
data_len = len(self.tsdataset1.get_target())
assert len(predicts.get_target()) == data_len - start

assert len(predicts.get_target()) == 1300

# case3 add window,stride, window != stride
lstnet = LSTNetRegressor(
Expand All @@ -107,9 +106,8 @@ def test_backtest(self):
lstnet.fit(self.tsdataset1, self.tsdataset1)
score, predicts = backtest(self.tsdataset1, lstnet, start=200, predict_window=50, stride=50, return_predicts=True)

start = 200 + 4 * 4
data_len = len(self.tsdataset1.get_target())
assert len(predicts.get_target()) == data_len - start

assert len(predicts.get_target()) == 1700

# case5 add return score
lstnet = LSTNetRegressor(
Expand Down
7 changes: 5 additions & 2 deletions paddlets/utils/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,15 @@ def _check():
for _ in tqdm(range(predict_rounds), desc=TQDM_PREFIX, disable=not verbose):
data._target, rest = all_target.split(index)
rest_len = len(rest)
if rest_len < predict_window + model_skip_chunk_len:

if rest_len < model_out_chunk_len + model_skip_chunk_len:
if data.known_cov is not None:
target_end_time = data._target.end_time
known_index = data.known_cov.get_index_at_point(target_end_time)
if len(data.known_cov) - known_index - 1 < predict_window + model_skip_chunk_len:
if len(data.known_cov) - known_index - 1 < model_out_chunk_len + model_skip_chunk_len:
break

if rest_len < predict_window + model_skip_chunk_len:
predict_window = rest_len
predict_window = predict_window - model_skip_chunk_len

Expand Down

0 comments on commit b9a9826

Please sign in to comment.