Skip to content

Commit

Permalink
Merge branch 'deepar' of https://github.com/aimclub/Fedot.Industrial
Browse files Browse the repository at this point in the history
…into deepar
  • Loading branch information
leostre committed Jun 5, 2024
2 parents a46bd96 + 17beebf commit bfeebc0
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions fedot_ind/core/models/nn/network_impl/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _predict(self, test_data, output_mode, hidden_state=None, **output_kw):
def predict_for_fit(self, test_data):
output_mode = 'predictions'
forecast_idx_predict = np.arange(start=test_data.idx[-1],
stop=test_data.idx[-1] + self.forecast_length,
stop=test_data.idx[-1] + self.forecast_length,
step=1)

Expand Down Expand Up @@ -488,6 +489,12 @@ def __ts_to_input_data(self, input_data: Union[InputData, pd.DataFrame]):
return train_input


@convert_to_3d_torch_array
def _create_torch_loader(self, train_data, is_train):
batch_size = self.batch_size if is_train else self.forecast_length



@convert_to_3d_torch_array
def _create_torch_loader(self, train_data, is_train):
batch_size = self.batch_size if is_train else self.forecast_length
Expand All @@ -498,9 +505,13 @@ def _create_torch_loader(self, train_data, is_train):
else:
features, target = train_data.features, train_data.target

else:
features, target = train_data.features, train_data.target

train_loader = torch.utils.data.DataLoader(
data.TensorDataset(features, target),
batch_size=batch_size, shuffle=False,
batch_size=batch_size, shuffle=False,
)
return train_loader

0 comments on commit bfeebc0

Please sign in to comment.