forked from AIStream-Peelout/flow-forecast
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtime_model_test.py
104 lines (93 loc) · 3.64 KB
/
time_model_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import unittest
from flood_forecast.model_dict_function import pytorch_model_dict as pytorch_model_dict1
from flood_forecast.time_model import PyTorchForecast
import os
import torch
class TimeSeriesModelTest(unittest.TestCase):
def setUp(self):
self.test_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_init")
self.model_params = {
"metrics": ["MSE", "DilateLoss"],
"model_params": {
"number_time_series": 3},
"inference_params": {
"hours_to_forecast": 16},
"dataset_params": {
"forecast_history": 20,
"class": "default",
"forecast_length": 20,
"relevant_cols": [
"cfs",
"temp",
"precip"],
"target_col": ["cfs"],
"interpolate": False},
"wandb": False}
def test_pytorch_model_dict(self):
self.assertEqual(type(pytorch_model_dict1), dict)
def test_pytorch_wrapper_default(self):
keag_file = os.path.join(self.test_path, "keag_small.csv")
model = PyTorchForecast(
"MultiAttnHeadSimple",
keag_file,
keag_file,
keag_file,
self.model_params)
self.assertEqual(model.model.dense_shape.in_features, 3)
self.assertEqual(model.model.multi_attn.embed_dim, 128)
self.assertEqual(model.model.multi_attn.num_heads, 8)
def test_pytorch_wrapper_custom(self):
self.model_params["model_params"] = {"number_time_series": 6, "d_model": 112}
keag_file = os.path.join(self.test_path, "keag_small.csv")
model = PyTorchForecast(
"MultiAttnHeadSimple",
keag_file,
keag_file,
keag_file,
self.model_params)
self.assertEqual(model.model.dense_shape.in_features, 6)
self.assertEqual(model.model.multi_attn.embed_dim, 112)
def test_model_save(self):
keag_file = os.path.join(self.test_path, "keag_small.csv")
model = PyTorchForecast(
"MultiAttnHeadSimple",
keag_file,
keag_file,
keag_file,
self.model_params)
model.save_model("output", 0)
self.assertEqual(model.training[0][0].shape, torch.Size([20, 3]))
def test_simple_transformer(self):
self.model_params["model_params"] = {
"seq_length": 19,
"number_time_series": 6,
"d_model": 136,
"n_heads": 8}
keag_file = os.path.join(self.test_path, "keag_small.csv")
model = PyTorchForecast(
"SimpleTransformer",
keag_file,
keag_file,
keag_file,
self.model_params)
self.assertEqual(model.model.dense_shape.in_features, 6)
self.assertEqual(model.model.mask.shape, torch.Size([19, 19]))
def test_data_correct(self):
keag_file = os.path.join(self.test_path, "keag_small.csv")
model = PyTorchForecast(
"MultiAttnHeadSimple",
keag_file,
keag_file,
keag_file,
self.model_params)
model
def test_informer_init(self):
import json
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_informer.json")) as y:
json_params = json.load(y)
keag_file = os.path.join(self.test_path, "keag_small.csv")
inf = PyTorchForecast("Informer", keag_file, keag_file, keag_file, json_params)
self.assertTrue(inf)
self.assertEqual(inf.model.label_len, 10)
if __name__ == '__main__':
unittest.main()