From bba6ea2183a97c3d8560ceab68b2a6bb38708c13 Mon Sep 17 00:00:00 2001 From: lss-1138 <57395990+lss-1138@users.noreply.github.com> Date: Wed, 27 Mar 2024 14:55:57 +0800 Subject: [PATCH] Initial commit --- .gitattributes | 2 + .gitignore | 9 + LICENSE | 201 +++++++ README.md | 30 + data_provider/data_factory.py | 51 ++ data_provider/data_loader.py | 396 +++++++++++++ exp/exp_basic.py | 37 ++ exp/exp_main.py | 379 +++++++++++++ layers/AutoCorrelation.py | 164 ++++++ layers/Autoformer_EncDec.py | 192 +++++++ layers/Embed.py | 164 ++++++ layers/FED_FourierCorrelation.py | 133 +++++ layers/FourierCorrelation.py | 512 +++++++++++++++++ layers/MultiWaveletCorrelation.py | 379 +++++++++++++ layers/PatchTST_backbone.py | 379 +++++++++++++ layers/PatchTST_layers.py | 121 ++++ layers/RevIN.py | 61 ++ layers/SelfAttention_Family.py | 166 ++++++ layers/Transformer_EncDec.py | 131 +++++ layers/mwt.py | 910 ++++++++++++++++++++++++++++++ layers/utils.py | 391 +++++++++++++ models/Autoformer.py | 126 +++++ models/DLinear.py | 174 ++++++ models/FEDformer.py | 207 +++++++ models/Film.py | 247 ++++++++ models/Informer.py | 106 ++++ models/Linear.py | 21 + models/PatchTST.py | 92 +++ models/SparseTSF.py | 45 ++ models/Stat_models.py | 120 ++++ models/Transformer.py | 144 +++++ requirements.txt | 5 + run_all.sh | 12 + run_longExp.py | 156 +++++ scripts/SparseTSF/electricity.sh | 30 + scripts/SparseTSF/etth1.sh | 31 + scripts/SparseTSF/etth2.sh | 31 + scripts/SparseTSF/ettm1.sh | 30 + scripts/SparseTSF/ettm2.sh | 30 + scripts/SparseTSF/traffic.sh | 31 + scripts/SparseTSF/weather.sh | 30 + utils/augmentations.py | 163 ++++++ utils/masking.py | 39 ++ utils/metrics.py | 44 ++ utils/timefeatures.py | 134 +++++ utils/tools.py | 121 ++++ 46 files changed, 6977 insertions(+) create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 data_provider/data_factory.py create mode 100644 data_provider/data_loader.py create mode 100644 exp/exp_basic.py create mode 100644 exp/exp_main.py create mode 100644 layers/AutoCorrelation.py create mode 100644 layers/Autoformer_EncDec.py create mode 100644 layers/Embed.py create mode 100644 layers/FED_FourierCorrelation.py create mode 100644 layers/FourierCorrelation.py create mode 100644 layers/MultiWaveletCorrelation.py create mode 100644 layers/PatchTST_backbone.py create mode 100644 layers/PatchTST_layers.py create mode 100644 layers/RevIN.py create mode 100644 layers/SelfAttention_Family.py create mode 100644 layers/Transformer_EncDec.py create mode 100644 layers/mwt.py create mode 100644 layers/utils.py create mode 100644 models/Autoformer.py create mode 100644 models/DLinear.py create mode 100644 models/FEDformer.py create mode 100644 models/Film.py create mode 100644 models/Informer.py create mode 100644 models/Linear.py create mode 100644 models/PatchTST.py create mode 100644 models/SparseTSF.py create mode 100644 models/Stat_models.py create mode 100644 models/Transformer.py create mode 100644 requirements.txt create mode 100644 run_all.sh create mode 100644 run_longExp.py create mode 100644 scripts/SparseTSF/electricity.sh create mode 100644 scripts/SparseTSF/etth1.sh create mode 100644 scripts/SparseTSF/etth2.sh create mode 100644 scripts/SparseTSF/ettm1.sh create mode 100644 scripts/SparseTSF/ettm2.sh create mode 100644 scripts/SparseTSF/traffic.sh create mode 100644 scripts/SparseTSF/weather.sh create mode 100644 utils/augmentations.py create mode 100644 utils/masking.py create mode 100644 utils/metrics.py create mode 100644 utils/timefeatures.py create mode 100644 utils/tools.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..72c45eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.idea/ +*.pyc +*.pth +dataset/ +result.txt +test_results/ +results/ +logs/ +.DS_Store \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b09cd78 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cd25ba0 --- /dev/null +++ b/README.md @@ -0,0 +1,30 @@ +# SparseTSF + +Welcome to the anonymized repository of the SparseTSF paper: "SparseTSF: Modeling Long-term Time Series Forecasting with *1k* Parameters" + +### Model Implementation + +The implementation code of SparseTSF is available at: +``` +models/SparseTSF.py +``` + +### Training Scripts + +The training scripts (including hyperparameter settings) for replicating the SparseTSF results are available at: + +``` +scripts/SparsrTSF +``` + +### Quick Reproduction + +You can reproduce all the main results of SparseTSF with the following code snippet. +``` +conda create -n SparseTSF python=3.8 +conda activate SparseTSF +pip install -r requirements.txt +sh run_all.sh +``` + +**Thank you again for your efforts and time. We will continue to improve this repository after the paper is accepted.** \ No newline at end of file diff --git a/data_provider/data_factory.py b/data_provider/data_factory.py new file mode 100644 index 0000000..25b0595 --- /dev/null +++ b/data_provider/data_factory.py @@ -0,0 +1,51 @@ +from data_provider.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred +from torch.utils.data import DataLoader + +data_dict = { + 'ETTh1': Dataset_ETT_hour, + 'ETTh2': Dataset_ETT_hour, + 'ETTm1': Dataset_ETT_minute, + 'ETTm2': Dataset_ETT_minute, + 'custom': Dataset_Custom, +} + + +def data_provider(args, flag): + Data = data_dict[args.data] + timeenc = 0 if args.embed != 'timeF' else 1 + + if flag == 'test': + shuffle_flag = False + drop_last = True + batch_size = args.batch_size + freq = args.freq + elif flag == 'pred': + shuffle_flag = False + drop_last = False + batch_size = 1 + freq = args.freq + Data = Dataset_Pred + else: + shuffle_flag = True + drop_last = True + batch_size = args.batch_size + freq = args.freq + + data_set = Data( + root_path=args.root_path, + data_path=args.data_path, + flag=flag, + size=[args.seq_len, args.label_len, args.pred_len], + features=args.features, + target=args.target, + timeenc=timeenc, + freq=freq + ) + print(flag, len(data_set)) + data_loader = DataLoader( + data_set, + batch_size=batch_size, + shuffle=shuffle_flag, + num_workers=args.num_workers, + drop_last=drop_last) + return data_set, data_loader diff --git a/data_provider/data_loader.py b/data_provider/data_loader.py new file mode 100644 index 0000000..ab8b4c7 --- /dev/null +++ b/data_provider/data_loader.py @@ -0,0 +1,396 @@ +import os +import numpy as np +import pandas as pd +import os +import torch +from torch.utils.data import Dataset, DataLoader +from sklearn.preprocessing import StandardScaler +from utils.timefeatures import time_features +import warnings + +warnings.filterwarnings('ignore') + + +class Dataset_ETT_hour(Dataset): + def __init__(self, root_path, flag='train', size=None, + features='S', data_path='ETTh1.csv', + target='OT', scale=True, timeenc=0, freq='h'): + # size [seq_len, label_len, pred_len] + # info + if size == None: + self.seq_len = 24 * 4 * 4 + self.label_len = 24 * 4 + self.pred_len = 24 * 4 + else: + self.seq_len = size[0] + self.label_len = size[1] + self.pred_len = size[2] + # init + assert flag in ['train', 'test', 'val'] + type_map = {'train': 0, 'val': 1, 'test': 2} + self.set_type = type_map[flag] + + self.features = features + self.target = target + self.scale = scale + self.timeenc = timeenc + self.freq = freq + + self.root_path = root_path + self.data_path = data_path + self.__read_data__() + + def __read_data__(self): + self.scaler = StandardScaler() + df_raw = pd.read_csv(os.path.join(self.root_path, + self.data_path)) + + border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len] + border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24] + border1 = border1s[self.set_type] + border2 = border2s[self.set_type] + + if self.features == 'M' or self.features == 'MS': + cols_data = df_raw.columns[1:] + df_data = df_raw[cols_data] + elif self.features == 'S': + df_data = df_raw[[self.target]] + + if self.scale: + train_data = df_data[border1s[0]:border2s[0]] + self.scaler.fit(train_data.values) + data = self.scaler.transform(df_data.values) + else: + data = df_data.values + + df_stamp = df_raw[['date']][border1:border2] + df_stamp['date'] = pd.to_datetime(df_stamp.date) + if self.timeenc == 0: + df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) + df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) + df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) + df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) + data_stamp = df_stamp.drop(['date'], axis=1).values + elif self.timeenc == 1: + data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) + data_stamp = data_stamp.transpose(1, 0) + + + self.data_x = data[border1:border2] + self.data_y = data[border1:border2] + self.data_stamp = data_stamp + + def __getitem__(self, index): + s_begin = index + s_end = s_begin + self.seq_len + r_begin = s_end - self.label_len + r_end = r_begin + self.label_len + self.pred_len + + seq_x = self.data_x[s_begin:s_end] + seq_y = self.data_y[r_begin:r_end] + seq_x_mark = self.data_stamp[s_begin:s_end] + seq_y_mark = self.data_stamp[r_begin:r_end] + + return seq_x, seq_y, seq_x_mark, seq_y_mark + + def __len__(self): + return len(self.data_x) - self.seq_len - self.pred_len + 1 + + def inverse_transform(self, data): + return self.scaler.inverse_transform(data) + + +class Dataset_ETT_minute(Dataset): + def __init__(self, root_path, flag='train', size=None, + features='S', data_path='ETTm1.csv', + target='OT', scale=True, timeenc=0, freq='t'): + # size [seq_len, label_len, pred_len] + # info + if size == None: + self.seq_len = 24 * 4 * 4 + self.label_len = 24 * 4 + self.pred_len = 24 * 4 + else: + self.seq_len = size[0] + self.label_len = size[1] + self.pred_len = size[2] + # init + assert flag in ['train', 'test', 'val'] + type_map = {'train': 0, 'val': 1, 'test': 2} + self.set_type = type_map[flag] + + self.features = features + self.target = target + self.scale = scale + self.timeenc = timeenc + self.freq = freq + + self.root_path = root_path + self.data_path = data_path + self.__read_data__() + + def __read_data__(self): + self.scaler = StandardScaler() + df_raw = pd.read_csv(os.path.join(self.root_path, + self.data_path)) + + border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len] + border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4] + border1 = border1s[self.set_type] + border2 = border2s[self.set_type] + + if self.features == 'M' or self.features == 'MS': + cols_data = df_raw.columns[1:] + df_data = df_raw[cols_data] + elif self.features == 'S': + df_data = df_raw[[self.target]] + + if self.scale: + train_data = df_data[border1s[0]:border2s[0]] + self.scaler.fit(train_data.values) + data = self.scaler.transform(df_data.values) + else: + data = df_data.values + + df_stamp = df_raw[['date']][border1:border2] + df_stamp['date'] = pd.to_datetime(df_stamp.date) + if self.timeenc == 0: + df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) + df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) + df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) + df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) + df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) + df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) + data_stamp = df_stamp.drop(['date'], axis=1).values + elif self.timeenc == 1: + data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) + data_stamp = data_stamp.transpose(1, 0) + + self.data_x = data[border1:border2] + self.data_y = data[border1:border2] + self.data_stamp = data_stamp + + def __getitem__(self, index): + s_begin = index + s_end = s_begin + self.seq_len + r_begin = s_end - self.label_len + r_end = r_begin + self.label_len + self.pred_len + + seq_x = self.data_x[s_begin:s_end] + seq_y = self.data_y[r_begin:r_end] + seq_x_mark = self.data_stamp[s_begin:s_end] + seq_y_mark = self.data_stamp[r_begin:r_end] + + return seq_x, seq_y, seq_x_mark, seq_y_mark + + def __len__(self): + return len(self.data_x) - self.seq_len - self.pred_len + 1 + + def inverse_transform(self, data): + return self.scaler.inverse_transform(data) + + +class Dataset_Custom(Dataset): + def __init__(self, root_path, flag='train', size=None, + features='S', data_path='ETTh1.csv', + target='OT', scale=True, timeenc=0, freq='h'): + # size [seq_len, label_len, pred_len] + # info + if size == None: + self.seq_len = 24 * 4 * 4 + self.label_len = 24 * 4 + self.pred_len = 24 * 4 + else: + self.seq_len = size[0] + self.label_len = size[1] + self.pred_len = size[2] + # init + assert flag in ['train', 'test', 'val'] + type_map = {'train': 0, 'val': 1, 'test': 2} + self.set_type = type_map[flag] + + self.features = features + self.target = target + self.scale = scale + self.timeenc = timeenc + self.freq = freq + + self.root_path = root_path + self.data_path = data_path + self.__read_data__() + + def __read_data__(self): + self.scaler = StandardScaler() + df_raw = pd.read_csv(os.path.join(self.root_path, + self.data_path)) + + ''' + df_raw.columns: ['date', ...(other features), target feature] + ''' + cols = list(df_raw.columns) + cols.remove(self.target) + cols.remove('date') + df_raw = df_raw[['date'] + cols + [self.target]] + # print(cols) + num_train = int(len(df_raw) * 0.7) + num_test = int(len(df_raw) * 0.2) + num_vali = len(df_raw) - num_train - num_test + border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] + border2s = [num_train, num_train + num_vali, len(df_raw)] + border1 = border1s[self.set_type] + border2 = border2s[self.set_type] + + if self.features == 'M' or self.features == 'MS': + cols_data = df_raw.columns[1:] + df_data = df_raw[cols_data] + elif self.features == 'S': + df_data = df_raw[[self.target]] + + if self.scale: + train_data = df_data[border1s[0]:border2s[0]] + self.scaler.fit(train_data.values) + # print(self.scaler.mean_) + # exit() + data = self.scaler.transform(df_data.values) + else: + data = df_data.values + + df_stamp = df_raw[['date']][border1:border2] + df_stamp['date'] = pd.to_datetime(df_stamp.date) + if self.timeenc == 0: + df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) + df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) + df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) + df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) + data_stamp = df_stamp.drop(['date'], axis=1).values + elif self.timeenc == 1: + data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) + data_stamp = data_stamp.transpose(1, 0) + + self.data_x = data[border1:border2] + self.data_y = data[border1:border2] + self.data_stamp = data_stamp + + def __getitem__(self, index): + s_begin = index + s_end = s_begin + self.seq_len + r_begin = s_end - self.label_len + r_end = r_begin + self.label_len + self.pred_len + + seq_x = self.data_x[s_begin:s_end] + seq_y = self.data_y[r_begin:r_end] + seq_x_mark = self.data_stamp[s_begin:s_end] + seq_y_mark = self.data_stamp[r_begin:r_end] + + return seq_x, seq_y, seq_x_mark, seq_y_mark + + def __len__(self): + return len(self.data_x) - self.seq_len - self.pred_len + 1 + + def inverse_transform(self, data): + return self.scaler.inverse_transform(data) + +class Dataset_Pred(Dataset): + def __init__(self, root_path, flag='pred', size=None, + features='S', data_path='ETTh1.csv', + target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None): + # size [seq_len, label_len, pred_len] + # info + if size == None: + self.seq_len = 24 * 4 * 4 + self.label_len = 24 * 4 + self.pred_len = 24 * 4 + else: + self.seq_len = size[0] + self.label_len = size[1] + self.pred_len = size[2] + # init + assert flag in ['pred'] + + self.features = features + self.target = target + self.scale = scale + self.inverse = inverse + self.timeenc = timeenc + self.freq = freq + self.cols = cols + self.root_path = root_path + self.data_path = data_path + self.__read_data__() + + def __read_data__(self): + self.scaler = StandardScaler() + df_raw = pd.read_csv(os.path.join(self.root_path, + self.data_path)) + ''' + df_raw.columns: ['date', ...(other features), target feature] + ''' + if self.cols: + cols = self.cols.copy() + cols.remove(self.target) + else: + cols = list(df_raw.columns) + cols.remove(self.target) + cols.remove('date') + df_raw = df_raw[['date'] + cols + [self.target]] + border1 = len(df_raw) - self.seq_len + border2 = len(df_raw) + + if self.features == 'M' or self.features == 'MS': + cols_data = df_raw.columns[1:] + df_data = df_raw[cols_data] + elif self.features == 'S': + df_data = df_raw[[self.target]] + + if self.scale: + self.scaler.fit(df_data.values) + data = self.scaler.transform(df_data.values) + else: + data = df_data.values + + tmp_stamp = df_raw[['date']][border1:border2] + tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date) + pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq) + + df_stamp = pd.DataFrame(columns=['date']) + df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:]) + if self.timeenc == 0: + df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1) + df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1) + df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1) + df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1) + df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1) + df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15) + data_stamp = df_stamp.drop(['date'], axis=1).values + elif self.timeenc == 1: + data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq) + data_stamp = data_stamp.transpose(1, 0) + + self.data_x = data[border1:border2] + if self.inverse: + self.data_y = df_data.values[border1:border2] + else: + self.data_y = data[border1:border2] + self.data_stamp = data_stamp + + def __getitem__(self, index): + s_begin = index + s_end = s_begin + self.seq_len + r_begin = s_end - self.label_len + r_end = r_begin + self.label_len + self.pred_len + + seq_x = self.data_x[s_begin:s_end] + if self.inverse: + seq_y = self.data_x[r_begin:r_begin + self.label_len] + else: + seq_y = self.data_y[r_begin:r_begin + self.label_len] + seq_x_mark = self.data_stamp[s_begin:s_end] + seq_y_mark = self.data_stamp[r_begin:r_end] + + return seq_x, seq_y, seq_x_mark, seq_y_mark + + def __len__(self): + return len(self.data_x) - self.seq_len + 1 + + def inverse_transform(self, data): + return self.scaler.inverse_transform(data) diff --git a/exp/exp_basic.py b/exp/exp_basic.py new file mode 100644 index 0000000..b0090f8 --- /dev/null +++ b/exp/exp_basic.py @@ -0,0 +1,37 @@ +import os +import torch +import numpy as np + + +class Exp_Basic(object): + def __init__(self, args): + self.args = args + self.device = self._acquire_device() + self.model = self._build_model().to(self.device) + + def _build_model(self): + raise NotImplementedError + return None + + def _acquire_device(self): + if self.args.use_gpu: + os.environ["CUDA_VISIBLE_DEVICES"] = str( + self.args.gpu) if not self.args.use_multi_gpu else self.args.devices + device = torch.device('cuda:{}'.format(self.args.gpu)) + print('Use GPU: cuda:{}'.format(self.args.gpu)) + else: + device = torch.device('cpu') + print('Use CPU') + return device + + def _get_data(self): + pass + + def vali(self): + pass + + def train(self): + pass + + def test(self): + pass diff --git a/exp/exp_main.py b/exp/exp_main.py new file mode 100644 index 0000000..421f0a3 --- /dev/null +++ b/exp/exp_main.py @@ -0,0 +1,379 @@ +from data_provider.data_factory import data_provider +from exp.exp_basic import Exp_Basic +from models import Informer, Autoformer, Transformer, DLinear, Linear, PatchTST, SparseTSF, Film, FEDformer +from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop +from utils.metrics import metric + +import numpy as np +import torch +import torch.nn as nn +from torch import optim +from torch.optim import lr_scheduler + +import os +import time + +import warnings +import matplotlib.pyplot as plt +import numpy as np + +warnings.filterwarnings('ignore') + + +class Exp_Main(Exp_Basic): + def __init__(self, args): + super(Exp_Main, self).__init__(args) + + def _build_model(self): + model_dict = { + 'Autoformer': Autoformer, + 'Transformer': Transformer, + 'Informer': Informer, + 'DLinear': DLinear, + 'Linear': Linear, + 'PatchTST': PatchTST, + 'SparseTSF': SparseTSF, + 'Film': Film, + 'FEDformer': FEDformer + } + model = model_dict[self.args.model].Model(self.args).float() + + if self.args.use_multi_gpu and self.args.use_gpu: + model = nn.DataParallel(model, device_ids=self.args.device_ids) + return model + + def _get_data(self, flag): + data_set, data_loader = data_provider(self.args, flag) + return data_set, data_loader + + def _select_optimizer(self): + model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) + return model_optim + + def _select_criterion(self): + if self.args.loss == "mae": + criterion = nn.L1Loss() + elif self.args.loss == "mse": + criterion = nn.MSELoss() + elif self.args.loss == "smooth": + criterion = nn.SmoothL1Loss() + else: + criterion = nn.MSELoss() + return criterion + + def vali(self, vali_data, vali_loader, criterion): + total_loss = [] + self.model.eval() + with torch.no_grad(): + for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader): + batch_x = batch_x.float().to(self.device) + batch_y = batch_y.float() + + batch_x_mark = batch_x_mark.float().to(self.device) + batch_y_mark = batch_y_mark.float().to(self.device) + + # decoder input + dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() + dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) + # encoder - decoder + if self.args.use_amp: + with torch.cuda.amp.autocast(): + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + else: + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + f_dim = -1 if self.args.features == 'MS' else 0 + outputs = outputs[:, -self.args.pred_len:, f_dim:] + batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) + + pred = outputs.detach().cpu() + true = batch_y.detach().cpu() + + loss = criterion(pred, true) + + total_loss.append(loss) + total_loss = np.average(total_loss) + self.model.train() + return total_loss + + def train(self, setting): + train_data, train_loader = self._get_data(flag='train') + vali_data, vali_loader = self._get_data(flag='val') + test_data, test_loader = self._get_data(flag='test') + + path = os.path.join(self.args.checkpoints, setting) + if not os.path.exists(path): + os.makedirs(path) + + time_now = time.time() + + train_steps = len(train_loader) + early_stopping = EarlyStopping(patience=self.args.patience, verbose=True) + + model_optim = self._select_optimizer() + criterion = self._select_criterion() + + if self.args.use_amp: + scaler = torch.cuda.amp.GradScaler() + + scheduler = lr_scheduler.OneCycleLR(optimizer=model_optim, + steps_per_epoch=train_steps, + pct_start=self.args.pct_start, + epochs=self.args.train_epochs, + max_lr=self.args.learning_rate) + + for epoch in range(self.args.train_epochs): + iter_count = 0 + train_loss = [] + self.model.train() + epoch_time = time.time() + for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader): + iter_count += 1 + model_optim.zero_grad() + batch_x = batch_x.float().to(self.device) + + batch_y = batch_y.float().to(self.device) + batch_x_mark = batch_x_mark.float().to(self.device) + batch_y_mark = batch_y_mark.float().to(self.device) + + # decoder input + dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() + dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) + + # encoder - decoder + if self.args.use_amp: + with torch.cuda.amp.autocast(): + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + + f_dim = -1 if self.args.features == 'MS' else 0 + outputs = outputs[:, -self.args.pred_len:, f_dim:] + batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) + loss = criterion(outputs, batch_y) + train_loss.append(loss.item()) + else: + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y) + # print(outputs.shape,batch_y.shape) + f_dim = -1 if self.args.features == 'MS' else 0 + outputs = outputs[:, -self.args.pred_len:, f_dim:] + batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) + loss = criterion(outputs, batch_y) + train_loss.append(loss.item()) + + if (i + 1) % 100 == 0: + print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) + speed = (time.time() - time_now) / iter_count + left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i) + print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) + iter_count = 0 + time_now = time.time() + + if self.args.use_amp: + scaler.scale(loss).backward() + scaler.step(model_optim) + scaler.update() + else: + loss.backward() + model_optim.step() + + if self.args.lradj == 'TST': + adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False) + scheduler.step() + + print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time)) + train_loss = np.average(train_loss) + vali_loss = self.vali(vali_data, vali_loader, criterion) + test_loss = self.vali(test_data, test_loader, criterion) + + print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format( + epoch + 1, train_steps, train_loss, vali_loss, test_loss)) + early_stopping(vali_loss, self.model, path) + if early_stopping.early_stop: + print("Early stopping") + break + + if self.args.lradj != 'TST': + adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args) + else: + print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0])) + + best_model_path = path + '/' + 'checkpoint.pth' + self.model.load_state_dict(torch.load(best_model_path, map_location="cuda:0")) + + + return self.model + + def test(self, setting, test=0): + test_data, test_loader = self._get_data(flag='test') + + if test: + print('loading model') + self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'), map_location="cuda:0")) + + preds = [] + trues = [] + inputx = [] + folder_path = './test_results/' + setting + '/' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + self.model.eval() + with torch.no_grad(): + for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): + batch_x = batch_x.float().to(self.device) + batch_y = batch_y.float().to(self.device) + + batch_x_mark = batch_x_mark.float().to(self.device) + batch_y_mark = batch_y_mark.float().to(self.device) + + # decoder input + dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float() + dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) + # encoder - decoder + if self.args.use_amp: + with torch.cuda.amp.autocast(): + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + else: + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + + f_dim = -1 if self.args.features == 'MS' else 0 + # print(outputs.shape,batch_y.shape) + outputs = outputs[:, -self.args.pred_len:, f_dim:] + batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device) + outputs = outputs.detach().cpu().numpy() + batch_y = batch_y.detach().cpu().numpy() + + pred = outputs # outputs.detach().cpu().numpy() # .squeeze() + true = batch_y # batch_y.detach().cpu().numpy() # .squeeze() + + preds.append(pred) + trues.append(true) + inputx.append(batch_x.detach().cpu().numpy()) + if i % 20 == 0: + input = batch_x.detach().cpu().numpy() + gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0) + pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0) + visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf')) + + if self.args.test_flop: + test_params_flop(self.model, (batch_x.shape[1],batch_x.shape[2])) + # test_params_flop((batch_x.shape[1], batch_x.shape[2])) + exit() + preds = np.array(preds) + trues = np.array(trues) + inputx = np.array(inputx) + + preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) + trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1]) + inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1]) + + # result save + folder_path = './results/' + setting + '/' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues) + print('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse)) + f = open("result.txt", 'a') + f.write(setting + " \n") + f.write('mse:{}, mae:{}, rse:{}'.format(mse, mae, rse)) + f.write('\n') + f.write('\n') + f.close() + + # np.save(folder_path + 'metrics.npy', np.array([mae, mse, rmse, mape, mspe,rse, corr])) + np.save(folder_path + 'pred.npy', preds) + # np.save(folder_path + 'true.npy', trues) + # np.save(folder_path + 'x.npy', inputx) + return + + def predict(self, setting, load=False): + pred_data, pred_loader = self._get_data(flag='pred') + + if load: + path = os.path.join(self.args.checkpoints, setting) + best_model_path = path + '/' + 'checkpoint.pth' + self.model.load_state_dict(torch.load(best_model_path)) + + preds = [] + + self.model.eval() + with torch.no_grad(): + for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader): + batch_x = batch_x.float().to(self.device) + batch_y = batch_y.float() + batch_x_mark = batch_x_mark.float().to(self.device) + batch_y_mark = batch_y_mark.float().to(self.device) + + # decoder input + dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to( + batch_y.device) + dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device) + # encoder - decoder + if self.args.use_amp: + with torch.cuda.amp.autocast(): + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + else: + if any(substr in self.args.model for substr in {'Linear', 'TST', 'SparseTSF'}): + outputs = self.model(batch_x) + else: + if self.args.output_attention: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] + else: + outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark) + pred = outputs.detach().cpu().numpy() # .squeeze() + preds.append(pred) + + preds = np.array(preds) + preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1]) + + # result save + folder_path = './results/' + setting + '/' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + np.save(folder_path + 'real_prediction.npy', preds) + + return diff --git a/layers/AutoCorrelation.py b/layers/AutoCorrelation.py new file mode 100644 index 0000000..323ca2b --- /dev/null +++ b/layers/AutoCorrelation.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +import math +from math import sqrt +import os + + +class AutoCorrelation(nn.Module): + """ + AutoCorrelation Mechanism with the following two phases: + (1) period-based dependencies discovery + (2) time delay aggregation + This block can replace the self-attention family mechanism seamlessly. + """ + def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False): + super(AutoCorrelation, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def time_delay_agg_training(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the training phase. + """ + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] + weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + pattern = torch.roll(tmp_values, -int(index[i]), -1) + delays_agg = delays_agg + pattern * \ + (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + return delays_agg + + def time_delay_agg_inference(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the inference phase. + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device) + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + weights = torch.topk(mean_value, top_k, dim=-1)[0] + delay = torch.topk(mean_value, top_k, dim=-1)[1] + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * \ + (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + return delays_agg + + def time_delay_agg_full(self, values, corr): + """ + Standard version of Autocorrelation + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).to(values.device) + # find top k + top_k = int(self.factor * math.log(length)) + weights = torch.topk(corr, top_k, dim=-1)[0] + delay = torch.topk(corr, top_k, dim=-1)[1] + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[..., i].unsqueeze(-1) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) + return delays_agg + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + + # period-based dependencies + q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) + k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) + res = q_fft * torch.conj(k_fft) + corr = torch.fft.irfft(res, dim=-1) + + # time delay agg + if self.training: + V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + else: + V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + + if self.output_attention: + return (V.contiguous(), corr.permute(0, 3, 1, 2)) + else: + return (V.contiguous(), None) + + +class AutoCorrelationLayer(nn.Module): + def __init__(self, correlation, d_model, n_heads, d_keys=None, + d_values=None): + super(AutoCorrelationLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_correlation = correlation + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_correlation( + queries, + keys, + values, + attn_mask + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn diff --git a/layers/Autoformer_EncDec.py b/layers/Autoformer_EncDec.py new file mode 100644 index 0000000..2bc4a3e --- /dev/null +++ b/layers/Autoformer_EncDec.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class my_Layernorm(nn.Module): + """ + Special designed layernorm for the seasonal part + """ + def __init__(self, channels): + super(my_Layernorm, self).__init__() + self.layernorm = nn.LayerNorm(channels) + + def forward(self, x): + x_hat = self.layernorm(x) + bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) + return x_hat - bias + + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + +class series_decomp_multi(nn.Module): + """ + Series decomposition block + """ + def __init__(self,kernel_size): + super(series_decomp_multi, self).__init__() + self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size] + self.layer = torch.nn.Linear(1,len(kernel_size)) + + def forward(self, x): + moving_mean=[] + for func in self.moving_avg: + moving_avg = func(x) + moving_mean.append(moving_avg.unsqueeze(-1)) + moving_mean=torch.cat(moving_mean,dim=-1) + moving_mean = torch.sum(moving_mean*nn.Softmax(-1)(self.layer(x.unsqueeze(-1))),dim=-1) + res = x - moving_mean + return res, moving_mean + + +class EncoderLayer(nn.Module): + """ + Autoformer encoder layer with the progressive decomposition architecture + """ + def __init__(self, attention, d_model, d_ff=None, moving_avg=25, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = series_decomp(moving_avg) + self.decomp2 = series_decomp(moving_avg) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + x, _ = self.decomp1(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + res, _ = self.decomp2(x + y) + return res, attn + + +class Encoder(nn.Module): + """ + Autoformer encoder + """ + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + """ + Autoformer decoder layer with the progressive decomposition architecture + """ + def __init__(self, self_attention, cross_attention, d_model, c_out, d_ff=None, + moving_avg=25, dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False) + self.decomp1 = series_decomp(moving_avg) + self.decomp2 = series_decomp(moving_avg) + self.decomp3 = series_decomp(moving_avg) + self.dropout = nn.Dropout(dropout) + self.projection = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=3, stride=1, padding=1, + padding_mode='circular', bias=False) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask + )[0]) + x, trend1 = self.decomp1(x) + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask + )[0]) + x, trend2 = self.decomp2(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + x, trend3 = self.decomp3(x + y) + + residual_trend = trend1 + trend2 + trend3 + residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(1, 2) + return x, residual_trend + + +class Decoder(nn.Module): + """ + Autoformer encoder + """ + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): + for layer in self.layers: + x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + trend = trend + residual_trend + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x, trend diff --git a/layers/Embed.py b/layers/Embed.py new file mode 100644 index 0000000..abdf903 --- /dev/null +++ b/layers/Embed.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm +import math + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] + + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) + return x + + +class FixedEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(FixedEmbedding, self).__init__() + + w = torch.zeros(c_in, d_model).float() + w.require_grad = False + + position = torch.arange(0, c_in).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + + w[:, 0::2] = torch.sin(position * div_term) + w[:, 1::2] = torch.cos(position * div_term) + + self.emb = nn.Embedding(c_in, d_model) + self.emb.weight = nn.Parameter(w, requires_grad=False) + + def forward(self, x): + return self.emb(x).detach() + + +class TemporalEmbedding(nn.Module): + def __init__(self, d_model, embed_type='fixed', freq='h'): + super(TemporalEmbedding, self).__init__() + + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 + + Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding + if freq == 't': + self.minute_embed = Embed(minute_size, d_model) + self.hour_embed = Embed(hour_size, d_model) + self.weekday_embed = Embed(weekday_size, d_model) + self.day_embed = Embed(day_size, d_model) + self.month_embed = Embed(month_size, d_model) + + def forward(self, x): + x = x.long() + + minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + + return hour_x + weekday_x + day_x + month_x + minute_x + + +class TimeFeatureEmbedding(nn.Module): + def __init__(self, d_model, embed_type='timeF', freq='h'): + super(TimeFeatureEmbedding, self).__init__() + + freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + d_inp = freq_map[freq] + self.embed = nn.Linear(d_inp, d_model, bias=False) + + def forward(self, x): + return self.embed(x) + + +class DataEmbedding(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + return self.dropout(x) + + +class DataEmbedding_wo_pos(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.temporal_embedding(x_mark) + return self.dropout(x) + +class DataEmbedding_wo_pos_temp(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_pos_temp, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + return self.dropout(x) + +class DataEmbedding_wo_temp(nn.Module): + def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + super(DataEmbedding_wo_temp, self).__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEmbedding(d_model=d_model) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, + freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( + d_model=d_model, embed_type=embed_type, freq=freq) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_mark): + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x) \ No newline at end of file diff --git a/layers/FED_FourierCorrelation.py b/layers/FED_FourierCorrelation.py new file mode 100644 index 0000000..a829026 --- /dev/null +++ b/layers/FED_FourierCorrelation.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# author=maziqing +# email=maziqing.mzq@alibaba-inc.com + +import numpy as np +import torch +import torch.nn as nn + + +def get_frequency_modes(seq_len, modes=64, mode_select_method='random'): + """ + get modes on frequency domain: + 'random' means sampling randomly; + 'else' means sampling the lowest modes; + """ + modes = min(modes, seq_len // 2) + if mode_select_method == 'random': + index = list(range(0, seq_len // 2)) + np.random.shuffle(index) + index = index[:modes] + else: + index = list(range(0, modes)) + index.sort() + return index + + +# ########## fourier layer ############# +class FourierBlock(nn.Module): + def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'): + super(FourierBlock, self).__init__() + print('fourier enhanced block used!') + """ + 1D Fourier block. It performs representation learning on frequency domain, + it does FFT, linear transform, and Inverse FFT. + """ + # get modes on frequency domain + self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method) + print('modes={}, index={}'.format(modes, self.index)) + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter( + self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat)) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + B, L, H, E = q.shape + x = q.permute(0, 2, 3, 1) + # Compute Fourier coefficients + x_ft = torch.fft.rfft(x, dim=-1) + # Perform Fourier neural operations + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) + for wi, i in enumerate(self.index): + if i >= x_ft.shape[3] or wi >= out_ft.shape[3]: + continue + out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi]) + # Return to time domain + x = torch.fft.irfft(out_ft, n=x.size(-1)) + return (x, None) + + +# ########## Fourier Cross Former #################### +class FourierCrossAttention(nn.Module): + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random', + activation='tanh', policy=0): + super(FourierCrossAttention, self).__init__() + print(' fourier enhanced cross attention used!') + """ + 1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT. + """ + self.activation = activation + self.in_channels = in_channels + self.out_channels = out_channels + # get modes for queries and keys (& values) on frequency domain + self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method) + self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method) + + print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q)) + print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv)) + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter( + self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat)) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + B, L, H, E = q.shape + xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] + xk = k.permute(0, 2, 3, 1) + xv = v.permute(0, 2, 3, 1) + + # Compute Fourier coefficients + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + if j >= xq_ft.shape[3]: + continue + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_kv): + if j >= xk_ft.shape[3]: + continue + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + + # perform attention mechanism on frequency domain + xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)) + if self.activation == 'tanh': + xqk_ft = xqk_ft.tanh() + elif self.activation == 'softmax': + xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) + xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) + else: + raise Exception('{} actiation function is not implemented'.format(self.activation)) + xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) + xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1) + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + if i >= xqkvw.shape[3] or j >= out_ft.shape[3]: + continue + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + # Return to time domain + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) + return (out, None) diff --git a/layers/FourierCorrelation.py b/layers/FourierCorrelation.py new file mode 100644 index 0000000..23c95f7 --- /dev/null +++ b/layers/FourierCorrelation.py @@ -0,0 +1,512 @@ +# coding=utf-8 +# author=maziqing +# email=maziqing.mzq@alibaba-inc.com + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from utils.masking import LocalMask + + +# from layers.mwt import MWT_CZ1d + +def get_dynamic_modes(seq_len, modes): + rate1 = seq_len // 96 + if rate1 <= 1: + index = list(range(0, min(seq_len // 2, modes), 1)) + else: + indexes = [i * seq_len / 96 for i in range(0, modes, 1)] + indexes = [i for i in indexes if i <= seq_len // 2] + indexes1 = list(range(0, min(seq_len // 2, modes // 3))) + for i in indexes: + if i % 1 == 0: + indexes1 += [int(i)] + else: + indexes1 += [int(i)] + indexes1 += [int(i) + 1] + index = list(set(indexes1)) + index.sort() + return index[:modes] + + +# Cross Fourier Former +class SpectralConvCross1d(nn.Module): + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes1=0, policy=0): + super(SpectralConvCross1d, self).__init__() + print('corss fourier correlation used!') + + """ + 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + # self.modes1 = seq_len // 2 + self.modes1 = modes1 + # Number of Fourier modes to multiply, at most floor(N/2) + 1 + if modes1 > 10000: + modes2 = modes1 - 10000 + self.index_q0 = list(range(0, min(seq_len_q // 4, modes2 // 2))) + self.index_q1 = list(range(len(self.index_q0), seq_len_q // 2)) + np.random.shuffle(self.index_q1) + self.index_q1 = self.index_q1[:min(seq_len_q // 4, modes2 // 2)] + self.index_q = self.index_q0 + self.index_q1 + self.index_q.sort() + + self.index_k_v0 = list(range(0, min(seq_len_kv // 4, modes2 // 2))) + self.index_k_v1 = list(range(len(self.index_k_v0), seq_len_kv // 2)) + np.random.shuffle(self.index_k_v1) + self.index_k_v1 = self.index_k_v1[:min(seq_len_kv // 4, modes2 // 2)] + self.index_k_v = self.index_k_v0 + self.index_k_v1 + self.index_k_v.sort() + + elif modes1 > 1000: + modes2 = modes1 - 1000 + self.index_q = list(range(0, seq_len_q // 2)) + np.random.shuffle(self.index_q) + self.index_q = self.index_q[:modes2] + self.index_q.sort() + self.index_k_v = list(range(0, seq_len_kv // 2)) + np.random.shuffle(self.index_k_v) + self.index_k_v = self.index_k_v[:modes2] + self.index_k_v.sort() + elif modes1 < 0: + modes2 = abs(modes1) + self.index_q = get_dynamic_modes(seq_len_q, modes2) + self.index_k_v = list(range(0, min(seq_len_kv // 2, modes2))) + else: + self.index_q = list(range(0, min(seq_len_q // 2, modes1))) + self.index_k_v = list(range(0, min(seq_len_kv // 2, modes1))) + + print('index_q={}'.format(self.index_q)) + print('len mode q={}', len(self.index_q)) + print('index_k_v={}'.format(self.index_k_v)) + print('len mode kv={}', len(self.index_k_v)) + + self.register_buffer('index_q2', torch.tensor(self.index_q)) + # modes = len(index) + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter( + self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat)) + + # self.conv = nn.Conv1d(in_channels=in_channels//8, out_channels=out_channels//8, kernel_size=3) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + mask = mask + B, L, H, E = q.shape + xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] + xk = k.permute(0, 2, 3, 1) + xv = v.permute(0, 2, 3, 1) + # print('xq shape',xq.shape) + # print('xk shape',xk.shape) + # print('xv shape',xv.shape) + + # Compute Fourier coeffcients up to factor of e^(- something constant) + # xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + + xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_k_v): + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + + xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)) + xqk_ft = xqk_ft.tanh() + # xqk_ft = torch.softmax(abs(xqk_ft),dim=-1) + # xqk_ft = torch.complex(xqk_ft,torch.zeros_like(xqk_ft)) + xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) + xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1) + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + # out_ft = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + # max_len = min(xq.size(-1),720) + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) + # raise Exception('aaa') + # size = [B, L, H, E] + # out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=max_len) + return (out, None) + + +class SpectralConvCross1d_local(nn.Module): + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes1=0): + super(SpectralConvCross1d_local, self).__init__() + print('corss fourier correlation used!') + + """ + 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + # self.modes1 = seq_len // 2 + self.modes1 = modes1 + # Number of Fourier modes to multiply, at most floor(N/2) + 1 + if modes1 > 10000: + modes2 = modes1 - 10000 + self.index_q0 = list(range(0, min(seq_len_q // 4, modes2 // 2))) + self.index_q1 = list(range(len(self.index_q0), seq_len_q // 2)) + np.random.shuffle(self.index_q1) + self.index_q1 = self.index_q1[:min(seq_len_q // 4, modes2 // 2)] + self.index_q = self.index_q0 + self.index_q1 + self.index_q.sort() + + self.index_k_v0 = list(range(0, min(seq_len_kv // 4, modes2 // 2))) + self.index_k_v1 = list(range(len(self.index_k_v0), seq_len_kv // 2)) + np.random.shuffle(self.index_k_v1) + self.index_k_v1 = self.index_k_v1[:min(seq_len_kv // 4, modes2 // 2)] + self.index_k_v = self.index_k_v0 + self.index_k_v1 + self.index_k_v.sort() + + elif modes1 > 1000: + modes2 = modes1 - 1000 + self.index_q = list(range(0, seq_len_q // 2)) + np.random.shuffle(self.index_q) + self.index_q = self.index_q[:modes2] + self.index_q.sort() + self.index_k_v = list(range(0, seq_len_kv // 2)) + np.random.shuffle(self.index_k_v) + self.index_k_v = self.index_k_v[:modes2] + self.index_k_v.sort() + elif modes1 < 0: + modes2 = abs(modes1) + self.index_q = get_dynamic_modes(seq_len_q, modes2) + self.index_k_v = list(range(0, min(seq_len_kv // 2, modes2))) + else: + self.index_q = list(range(0, min(seq_len_q // 2, modes1))) + self.index_k_v = list(range(0, min(seq_len_kv // 2, modes1))) + print('index_q={}'.format(self.index_q)) + print('index_k_v={}'.format(self.index_k_v)) + + self.register_buffer('index_q2', torch.tensor(self.index_q)) + # modes = len(index) + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter( + self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat)) + + # mask = self.log_mask(win_len,sub_len) + # self.register_buffer('mask_tri',mask) + + # self.conv = nn.Conv1d(in_channels=in_channels//8, out_channels=out_channels//8, kernel_size=3) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + mask = mask + B, L, H, E = q.shape + _, S, _, _ = k.shape + + if L > S: + zeros = torch.zeros_like(q[:, :(L - S), :]).float() + v = torch.cat([v, zeros], dim=1) + k = torch.cat([k, zeros], dim=1) + else: + v = v[:, :L, :, :] + k = k[:, :L, :, :] + + scores = torch.einsum("blhe,bshe->bhls", q, k) + scale = self.scale or 1. / sqrt(E) + + if mask is None: + mask = LocalMask(B, L, L, device=q.device) + + scores.masked_fill_(mask.mask, -np.inf) + + A_local = torch.softmax(scale * scores, dim=-1) + V_local = torch.einsum("bhls,bshd->blhd", A_local, v) + # print(V_local) + # print(A_local.shape) + # print(q.shape) + # print(V_local.shape) + + xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L] + xk = k.permute(0, 2, 3, 1) + xv = v.permute(0, 2, 3, 1) + + # Compute Fourier coeffcients up to factor of e^(- something constant) + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + + xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_k_v): + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + + xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)) + xqk_ft = xqk_ft.tanh() + xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) + xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1) + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)) + out = out.permute(0, 3, 1, 2) + # print(out.shape) + out = (out + V_local) / 2 + # raise Exception('aaa') + # size = [B, L, H, E] + return (out, None) + + ################################################################ + + +# 1d fourier layer +################################################################ +class SpectralConv1d(nn.Module): + def __init__(self, in_channels, out_channels, seq_len, modes1=0): + super(SpectralConv1d, self).__init__() + print('fourier correlation used!') + + """ + 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + + self.in_channels = in_channels + self.out_channels = out_channels + # self.modes1 = seq_len // 2 + self.modes1 = modes1 + # Number of Fourier modes to multiply, at most floor(N/2) + 1 + if modes1 > 10000: + modes2 = modes1 - 10000 + self.index0 = list(range(0, min(seq_len // 4, modes2 // 2))) + self.index1 = list(range(len(self.index0), seq_len // 2)) + np.random.shuffle(self.index1) + self.index1 = self.index1[:min(seq_len // 4, modes2 // 2)] + self.index = self.index0 + self.index1 + self.index.sort() + elif modes1 > 1000: + modes2 = modes1 - 1000 + self.index = list(range(0, seq_len // 2)) + np.random.shuffle(self.index) + self.index = self.index[:modes2] + else: + self.index = list(range(0, min(seq_len // 2, modes1))) + + print('modes1={}, index={}'.format(modes1, self.index)) + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter( + self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat)) + # self.conv = nn.Conv1d(in_channels=in_channels//8, out_channels=out_channels//8, kernel_size=3) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + k = k + v = v + mask = mask + B, L, H, E = q.shape + x = q.permute(0, 2, 3, 1) + # batchsize = B + # Compute Fourier coeffcients up to factor of e^(- something constant) + x_ft = torch.fft.rfft(x, dim=-1) + # out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) + # if len(self.index)==0: + # out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) + # else: + # out_ft = torch.zeros(B, H, E, len(self.index), device=x.device, dtype=torch.cfloat) + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) + + # Multiply relevant Fourier modes + # 取guided modes的版本 + # print('x shape',x.shape) + # print('out_ft shape',out_ft.shape) + # print('x_ft shape',x_ft.shape) + # print('weight shape',self.weights1.shape) + # print('self index',self.index) + for wi, i in enumerate(self.index): + out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi]) + + # 取topk的modes版本 + # topk = torch.topk(torch.sum(x_ft, dim=[0, 1, 2]).abs(), dim=-1, k=self.modes1) + # energy = (topk[0]**2).sum() + # energy90 = 0 + # for index, j in enumerate(topk[0]): + # energy90 += j**2 + # if energy90 >= energy * 0.9: + # break + # for i in topk[1][:index]: + # out_ft[:, :, :, i] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, i]) + + # Return to physical space + x = torch.fft.irfft(out_ft, n=x.size(-1)) + # max_len = min(720,x.size(-1)) + # x = torch.fft.irfft(out_ft, n=max_len) + # size = [B, L, H, E] + return (x, None) + + +class SpectralConv1d_local(nn.Module): + def __init__(self, in_channels, out_channels, seq_len, modes1=0): + super(SpectralConv1d_local, self).__init__() + print('fourier correlation used!') + + """ + 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + + self.in_channels = in_channels + self.out_channels = out_channels + # self.modes1 = seq_len // 2 + self.modes1 = modes1 + # Number of Fourier modes to multiply, at most floor(N/2) + 1 + if modes1 > 10000: + modes2 = modes1 - 10000 + self.index0 = list(range(0, min(seq_len // 4, modes2 // 2))) + self.index1 = list(range(len(self.index0), seq_len // 2)) + np.random.shuffle(self.index1) + self.index1 = self.index1[:min(seq_len // 4, modes2 // 2)] + self.index = self.index0 + self.index1 + self.index.sort() + elif modes1 > 1000: + modes2 = modes1 - 1000 + self.index = list(range(0, seq_len // 2)) + np.random.shuffle(self.index) + self.index = self.index[:modes2] + else: + self.index = list(range(0, min(seq_len // 2, modes1))) + + print('modes1={}, index={}'.format(modes1, self.index)) + + self.scale = (1 / (in_channels * out_channels)) + self.weights1 = nn.Parameter( + self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat)) + # self.conv = nn.Conv1d(in_channels=in_channels//8, out_channels=out_channels//8, kernel_size=3) + + # Complex multiplication + def compl_mul1d(self, input, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bhi,hio->bho", input, weights) + + def forward(self, q, k, v, mask): + # break + # print('local start') + # size = [B, L, H, E] + k = k + v = v + mask = mask + B, L, H, E = q.shape + + _, S, _, _ = k.shape + + scores = torch.einsum("blhe,bshe->bhls", q, k) + scale = self.scale or 1. / sqrt(E) + + if mask is None: + mask = LocalMask(B, L, L, device=q.device) + + scores.masked_fill_(mask.mask, -np.inf) + + A_local = torch.softmax(scale * scores, dim=-1) + V_local = torch.einsum("bhls,bshd->blhd", A_local, v) + + x = q.permute(0, 2, 3, 1) + + x_ft = torch.fft.rfft(x, dim=-1) + + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat) + + for wi, i in enumerate(self.index): + out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi]) + + x = torch.fft.irfft(out_ft, n=x.size(-1)) + x = (x + V_local) / 2 + # print(break) + + return (x, None) + + +class FNO1d(nn.Module): + def __init__(self, modes, width): + super(FNO1d, self).__init__() + + """ + The overall network. It contains 4 layers of the Fourier layer. + 1. Lift the input to the desire channel dimension by self.fc0 . + 2. 4 layers of the integral operators u' = (W + K)(u). + W defined by self.w; K defined by self.conv . + 3. Project from the channel space to the output space by self.fc1 and self.fc2 . + + input: the solution of the initial condition and location (a(x), x) + input shape: (batchsize, x=s, c=2) + output: the solution of a later timestep + output shape: (batchsize, x=s, c=1) + """ + + self.modes1 = 16 + self.width = width + self.padding = 2 # pad the domain if input is non-periodic + self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) + + self.conv0 = SpectralConv1d(self.width, self.width, self.modes1) + self.conv1 = SpectralConv1d(self.width, self.width, self.modes1) + self.conv2 = SpectralConv1d(self.width, self.width, self.modes1) + self.conv3 = SpectralConv1d(self.width, self.width, self.modes1) + self.w0 = nn.Conv1d(self.width, self.width, 1) + self.w1 = nn.Conv1d(self.width, self.width, 1) + self.w2 = nn.Conv1d(self.width, self.width, 1) + self.w3 = nn.Conv1d(self.width, self.width, 1) + + self.fc1 = nn.Linear(self.width, 128) + self.fc2 = nn.Linear(128, 1) + + def forward(self, x): + grid = self.get_grid(x.shape, x.device) + x = torch.cat((x, grid), dim=-1) + x = self.fc0(x) + x = x.permute(0, 2, 1) + # x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic + + x1 = self.conv0(x) + x2 = self.w0(x) + x = x1 + x2 + x = F.gelu(x) + + x1 = self.conv1(x) + x2 = self.w1(x) + x = x1 + x2 + x = F.gelu(x) + + x1 = self.conv2(x) + x2 = self.w2(x) + x = x1 + x2 + x = F.gelu(x) + + x1 = self.conv3(x) + x2 = self.w3(x) + x = x1 + x2 + + # x = x[..., :-self.padding] # pad the domain if input is non-periodic + x = x.permute(0, 2, 1) + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + return x + + def get_grid(self, shape, device): + batchsize, size_x = shape[0], shape[1] + gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) + gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) + return gridx.to(device) \ No newline at end of file diff --git a/layers/MultiWaveletCorrelation.py b/layers/MultiWaveletCorrelation.py new file mode 100644 index 0000000..e56f6d2 --- /dev/null +++ b/layers/MultiWaveletCorrelation.py @@ -0,0 +1,379 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from typing import List, Tuple +import math +from functools import partial +from einops import rearrange, reduce, repeat +from torch import nn, einsum, diagonal +from math import log2, ceil +import pdb +from utils.masking import LocalMask +from layers.utils import get_filter + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class MultiWaveletTransform(nn.Module): + """ + 1D multiwavelet block. + """ + def __init__(self, ich=1, k=8, alpha=16, c=128, + nCZ=1, L=0, base='legendre', attention_dropout=0.1): + super(MultiWaveletTransform, self).__init__() + print('base', base) + self.k = k + self.c = c + self.L = L + self.nCZ = nCZ + self.Lk0 = nn.Linear(ich, c * k) + self.Lk1 = nn.Linear(c * k, ich) + self.ich = ich + self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ)) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + values = values.view(B, L, -1) + + V = self.Lk0(values).view(B, L, self.c, -1) + for i in range(self.nCZ): + V = self.MWT_CZ[i](V) + if i < self.nCZ - 1: + V = F.relu(V) + + V = self.Lk1(V.view(B, L, -1)) + V = V.view(B, L, -1, D) + return (V.contiguous(), None) + + +class MultiWaveletCross(nn.Module): + """ + 1D Multiwavelet Cross Attention layer. + """ + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes, c=64, + k=8, ich=512, + L=0, + base='legendre', + mode_select_method='random', + initializer=None, activation='tanh', + **kwargs): + super(MultiWaveletCross, self).__init__() + print('base', base) + + self.c = c + self.k = k + self.L = L + H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) + H0r = H0 @ PHI0 + G0r = G0 @ PHI0 + H1r = H1 @ PHI1 + G1r = G1 @ PHI1 + + H0r[np.abs(H0r) < 1e-8] = 0 + H1r[np.abs(H1r) < 1e-8] = 0 + G0r[np.abs(G0r) < 1e-8] = 0 + G1r[np.abs(G1r) < 1e-8] = 0 + self.max_item = 3 + + self.attn1 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes=modes, activation=activation, + mode_select_method=mode_select_method) + self.attn2 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes=modes, activation=activation, + mode_select_method=mode_select_method) + self.attn3 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes=modes, activation=activation, + mode_select_method=mode_select_method) + self.attn4 = FourierCrossAttentionW(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes=modes, activation=activation, + mode_select_method=mode_select_method) + self.T0 = nn.Linear(k, k) + self.register_buffer('ec_s', torch.Tensor( + np.concatenate((H0.T, H1.T), axis=0))) + self.register_buffer('ec_d', torch.Tensor( + np.concatenate((G0.T, G1.T), axis=0))) + + self.register_buffer('rc_e', torch.Tensor( + np.concatenate((H0r, G0r), axis=0))) + self.register_buffer('rc_o', torch.Tensor( + np.concatenate((H1r, G1r), axis=0))) + + self.Lk = nn.Linear(ich, c * k) + self.Lq = nn.Linear(ich, c * k) + self.Lv = nn.Linear(ich, c * k) + self.out = nn.Linear(c * k, ich) + self.modes1 = modes + + def forward(self, q, k, v, mask=None): + B, N, H, E = q.shape # (B, N, H, E) torch.Size([3, 768, 8, 2]) + _, S, _, _ = k.shape # (B, S, H, E) torch.Size([3, 96, 8, 2]) + + q = q.view(q.shape[0], q.shape[1], -1) + k = k.view(k.shape[0], k.shape[1], -1) + v = v.view(v.shape[0], v.shape[1], -1) + q = self.Lq(q) + q = q.view(q.shape[0], q.shape[1], self.c, self.k) + k = self.Lk(k) + k = k.view(k.shape[0], k.shape[1], self.c, self.k) + v = self.Lv(v) + v = v.view(v.shape[0], v.shape[1], self.c, self.k) + + if N > S: + zeros = torch.zeros_like(q[:, :(N - S), :]).float() + v = torch.cat([v, zeros], dim=1) + k = torch.cat([k, zeros], dim=1) + else: + v = v[:, :N, :, :] + k = k[:, :N, :, :] + + ns = math.floor(np.log2(N)) + nl = pow(2, math.ceil(np.log2(N))) + extra_q = q[:, 0:nl - N, :, :] + extra_k = k[:, 0:nl - N, :, :] + extra_v = v[:, 0:nl - N, :, :] + q = torch.cat([q, extra_q], 1) + k = torch.cat([k, extra_k], 1) + v = torch.cat([v, extra_v], 1) + + Ud_q = torch.jit.annotate(List[Tuple[Tensor]], []) + Ud_k = torch.jit.annotate(List[Tuple[Tensor]], []) + Ud_v = torch.jit.annotate(List[Tuple[Tensor]], []) + + Us_q = torch.jit.annotate(List[Tensor], []) + Us_k = torch.jit.annotate(List[Tensor], []) + Us_v = torch.jit.annotate(List[Tensor], []) + + Ud = torch.jit.annotate(List[Tensor], []) + Us = torch.jit.annotate(List[Tensor], []) + + # decompose + for i in range(ns - self.L): + # print('q shape',q.shape) + d, q = self.wavelet_transform(q) + Ud_q += [tuple([d, q])] + Us_q += [d] + for i in range(ns - self.L): + d, k = self.wavelet_transform(k) + Ud_k += [tuple([d, k])] + Us_k += [d] + for i in range(ns - self.L): + d, v = self.wavelet_transform(v) + Ud_v += [tuple([d, v])] + Us_v += [d] + for i in range(ns - self.L): + dk, sk = Ud_k[i], Us_k[i] + dq, sq = Ud_q[i], Us_q[i] + dv, sv = Ud_v[i], Us_v[i] + Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]] + Us += [self.attn3(sq, sk, sv, mask)[0]] + v = self.attn4(q, k, v, mask)[0] + + # reconstruct + for i in range(ns - 1 - self.L, -1, -1): + v = v + Us[i] + v = torch.cat((v, Ud[i]), -1) + v = self.evenOdd(v) + v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1)) + return (v.contiguous(), None) + + def wavelet_transform(self, x): + xa = torch.cat([x[:, ::2, :, :], + x[:, 1::2, :, :], + ], -1) + d = torch.matmul(xa, self.ec_d) + s = torch.matmul(xa, self.ec_s) + return d, s + + def evenOdd(self, x): + B, N, c, ich = x.shape # (B, N, c, k) + assert ich == 2 * self.k + x_e = torch.matmul(x, self.rc_e) + x_o = torch.matmul(x, self.rc_o) + + x = torch.zeros(B, N * 2, c, self.k, + device=x.device) + x[..., ::2, :, :] = x_e + x[..., 1::2, :, :] = x_o + return x + + +class FourierCrossAttentionW(nn.Module): + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=16, activation='tanh', + mode_select_method='random'): + super(FourierCrossAttentionW, self).__init__() + print('corss fourier correlation used!') + self.in_channels = in_channels + self.out_channels = out_channels + self.modes1 = modes + self.activation = activation + + def forward(self, q, k, v, mask): + B, L, E, H = q.shape + + xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] torch.Size([3, 8, 64, 512]) + xk = k.permute(0, 3, 2, 1) + xv = v.permute(0, 3, 2, 1) + self.index_q = list(range(0, min(int(L // 2), self.modes1))) + self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1))) + + # Compute Fourier coefficients + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + + xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_k_v): + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)) + if self.activation == 'tanh': + xqk_ft = xqk_ft.tanh() + elif self.activation == 'softmax': + xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) + xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) + else: + raise Exception('{} actiation function is not implemented'.format(self.activation)) + xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) + + xqkvw = xqkv_ft + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1) + # size = [B, L, H, E] + return (out, None) + + +class sparseKernelFT1d(nn.Module): + def __init__(self, + k, alpha, c=1, + nl=1, + initializer=None, + **kwargs): + super(sparseKernelFT1d, self).__init__() + + self.modes1 = alpha + self.scale = (1 / (c * k * c * k)) + self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.cfloat)) + self.weights1.requires_grad = True + self.k = k + + def compl_mul1d(self, x, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bix,iox->box", x, weights) + + def forward(self, x): + B, N, c, k = x.shape # (B, N, c, k) + + x = x.view(B, N, -1) + x = x.permute(0, 2, 1) + x_fft = torch.fft.rfft(x) + # Multiply relevant Fourier modes + l = min(self.modes1, N // 2 + 1) + # l = N//2+1 + out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat) + out_ft[:, :, :l] = self.compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l]) + x = torch.fft.irfft(out_ft, n=N) + x = x.permute(0, 2, 1).view(B, N, c, k) + return x + + +# ## +class MWT_CZ1d(nn.Module): + def __init__(self, + k=3, alpha=64, + L=0, c=1, + base='legendre', + initializer=None, + **kwargs): + super(MWT_CZ1d, self).__init__() + + self.k = k + self.L = L + H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) + H0r = H0 @ PHI0 + G0r = G0 @ PHI0 + H1r = H1 @ PHI1 + G1r = G1 @ PHI1 + + H0r[np.abs(H0r) < 1e-8] = 0 + H1r[np.abs(H1r) < 1e-8] = 0 + G0r[np.abs(G0r) < 1e-8] = 0 + G1r[np.abs(G1r) < 1e-8] = 0 + self.max_item = 3 + + self.A = sparseKernelFT1d(k, alpha, c) + self.B = sparseKernelFT1d(k, alpha, c) + self.C = sparseKernelFT1d(k, alpha, c) + + self.T0 = nn.Linear(k, k) + + self.register_buffer('ec_s', torch.Tensor( + np.concatenate((H0.T, H1.T), axis=0))) + self.register_buffer('ec_d', torch.Tensor( + np.concatenate((G0.T, G1.T), axis=0))) + + self.register_buffer('rc_e', torch.Tensor( + np.concatenate((H0r, G0r), axis=0))) + self.register_buffer('rc_o', torch.Tensor( + np.concatenate((H1r, G1r), axis=0))) + + def forward(self, x): + B, N, c, k = x.shape # (B, N, k) + ns = math.floor(np.log2(N)) + nl = pow(2, math.ceil(np.log2(N))) + extra_x = x[:, 0:nl - N, :, :] + x = torch.cat([x, extra_x], 1) + Ud = torch.jit.annotate(List[Tensor], []) + Us = torch.jit.annotate(List[Tensor], []) + # decompose + for i in range(ns - self.L): + # print('x shape',x.shape) + d, x = self.wavelet_transform(x) + Ud += [self.A(d) + self.B(x)] + Us += [self.C(d)] + x = self.T0(x) # coarsest scale transform + + # reconstruct + for i in range(ns - 1 - self.L, -1, -1): + x = x + Us[i] + x = torch.cat((x, Ud[i]), -1) + x = self.evenOdd(x) + x = x[:, :N, :, :] + + return x + + def wavelet_transform(self, x): + xa = torch.cat([x[:, ::2, :, :], + x[:, 1::2, :, :], + ], -1) + d = torch.matmul(xa, self.ec_d) + s = torch.matmul(xa, self.ec_s) + return d, s + + def evenOdd(self, x): + + B, N, c, ich = x.shape # (B, N, c, k) + assert ich == 2 * self.k + x_e = torch.matmul(x, self.rc_e) + x_o = torch.matmul(x, self.rc_o) + + x = torch.zeros(B, N * 2, c, self.k, + device=x.device) + x[..., ::2, :, :] = x_e + x[..., 1::2, :, :] = x_o + return x \ No newline at end of file diff --git a/layers/PatchTST_backbone.py b/layers/PatchTST_backbone.py new file mode 100644 index 0000000..55f7956 --- /dev/null +++ b/layers/PatchTST_backbone.py @@ -0,0 +1,379 @@ +__all__ = ['PatchTST_backbone'] + +# Cell +from typing import Callable, Optional +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import numpy as np + +#from collections import OrderedDict +from layers.PatchTST_layers import * +from layers.RevIN import RevIN + +# Cell +class PatchTST_backbone(nn.Module): + def __init__(self, c_in:int, context_window:int, target_window:int, patch_len:int, stride:int, max_seq_len:Optional[int]=1024, + n_layers:int=3, d_model=128, n_heads=16, d_k:Optional[int]=None, d_v:Optional[int]=None, + d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0., dropout:float=0., act:str="gelu", key_padding_mask:bool='auto', + padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False, + pe:str='zeros', learn_pe:bool=True, fc_dropout:float=0., head_dropout = 0, padding_patch = None, + pretrain_head:bool=False, head_type = 'flatten', individual = False, revin = True, affine = True, subtract_last = False, + verbose:bool=False, **kwargs): + + super().__init__() + + # RevIn + self.revin = revin + if self.revin: self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last) + + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch = padding_patch + patch_num = int((context_window - patch_len)/stride + 1) + if padding_patch == 'end': # can be modified to general case + self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) + patch_num += 1 + + # Backbone + self.backbone = TSTiEncoder(c_in, patch_num=patch_num, patch_len=patch_len, max_seq_len=max_seq_len, + n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, + attn_dropout=attn_dropout, dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, + attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, + pe=pe, learn_pe=learn_pe, verbose=verbose, **kwargs) + + # Head + self.head_nf = d_model * patch_num + self.n_vars = c_in + self.pretrain_head = pretrain_head + self.head_type = head_type + self.individual = individual + + if self.pretrain_head: + self.head = self.create_pretrain_head(self.head_nf, c_in, fc_dropout) # custom head passed as a partial func with all its kwargs + elif head_type == 'flatten': + self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, target_window, head_dropout=head_dropout) + + + def forward(self, z): # z: [bs x nvars x seq_len] + # norm + if self.revin: + z = z.permute(0,2,1) + z = self.revin_layer(z, 'norm') + z = z.permute(0,2,1) + + # do patching + if self.padding_patch == 'end': + z = self.padding_patch_layer(z) + z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) # z: [bs x nvars x patch_num x patch_len] + z = z.permute(0,1,3,2) # z: [bs x nvars x patch_len x patch_num] + + # model + z = self.backbone(z) # z: [bs x nvars x d_model x patch_num] + z = self.head(z) # z: [bs x nvars x target_window] + + # denorm + if self.revin: + z = z.permute(0,2,1) + z = self.revin_layer(z, 'denorm') + z = z.permute(0,2,1) + return z + + def create_pretrain_head(self, head_nf, vars, dropout): + return nn.Sequential(nn.Dropout(dropout), + nn.Conv1d(head_nf, vars, 1) + ) + + +class Flatten_Head(nn.Module): + def __init__(self, individual, n_vars, nf, target_window, head_dropout=0): + super().__init__() + + self.individual = individual + self.n_vars = n_vars + + if self.individual: + self.linears = nn.ModuleList() + self.dropouts = nn.ModuleList() + self.flattens = nn.ModuleList() + for i in range(self.n_vars): + self.flattens.append(nn.Flatten(start_dim=-2)) + self.linears.append(nn.Linear(nf, target_window)) + self.dropouts.append(nn.Dropout(head_dropout)) + else: + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + if self.individual: + x_out = [] + for i in range(self.n_vars): + z = self.flattens[i](x[:,i,:,:]) # z: [bs x d_model * patch_num] + z = self.linears[i](z) # z: [bs x target_window] + z = self.dropouts[i](z) + x_out.append(z) + x = torch.stack(x_out, dim=1) # x: [bs x nvars x target_window] + else: + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + + + + +class TSTiEncoder(nn.Module): #i means channel-independent + def __init__(self, c_in, patch_num, patch_len, max_seq_len=1024, + n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None, + d_ff=256, norm='BatchNorm', attn_dropout=0., dropout=0., act="gelu", store_attn=False, + key_padding_mask='auto', padding_var=None, attn_mask=None, res_attention=True, pre_norm=False, + pe='zeros', learn_pe=True, verbose=False, **kwargs): + + + super().__init__() + + self.patch_num = patch_num + self.patch_len = patch_len + + # Input encoding + q_len = patch_num + self.W_P = nn.Linear(patch_len, d_model) # Eq 1: projection of feature vectors onto a d-dim vector space + self.seq_len = q_len + + # Positional encoding + self.W_pos = positional_encoding(pe, learn_pe, q_len, d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + # Encoder + self.encoder = TSTEncoder(q_len, d_model, n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, + pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=n_layers, store_attn=store_attn) + + + def forward(self, x) -> Tensor: # x: [bs x nvars x patch_len x patch_num] + + n_vars = x.shape[1] + # Input encoding + x = x.permute(0,1,3,2) # x: [bs x nvars x patch_num x patch_len] + x = self.W_P(x) # x: [bs x nvars x patch_num x d_model] + + u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3])) # u: [bs * nvars x patch_num x d_model] + u = self.dropout(u + self.W_pos) # u: [bs * nvars x patch_num x d_model] + + # Encoder + z = self.encoder(u) # z: [bs * nvars x patch_num x d_model] + z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1])) # z: [bs x nvars x patch_num x d_model] + z = z.permute(0,1,3,2) # z: [bs x nvars x d_model x patch_num] + + return z + + + +# Cell +class TSTEncoder(nn.Module): + def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=None, + norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu', + res_attention=False, n_layers=1, pre_norm=False, store_attn=False): + super().__init__() + + self.layers = nn.ModuleList([TSTEncoderLayer(q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, + attn_dropout=attn_dropout, dropout=dropout, + activation=activation, res_attention=res_attention, + pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)]) + self.res_attention = res_attention + + def forward(self, src:Tensor, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + output = src + scores = None + if self.res_attention: + for mod in self.layers: output, scores = mod(output, prev=scores, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return output + else: + for mod in self.layers: output = mod(output, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + return output + + + +class TSTEncoderLayer(nn.Module): + def __init__(self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False, + norm='BatchNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False): + super().__init__() + assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" + d_k = d_model // n_heads if d_k is None else d_k + d_v = d_model // n_heads if d_v is None else d_v + + # Multi-Head attention + self.res_attention = res_attention + self.self_attn = _MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention) + + # Add & Norm + self.dropout_attn = nn.Dropout(dropout) + if "batch" in norm.lower(): + self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) + else: + self.norm_attn = nn.LayerNorm(d_model) + + # Position-wise Feed-Forward + self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias), + get_activation_fn(activation), + nn.Dropout(dropout), + nn.Linear(d_ff, d_model, bias=bias)) + + # Add & Norm + self.dropout_ffn = nn.Dropout(dropout) + if "batch" in norm.lower(): + self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) + else: + self.norm_ffn = nn.LayerNorm(d_model) + + self.pre_norm = pre_norm + self.store_attn = store_attn + + + def forward(self, src:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None) -> Tensor: + + # Multi-Head attention sublayer + if self.pre_norm: + src = self.norm_attn(src) + ## Multi-Head attention + if self.res_attention: + src2, attn, scores = self.self_attn(src, src, src, prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + else: + src2, attn = self.self_attn(src, src, src, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + if self.store_attn: + self.attn = attn + ## Add & Norm + src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout + if not self.pre_norm: + src = self.norm_attn(src) + + # Feed-forward sublayer + if self.pre_norm: + src = self.norm_ffn(src) + ## Position-wise Feed-Forward + src2 = self.ff(src) + ## Add & Norm + src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout + if not self.pre_norm: + src = self.norm_ffn(src) + + if self.res_attention: + return src, scores + else: + return src + + + + +class _MultiheadAttention(nn.Module): + def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False): + """Multi Head Attention Layer + Input shape: + Q: [batch_size (bs) x max_q_len x d_model] + K, V: [batch_size (bs) x q_len x d_model] + mask: [q_len x q_len] + """ + super().__init__() + d_k = d_model // n_heads if d_k is None else d_k + d_v = d_model // n_heads if d_v is None else d_v + + self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v + + self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) + self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) + self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) + + # Scaled Dot-Product Attention (multiple heads) + self.res_attention = res_attention + self.sdp_attn = _ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa) + + # Poject output + self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)) + + + def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None, + key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + + bs = Q.size(0) + if K is None: K = Q + if V is None: V = Q + + # Linear (+ split in multiple heads) + q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k] + k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) + v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v] + + # Apply Scaled Dot-Product Attention (multiple heads) + if self.res_attention: + output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + else: + output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask) + # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len] + + # back to the original inputs dimensions + output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] + output = self.to_out(output) + + if self.res_attention: return output, attn_weights, attn_scores + else: return output, attn_weights + + +class _ScaledDotProductAttention(nn.Module): + r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer + (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets + by Lee et al, 2021)""" + + def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False): + super().__init__() + self.attn_dropout = nn.Dropout(attn_dropout) + self.res_attention = res_attention + head_dim = d_model // n_heads + self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa) + self.lsa = lsa + + def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): + ''' + Input shape: + q : [bs x n_heads x max_q_len x d_k] + k : [bs x n_heads x d_k x seq_len] + v : [bs x n_heads x seq_len x d_v] + prev : [bs x n_heads x q_len x seq_len] + key_padding_mask: [bs x seq_len] + attn_mask : [1 x seq_len x seq_len] + Output shape: + output: [bs x n_heads x q_len x d_v] + attn : [bs x n_heads x q_len x seq_len] + scores : [bs x n_heads x q_len x seq_len] + ''' + + # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence + attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len] + + # Add pre-softmax attention scores from the previous layer (optional) + if prev is not None: attn_scores = attn_scores + prev + + # Attention mask (optional) + if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len + if attn_mask.dtype == torch.bool: + attn_scores.masked_fill_(attn_mask, -np.inf) + else: + attn_scores += attn_mask + + # Key padding mask (optional) + if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len) + attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf) + + # normalize the attention weights + attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len] + attn_weights = self.attn_dropout(attn_weights) + + # compute the new values given the attention weights + output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] + + if self.res_attention: return output, attn_weights, attn_scores + else: return output, attn_weights + diff --git a/layers/PatchTST_layers.py b/layers/PatchTST_layers.py new file mode 100644 index 0000000..11b5bd6 --- /dev/null +++ b/layers/PatchTST_layers.py @@ -0,0 +1,121 @@ +__all__ = ['Transpose', 'get_activation_fn', 'moving_avg', 'series_decomp', 'PositionalEncoding', 'SinCosPosEncoding', 'Coord2dPosEncoding', 'Coord1dPosEncoding', 'positional_encoding'] + +import torch +from torch import nn +import math + +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + if self.contiguous: return x.transpose(*self.dims).contiguous() + else: return x.transpose(*self.dims) + + +def get_activation_fn(activation): + if callable(activation): return activation() + elif activation.lower() == "relu": return nn.ReLU() + elif activation.lower() == "gelu": return nn.GELU() + raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable') + + +# decomposition + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + + +# pos_encoding + +def PositionalEncoding(q_len, d_model, normalize=True): + pe = torch.zeros(q_len, d_model) + position = torch.arange(0, q_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + if normalize: + pe = pe - pe.mean() + pe = pe / (pe.std() * 10) + return pe + +SinCosPosEncoding = PositionalEncoding + +def Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True, eps=1e-3, verbose=False): + x = .5 if exponential else 1 + i = 0 + for i in range(100): + cpe = 2 * (torch.linspace(0, 1, q_len).reshape(-1, 1) ** x) * (torch.linspace(0, 1, d_model).reshape(1, -1) ** x) - 1 + pv(f'{i:4.0f} {x:5.3f} {cpe.mean():+6.3f}', verbose) + if abs(cpe.mean()) <= eps: break + elif cpe.mean() > eps: x += .001 + else: x -= .001 + i += 1 + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + return cpe + +def Coord1dPosEncoding(q_len, exponential=False, normalize=True): + cpe = (2 * (torch.linspace(0, 1, q_len).reshape(-1, 1)**(.5 if exponential else 1)) - 1) + if normalize: + cpe = cpe - cpe.mean() + cpe = cpe / (cpe.std() * 10) + return cpe + +def positional_encoding(pe, learn_pe, q_len, d_model): + # Positional encoding + if pe == None: + W_pos = torch.empty((q_len, d_model)) # pe = None and learn_pe = False can be used to measure impact of pe + nn.init.uniform_(W_pos, -0.02, 0.02) + learn_pe = False + elif pe == 'zero': + W_pos = torch.empty((q_len, 1)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'zeros': + W_pos = torch.empty((q_len, d_model)) + nn.init.uniform_(W_pos, -0.02, 0.02) + elif pe == 'normal' or pe == 'gauss': + W_pos = torch.zeros((q_len, 1)) + torch.nn.init.normal_(W_pos, mean=0.0, std=0.1) + elif pe == 'uniform': + W_pos = torch.zeros((q_len, 1)) + nn.init.uniform_(W_pos, a=0.0, b=0.1) + elif pe == 'lin1d': W_pos = Coord1dPosEncoding(q_len, exponential=False, normalize=True) + elif pe == 'exp1d': W_pos = Coord1dPosEncoding(q_len, exponential=True, normalize=True) + elif pe == 'lin2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=False, normalize=True) + elif pe == 'exp2d': W_pos = Coord2dPosEncoding(q_len, d_model, exponential=True, normalize=True) + elif pe == 'sincos': W_pos = PositionalEncoding(q_len, d_model, normalize=True) + else: raise ValueError(f"{pe} is not a valid pe (positional encoder. Available types: 'gauss'=='normal', \ + 'zeros', 'zero', uniform', 'lin1d', 'exp1d', 'lin2d', 'exp2d', 'sincos', None.)") + return nn.Parameter(W_pos, requires_grad=learn_pe) \ No newline at end of file diff --git a/layers/RevIN.py b/layers/RevIN.py new file mode 100644 index 0000000..8900170 --- /dev/null +++ b/layers/RevIN.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn + +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=True, subtract_last=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + if self.affine: + self._init_params() + + def forward(self, x, mode:str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim-1)) + if self.subtract_last: + self.last = x[:,-1,:].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps*self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x \ No newline at end of file diff --git a/layers/SelfAttention_Family.py b/layers/SelfAttention_Family.py new file mode 100644 index 0000000..c8138e2 --- /dev/null +++ b/layers/SelfAttention_Family.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import matplotlib.pyplot as plt + +import numpy as np +import math +from math import sqrt +from utils.masking import TriangularCausalMask, ProbMask +import os + + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return (V.contiguous(), A) + else: + return (V.contiguous(), None) + + +class ProbAttention(nn.Module): + def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False): + super(ProbAttention, self).__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q) + # Q [B, H, L, D] + B, H, L_K, E = K.shape + _, _, L_Q, _ = Q.shape + + # calculate the sampled Q_K + K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E) + index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q + K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] + Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() + + # find the Top_k query with sparisty measurement + M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K) + M_top = M.topk(n_top, sorted=False)[1] + + # use the reduced Q to calculate Q_K + Q_reduce = Q[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + M_top, :] # factor*ln(L_q) + Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k + + return Q_K, M_top + + def _get_initial_context(self, V, L_Q): + B, H, L_V, D = V.shape + if not self.mask_flag: + # V_sum = V.sum(dim=-2) + V_sum = V.mean(dim=-2) + contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone() + else: # use mask + assert (L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only + contex = V.cumsum(dim=-2) + return contex + + def _update_context(self, context_in, V, scores, index, L_Q, attn_mask): + B, H, L_V, D = V.shape + + if self.mask_flag: + attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device) + scores.masked_fill_(attn_mask.mask, -np.inf) + + attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores) + + context_in[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :] = torch.matmul(attn, V).type_as(context_in) + if self.output_attention: + attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device) + attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn + return (context_in, attns) + else: + return (context_in, None) + + def forward(self, queries, keys, values, attn_mask): + B, L_Q, H, D = queries.shape + _, L_K, _, _ = keys.shape + + queries = queries.transpose(2, 1) + keys = keys.transpose(2, 1) + values = values.transpose(2, 1) + + U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k) + u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) + + U_part = U_part if U_part < L_K else L_K + u = u if u < L_Q else L_Q + + scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) + + # add scale factor + scale = self.scale or 1. / sqrt(D) + if scale is not None: + scores_top = scores_top * scale + # get the context + context = self._get_initial_context(values, L_Q) + # update the context with selected top_k queries + context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask) + + return context.contiguous(), attn + + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, + keys, + values, + attn_mask + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn diff --git a/layers/Transformer_EncDec.py b/layers/Transformer_EncDec.py new file mode 100644 index 0000000..c0c5789 --- /dev/null +++ b/layers/Transformer_EncDec.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvLayer(nn.Module): + def __init__(self, c_in): + super(ConvLayer, self).__init__() + self.downConv = nn.Conv1d(in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=2, + padding_mode='circular') + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class DecoderLayer(nn.Module): + def __init__(self, self_attention, cross_attention, d_model, d_ff=None, + dropout=0.1, activation="relu"): + super(DecoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention( + x, x, x, + attn_mask=x_mask + )[0]) + x = self.norm1(x) + + x = x + self.dropout(self.cross_attention( + x, cross, cross, + attn_mask=cross_mask + )[0]) + + y = x = self.norm2(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm3(x + y) + + +class Decoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super(Decoder, self).__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None): + for layer in self.layers: + x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x diff --git a/layers/mwt.py b/layers/mwt.py new file mode 100644 index 0000000..d47fc41 --- /dev/null +++ b/layers/mwt.py @@ -0,0 +1,910 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from typing import List, Tuple +import math +from functools import partial +from einops import rearrange, reduce, repeat +from torch import nn, einsum, diagonal +from math import log2, ceil +from layers.utils import get_filter + + +# from layers.FourierCorrelation import SpectralConvCross1d as #SpectralCross1d + + +class mwt_transform(nn.Module): + def __init__(self, ich=1, k=8, alpha=16, c=128, nCZ=1, mask_flag=True, + L=0, + base='legendre', attention_dropout=0.1): + super(mwt_transform, self).__init__() + print('base', base) + self.k = k + self.c = c + self.L = L + self.nCZ = nCZ + self.Lk0 = nn.Linear(ich, c * k) + self.Lk1 = nn.Linear(c * k, ich) + self.ich = ich + self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ)) + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + values = values.view(B, L, -1) + + V = self.Lk0(values).view(B, L, self.c, -1) + for i in range(self.nCZ): + V = self.MWT_CZ[i](V) + if i < self.nCZ - 1: + V = F.relu(V) + + V = self.Lk1(V.view(B, L, -1)) + V = V.view(B, L, -1, D) + return (V.contiguous(), None) + + +class mwt_operator(nn.Module): + def __init__(self, ich=1, k=8, alpha=16, c=128, nCZ=1, mask_flag=True, + L=0, + base='legendre', attention_dropout=0.1): + super(mwt_transform, self).__init__() + print('base', base) + self.k = k + self.c = c + self.L = L + self.nCZ = nCZ + self.Lk0 = nn.Linear(ich, c * k) + self.Lk1 = nn.Linear(c * k, ich) + self.ich = ich + self.MWT_CZ = nn.ModuleList(MWT_CZ1d(k, alpha, L, c, base) for i in range(nCZ)) + + def forward(self, values): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + values = values.view(B, L, -1) + + V = self.Lk0(values).view(B, L, self.c, -1) + for i in range(self.nCZ): + V = self.MWT_CZ[i](V) + if i < self.nCZ - 1: + V = F.relu(V) + + V = self.Lk1(V.view(B, L, -1)) + V = V.view(B, L, -1, D) + return (V.contiguous(), None) + + +class mwt_transform_cross(nn.Module): + def __init__(self, ich=3, k=8, alpha=64, c=128, nCZ=1, mask_flag=True, + L=0, + base='legendre', attention_dropout=0.1): + super(mwt_transform, self).__init__() + print('base', base) + self.k = k + self.c = c + self.L = L + self.nCZ = nCZ + # self.Lk0 = nn.Linear(ich,c*k) + self.Lk1 = nn.Linear(c * k, ich) + self.ich = ich + self.MWT_CZ = nn.ModuleList(MWT_CZ1d_cross(k, alpha, L, c, base) for i in range(nCZ)) + + def forward(self, q, k, v, attn_mask=None): + B, L, H, E = q.shape + _, S, _, D = v.shape + # if L > S: + # zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + # values = torch.cat([values, zeros], dim=1) + # keys = torch.cat([keys, zeros], dim=1) + # else: + # values = values[:, :L, :, :] + # keys = keys[:, :L, :, :] + # values = values.view(B,L,-1) + + # V = self.Lk0(values).view(B,L,self.c,-1) + for i in range(self.nCZ): + v = self.MWT_CZ[i](q, k, v) + if i < self.nCZ - 1: + v = F.relu(v) + + v = self.Lk1(V.view(B, L, -1)) + V = V.view(B, L, -1, D) + return (V.contiguous(), None) + + +class SpectralCross1d(nn.Module): + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes1=16, activation='tanh'): + super(SpectralCross1d, self).__init__() + print('corss fourier correlation used!') + + """ + 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + self.in_channels = in_channels + self.out_channels = out_channels + # self.modes1 = seq_len // 2 + self.modes1 = modes1 + self.activation = activation + # Number of Fourier modes to multiply, at most floor(N/2) + 1 + # if modes1 == 100: + # self.index_q = list(range(0, seq_len_q//2)) + # np.random.shuffle(self.index_q) + # self.index_q = self.index_q[:16] + # self.index_q.sort() + # self.index_k_v = list(range(0, seq_len_kv // 2)) + # np.random.shuffle(self.index_k_v) + # self.index_k_v = self.index_k_v[:16] + # self.index_k_v.sort() + # elif modes1 < 0: + # modes2 = abs(modes1) + # self.index_q = get_dynamic_modes(seq_len_q, modes2) + # self.index_k_v = list(range(0, min(seq_len_kv // 2, modes2))) + # else: + # self.index_q = list(range(0, min(seq_len_q//2, modes1))) + # self.index_k_v = list(range(0, min(seq_len_kv//2, modes1))) + # print('index_q={}'.format(self.index_q)) + # print('index_k_v={}'.format(self.index_k_v)) + if modes1 > 10000: + modes2 = modes1 - 10000 + self.index_q0 = list(range(0, min(seq_len_q // 4, modes2 // 2))) + self.index_q1 = list(range(len(self.index_q0), seq_len_q // 2)) + np.random.shuffle(self.index_q1) + self.index_q1 = self.index_q1[:min(seq_len_q // 4, modes2 // 2)] + self.index_q = self.index_q0 + self.index_q1 + self.index_q.sort() + + self.index_k_v0 = list(range(0, min(seq_len_kv // 4, modes2 // 2))) + self.index_k_v1 = list(range(len(self.index_k_v0), seq_len_kv // 2)) + np.random.shuffle(self.index_k_v1) + self.index_k_v1 = self.index_k_v1[:min(seq_len_kv // 4, modes2 // 2)] + self.index_k_v = self.index_k_v0 + self.index_k_v1 + self.index_k_v.sort() + + elif modes1 > 1000: + modes2 = modes1 - 1000 + self.index_q = list(range(0, seq_len_q // 2)) + np.random.shuffle(self.index_q) + self.index_q = self.index_q[:modes2] + self.index_q.sort() + self.index_k_v = list(range(0, seq_len_kv // 2)) + np.random.shuffle(self.index_k_v) + self.index_k_v = self.index_k_v[:modes2] + self.index_k_v.sort() + elif modes1 < 0: + modes2 = abs(modes1) + self.index_q = get_dynamic_modes(seq_len_q, modes2) + self.index_k_v = list(range(0, min(seq_len_kv // 2, modes2))) + else: + self.index_q = list(range(0, min(seq_len_q // 2, modes1))) + self.index_k_v = list(range(0, min(seq_len_kv // 2, modes1))) + + print('index_q={}'.format(self.index_q)) + print('len mode q={}', len(self.index_q)) + print('index_k_v={}'.format(self.index_k_v)) + print('len mode kv={}', len(self.index_k_v)) + + self.register_buffer('index_q2', torch.tensor(self.index_q)) + + # self.scale = (1 / (in_channels * out_channels)) + # self.weights1 = nn.Parameter( + # self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index_q), dtype=torch.cfloat)) + def forward(self, q, k, v, mask): + # size = [B, L, H, E] + mask = mask + B, L, E, H = q.shape + xq = q.permute(0, 3, 2, 1) # size = [B, H, E, L] + xk = k.permute(0, 3, 2, 1) + xv = v.permute(0, 3, 2, 1) + self.index_q = list(range(0, min(int(L // 2), self.modes1))) + self.index_k_v = list(range(0, min(int(xv.shape[3] // 2), self.modes1))) + + # Compute Fourier coeffcients up to factor of e^(- something constant) + xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat) + xq_ft = torch.fft.rfft(xq, dim=-1) + for i, j in enumerate(self.index_q): + xq_ft_[:, :, :, i] = xq_ft[:, :, :, j] + + xk_ft_ = torch.zeros(B, H, E, len(self.index_k_v), device=xq.device, dtype=torch.cfloat) + xk_ft = torch.fft.rfft(xk, dim=-1) + for i, j in enumerate(self.index_k_v): + xk_ft_[:, :, :, i] = xk_ft[:, :, :, j] + xqk_ft = (torch.einsum("bhex,bhey->bhxy", xq_ft_, xk_ft_)) + if self.activation == 'tanh': + xqk_ft = xqk_ft.tanh() + elif self.activation == 'softmax': + xqk_ft = torch.softmax(abs(xqk_ft), dim=-1) + xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft)) + else: + raise Exception('{} actiation function is not implemented'.format(self.activation)) + xqkv_ft = torch.einsum("bhxy,bhey->bhex", xqk_ft, xk_ft_) + # print('xqkv_ft',xqkv_ft.shape) + # print('self.weights1',self.weights1.shape) + # xqkvw = torch.einsum("bhex,heox->bhox", xqkv_ft, self.weights1) + xqkvw = xqkv_ft + out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat) + for i, j in enumerate(self.index_q): + out_ft[:, :, :, j] = xqkvw[:, :, :, i] + + out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1)).permute(0, 3, 2, 1) + # raise Exception('aaa') + # size = [B, L, H, E] + return (out, None) + + +def get_initializer(name): + if name == 'xavier_normal': + init_ = partial(nn.init.xavier_normal_) + elif name == 'kaiming_uniform': + init_ = partial(nn.init.kaiming_uniform_) + elif name == 'kaiming_normal': + init_ = partial(nn.init.kaiming_normal_) + return init_ + + +class sparseKernel1d(nn.Module): + def __init__(self, + k, alpha, c=1, + nl=1, + initializer=None, + **kwargs): + super(sparseKernel1d, self).__init__() + + self.k = k + self.Li = nn.Linear(c * k, 128) + self.conv = self.convBlock(c * k, 128) + self.Lo = nn.Linear(128, c * k) + + def forward(self, x): + B, N, c, ich = x.shape # (B, N, c, k) + x = x.view(B, N, -1) + x = x.permute(0, 2, 1) + print('x shape', x.shape) + x = self.conv(x) + print('x conv shape', x.shape) + x = x.permute(0, 2, 1) + x = self.Lo(x) + print('x linear shape', x.shape) + x = x.view(B, N, c, ich) + raise Excepetion('aaaa') + return x + + def convBlock(self, ich, och): + net = nn.Sequential( + nn.Conv1d(ich, och, 3, 1, 1), + nn.ReLU(inplace=True), + ) + return net + + +def compl_mul1d(x, weights): + # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) + return torch.einsum("bix,iox->box", x, weights) + + +class ComplexConv(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = False, contiguous: bool = True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.contiguous = contiguous + self.weight = nn.Parameter( + torch.complex((1 / (in_features * out_features)) * torch.Tensor(out_features, in_features, 1), + (1 / (in_features * out_features)) * torch.Tensor(out_features, in_features, 1))) + if bias: + self.bias = nn.Parameter(torch.complex(torch.Tensor(out_features), + torch.Tensor(out_features))) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x): + return _apply_linear_complex(F.conv1d, x, self.weight, self.bias, self.contiguous) + + +def _apply_linear_complex(conv_fn, x, weight, bias, contiguous=True): + x_r, x_i = x.real, x.imag + w_r, w_i = weight.real, weight.imag + b_r, b_i = (None, None) if bias is None else (bias.real, bias.imag) + y_rr = conv_fn(x_r, w_r, b_r, padding=0) + y_ir = conv_fn(x_i, w_r, b_r, padding=0) + y_ri = conv_fn(x_r, w_i, b_i, padding=0) + y_ii = conv_fn(x_i, w_i, b_i, padding=0) + return torch.complex(y_rr - y_ii, y_ir + y_ri) + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Sparsemax(nn.Module): + """Sparsemax function.""" + + def __init__(self, dim=None): + """Initialize sparsemax activation + + Args: + dim (int, optional): The dimension over which to apply the sparsemax function. + """ + super(Sparsemax, self).__init__() + + self.dim = -1 if dim is None else dim + + def forward(self, input): + """Forward function. + Args: + input (torch.Tensor): Input tensor. First dimension should be the batch size + Returns: + torch.Tensor: [batch_size x number_of_logits] Output tensor + """ + # Sparsemax currently only handles 2-dim tensors, + # so we reshape to a convenient shape and reshape back after sparsemax + input = input.transpose(0, self.dim) + original_size = input.size() + input = input.reshape(input.size(0), -1) + input = input.transpose(0, 1) + dim = 1 + + number_of_logits = input.size(dim) + + # Translate input by max for numerical stability + input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) + + # Sort input in descending order. + # (NOTE: Can be replaced with linear time selection method described here: + # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html) + zs = torch.sort(input=input, dim=dim, descending=True)[0] + range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1) + range = range.expand_as(zs) + + # Determine sparsity of projection + bound = 1 + range * zs + cumulative_sum_zs = torch.cumsum(zs, dim) + is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) + k = torch.max(is_gt * range, dim, keepdim=True)[0] + + # Compute threshold function + zs_sparse = is_gt * zs + + # Compute taus + taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k + taus = taus.expand_as(input) + + # Sparsemax + self.output = torch.max(torch.zeros_like(input), input - taus) + + # Reshape back to original shape + output = self.output + output = output.transpose(0, 1) + output = output.reshape(original_size) + output = output.transpose(0, self.dim) + + return output + + def backward(self, grad_output): + """Backward function.""" + dim = 1 + + nonzeros = torch.ne(self.output, 0) + sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) + self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) + + return self.grad_input + + +def softmax_complex(z, dim): + ''' + Complex-valued Neural Networks with Non-parametric Activation Functions + (Eq. 36) + https://arxiv.org/pdf/1802.08026.pdf + ''' + # if "ComplexTensor" in z.__class__.__name__: + # result = Sparsemax(dim=dim)(abs(z)) + # else: + # result = Sparsemax(dim=dim)(z) + # result = Sparsemax(dim=dim)(z.real)+ Sparsemax(dim=dim)(z.imag) + result = Sparsemax(dim=dim)(abs(z)) + return result + + +class sparseKernelFT1d(nn.Module): + def __init__(self, + k, alpha, c=1, + nl=1, + initializer=None, + **kwargs): + super(sparseKernelFT1d, self).__init__() + + self.modes1 = alpha + self.scale = (1 / (c * k * c * k)) + self.weights1 = nn.Parameter(self.scale * torch.rand(c * k, c * k, self.modes1, dtype=torch.cfloat)) + self.weights1.requires_grad = True + self.k = k + # self.conv = ComplexConv(self.modes1,self.modes1) + # self.unet = UNET_1D(input_dim=self.modes1,output_dim=self.modes1,layer_n=16,kernel_size=3, depth=1) + + def forward(self, x): + B, N, c, k = x.shape # (B, N, c, k) + + x = x.view(B, N, -1) + x = x.permute(0, 2, 1) + x_fft = torch.fft.rfft(x) + # Multiply relevant Fourier modes + l = min(self.modes1, N // 2 + 1) + # l = N//2+1 + out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat) + out_ft[:, :, :l] = compl_mul1d(x_fft[:, :, :l], self.weights1[:, :, :l]) + x = torch.fft.irfft(out_ft, n=N) + x = x.permute(0, 2, 1).view(B, N, c, k) + return x + + +class sparseKernelFT1d_pre(nn.Module): + def __init__(self, alpha, + **kwargs): + super(sparseKernelFT1d_pre, self).__init__() + + self.modes1 = alpha + + def forward(self, x): + B, N, c, k = x.shape # (B, N, c, k) + + x = x.view(B, N, -1) + x = x.permute(0, 2, 1) + x_fft = torch.fft.rfft(x) + # Multiply relevant Fourier modes + l = min(self.modes1, N // 2 + 1) + # l = N//2+1 + out_ft = torch.zeros(B, c * k, N // 2 + 1, device=x.device, dtype=torch.cfloat) + out_ft[:, :, :l] = x_fft[:, :, :l] + # x = torch.fft.irfft(out_ft, n=N) + # x = x.permute(0, 2, 1).view(B, N, c, k) + return out_ft, N + + +class conbr_block(nn.Module): + def __init__(self, in_layer, out_layer, kernel_size, stride, dilation): + super(conbr_block, self).__init__() + + self.conv1 = nn.Conv1d(in_layer, out_layer, kernel_size=kernel_size, stride=stride, dilation=dilation, + padding=1, + bias=True) + self.bn = nn.BatchNorm1d(out_layer) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.bn(x) + out = self.relu(x) + # rint('out shape',out.shape) + + return out + + +class se_block(nn.Module): + def __init__(self, in_layer, out_layer): + super(se_block, self).__init__() + + self.conv1 = nn.Conv1d(in_layer, out_layer // 8, kernel_size=1, padding=0) + self.conv2 = nn.Conv1d(out_layer // 8, in_layer, kernel_size=1, padding=0) + self.fc = nn.Linear(1, out_layer // 8) + self.fc2 = nn.Linear(out_layer // 8, out_layer) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x_se = nn.functional.adaptive_avg_pool1d(x, 1) + x_se = self.conv1(x_se) + x_se = self.relu(x_se) + x_se = self.conv2(x_se) + x_se = self.sigmoid(x_se) + + x_out = torch.add(x, x_se) + return x_out + + +class re_block(nn.Module): + def __init__(self, in_layer, out_layer, kernel_size, dilation): + super(re_block, self).__init__() + + self.cbr1 = conbr_block(in_layer, out_layer, kernel_size, 1, dilation) + self.cbr2 = conbr_block(out_layer, out_layer, kernel_size, 1, dilation) + self.seblock = se_block(out_layer, out_layer) + + def forward(self, x): + # print('x shape',x.shape) + x_re = self.cbr1(x) + # print('x re shape', x_re.shape) + x_re = self.cbr2(x_re) + # print('x re shape', x_re.shape) + x_re = self.seblock(x_re) + # print('x re shape', x_re.shape) + x_out = torch.add(x, x_re) + return x_out + + +class UNET_1D(nn.Module): + def __init__(self, input_dim, output_dim, layer_n, kernel_size, depth): + super(UNET_1D, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.layer_n = layer_n + self.kernel_size = kernel_size + self.depth = depth + + self.AvgPool1D1 = nn.AvgPool1d(input_dim, stride=4, padding=3) + # self.AvgPool1D2 = nn.AvgPool1d(input_dim, stride=25) + + self.layer1 = self.down_layer(self.input_dim, self.layer_n, self.kernel_size, 1, 2) + self.layer2 = self.down_layer(self.layer_n, int(self.layer_n * 2), self.kernel_size, 4, 2) + self.layer3 = self.down_layer(int(self.layer_n * 2) + int(self.input_dim), int(self.layer_n * 3), + self.kernel_size, 4, 2) + # self.cbr_up1 = conbr_block(int(self.layer_n*7), int(self.layer_n*3), self.kernel_size, 1, 1) + self.cbr_up2 = conbr_block(int(self.layer_n * 5), int(self.layer_n * 2), self.kernel_size, 1, 1) + self.cbr_up3 = conbr_block(int(self.layer_n * 3), self.layer_n, self.kernel_size, 1, 1) + self.upsample = nn.Upsample(scale_factor=4, mode='nearest') + + self.outcov = nn.Conv1d(self.layer_n, self.output_dim, kernel_size=self.kernel_size, stride=1, padding=3) + + def down_layer(self, input_layer, out_layer, kernel, stride, depth): + block = [] + block.append(conbr_block(input_layer, out_layer, kernel, stride, 1)) + for i in range(depth): + block.append(re_block(out_layer, out_layer, kernel, 1)) + return nn.Sequential(*block) + + def forward(self, x): + pool_x1 = self.AvgPool1D1(x) + # print('pool x1',pool_x1.shape) + #############Encoder##################### + + out_0 = self.layer1(x) + out_1 = self.layer2(out_0) + + x = torch.cat([out_1, pool_x1], 1) + x = self.layer3(x) + + #############Decoder#################### + + up = self.upsample(x) + # print('up shape',up.shape) + # print('out1 shape',out_1.shape) + up = torch.cat([up, out_1], 1) + up = self.cbr_up2(up) + + up = self.upsample(up) + up = torch.cat([up, out_0], 1) + up = self.cbr_up3(up) + + out = self.outcov(up) + return out + + +class MWT_CZ1d(nn.Module): + def __init__(self, + k=3, alpha=64, + L=0, c=1, + base='legendre', + initializer=get_initializer('xavier_normal'), + **kwargs): + super(MWT_CZ1d, self).__init__() + + self.k = k + self.L = L + H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) + H0r = H0 @ PHI0 + G0r = G0 @ PHI0 + H1r = H1 @ PHI1 + G1r = G1 @ PHI1 + + H0r[np.abs(H0r) < 1e-8] = 0 + H1r[np.abs(H1r) < 1e-8] = 0 + G0r[np.abs(G0r) < 1e-8] = 0 + G1r[np.abs(G1r) < 1e-8] = 0 + self.max_item = 3 + + # self.A = nn.ModuleList(sparseKernelFT1d(k, alpha, c) for _ in range(self.max_item)) + # self.B = nn.ModuleList(sparseKernelFT1d(k, alpha, c) for _ in range(self.max_item)) + # self.C = nn.ModuleList(sparseKernelFT1d(k, alpha, c) for _ in range(self.max_item)) + + self.A = sparseKernelFT1d(k, alpha, c) + self.B = sparseKernelFT1d(k, alpha, c) + self.C = sparseKernelFT1d(k, alpha, c) + + self.T0 = nn.Linear(k, k) + + self.register_buffer('ec_s', torch.Tensor( + np.concatenate((H0.T, H1.T), axis=0))) + self.register_buffer('ec_d', torch.Tensor( + np.concatenate((G0.T, G1.T), axis=0))) + + self.register_buffer('rc_e', torch.Tensor( + np.concatenate((H0r, G0r), axis=0))) + self.register_buffer('rc_o', torch.Tensor( + np.concatenate((H1r, G1r), axis=0))) + + # self.ec_s = torch.nn.Parameter(torch.Tensor(np.concatenate((H0.T, H1.T), axis=0))) + # self.ec_d = torch.nn.Parameter(torch.Tensor(np.concatenate((G0.T, G1.T), axis=0))) + # self.rc_e = torch.nn.Parameter(torch.Tensor(np.concatenate((H0r, G0r), axis=0))) + # self.rc_o = torch.nn.Parameter(torch.Tensor(np.concatenate((H1r, G1r), axis=0))) + + def forward(self, x): + + B, N, c, k = x.shape # (B, N, k) + + ns = math.floor(np.log2(N)) + nl = pow(2, math.ceil(np.log2(N))) + extra_x = x[:, 0:nl - N, :, :] + x = torch.cat([x, extra_x], 1) + # print('x shape raw',x.shape) + Ud = torch.jit.annotate(List[Tensor], []) + Us = torch.jit.annotate(List[Tensor], []) + # decompose + for i in range(ns - self.L): + # print('x shape',x.shape) + d, x = self.wavelet_transform(x) + Ud += [self.A(d) + self.B(x)] + Us += [self.C(d)] + + # print('x shape decomposed',x.shape) + x = self.T0(x) # coarsest scale transform + + # reconstruct + for i in range(ns - 1 - self.L, -1, -1): + # print('Us {} shape {}'.format(i,Us[i].shape)) + x = x + Us[i] + x = torch.cat((x, Ud[i]), -1) + # print('reconsturct step {} shape {}'.format(i,x.shape)) + x = self.evenOdd(x) + # raise Exception('test break') + x = x[:, :N, :, :] + # print('new x shape',x.shape) + # raise Exception('break') + + return x + + def wavelet_transform(self, x): + xa = torch.cat([x[:, ::2, :, :], + x[:, 1::2, :, :], + ], -1) + d = torch.matmul(xa, self.ec_d) + s = torch.matmul(xa, self.ec_s) + return d, s + + def evenOdd(self, x): + + B, N, c, ich = x.shape # (B, N, c, k) + assert ich == 2 * self.k + x_e = torch.matmul(x, self.rc_e) + x_o = torch.matmul(x, self.rc_o) + + x = torch.zeros(B, N * 2, c, self.k, + device=x.device) + x[..., ::2, :, :] = x_e + x[..., 1::2, :, :] = x_o + return x + + +class MWT1d(nn.Module): + def __init__(self, + ich=1, k=3, alpha=2, c=1, + nCZ=3, + L=0, + base='legendre', + initializer=get_initializer('xavier_normal'), + **kwargs): + super(MWT1d, self).__init__() + + self.k = k + self.c = c + self.L = L + self.nCZ = nCZ + self.Lk = nn.Linear(ich, c * k) + + self.MWT_CZ = nn.ModuleList( + [MWT_CZ1d(k, alpha, L, c, base, + initializer) for _ in range(nCZ)] + ) + self.Lc0 = nn.Linear(c * k, 128) + self.Lc1 = nn.Linear(128, 1) + + if initializer is not None: + self.reset_parameters(initializer) + + def forward(self, x): + + B, N, ich = x.shape # (B, N, d) + ns = math.floor(np.log2(N)) + x = self.Lk(x) + x = x.view(B, N, self.c, self.k) + + for i in range(self.nCZ): + x = self.MWT_CZ[i](x) + if i < self.nCZ - 1: + x = F.relu(x) + + x = x.view(B, N, -1) # collapse c and k + x = self.Lc0(x) + x = F.relu(x) + x = self.Lc1(x) + return x.squeeze() + + def reset_parameters(self, initializer): + initializer(self.Lc0.weight) + initializer(self.Lc1.weight) + + +def exists(val): + return val is not None + + +class MWT_CZ1d_cross(nn.Module): + def __init__(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes1, c=64, + k=8, ich=512, + L=0, + base='legendre', + initializer=get_initializer('xavier_normal'), activation='tanh', + + **kwargs): + super(MWT_CZ1d_cross, self).__init__() + print('base', base) + + self.c = c + + self.k = k + self.L = L + H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) + H0r = H0 @ PHI0 + G0r = G0 @ PHI0 + H1r = H1 @ PHI1 + G1r = G1 @ PHI1 + + H0r[np.abs(H0r) < 1e-8] = 0 + H1r[np.abs(H1r) < 1e-8] = 0 + G0r[np.abs(G0r) < 1e-8] = 0 + G1r[np.abs(G1r) < 1e-8] = 0 + self.max_item = 3 + + # self.pre = sparseKernelFT1d_pre(alpha) + + self.attn1 = SpectralCross1d(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes1=modes1, activation=activation) + self.attn2 = SpectralCross1d(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes1=modes1, activation=activation) + self.attn3 = SpectralCross1d(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes1=modes1, activation=activation) + self.attn4 = SpectralCross1d(in_channels=in_channels, out_channels=out_channels, seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, modes1=modes1, activation=activation) + + self.T0 = nn.Linear(k, k) + + self.register_buffer('ec_s', torch.Tensor( + np.concatenate((H0.T, H1.T), axis=0))) + self.register_buffer('ec_d', torch.Tensor( + np.concatenate((G0.T, G1.T), axis=0))) + + self.register_buffer('rc_e', torch.Tensor( + np.concatenate((H0r, G0r), axis=0))) + self.register_buffer('rc_o', torch.Tensor( + np.concatenate((H1r, G1r), axis=0))) + + self.Lk = nn.Linear(ich, c * k) + self.Lq = nn.Linear(ich, c * k) + self.Lv = nn.Linear(ich, c * k) + self.out = nn.Linear(c * k, ich) + self.modes1 = modes1 + + def forward(self, q, k, v, mask=None): + B, N, H, E = q.shape # (B, N, k) + _, S, _, _ = k.shape + + q = q.view(q.shape[0], q.shape[1], -1) + k = k.view(k.shape[0], k.shape[1], -1) + v = v.view(v.shape[0], v.shape[1], -1) + q = self.Lq(q) + q = q.view(q.shape[0], q.shape[1], self.c, self.k) + k = self.Lk(k) + k = k.view(k.shape[0], k.shape[1], self.c, self.k) + v = self.Lv(v) + v = v.view(v.shape[0], v.shape[1], self.c, self.k) + + if N > S: + zeros = torch.zeros_like(q[:, :(N - S), :]).float() + v = torch.cat([v, zeros], dim=1) + k = torch.cat([k, zeros], dim=1) + else: + v = v[:, :N, :, :] + k = k[:, :N, :, :] + + ns = math.floor(np.log2(N)) + nl = pow(2, math.ceil(np.log2(N))) + extra_q = q[:, 0:nl - N, :, :] + extra_k = k[:, 0:nl - N, :, :] + extra_v = v[:, 0:nl - N, :, :] + q = torch.cat([q, extra_q], 1) + k = torch.cat([k, extra_k], 1) + v = torch.cat([v, extra_v], 1) + + Ud_q = torch.jit.annotate(List[Tuple[Tensor]], []) + Ud_k = torch.jit.annotate(List[Tuple[Tensor]], []) + Ud_v = torch.jit.annotate(List[Tuple[Tensor]], []) + + Us_q = torch.jit.annotate(List[Tensor], []) + Us_k = torch.jit.annotate(List[Tensor], []) + Us_v = torch.jit.annotate(List[Tensor], []) + + Ud = torch.jit.annotate(List[Tensor], []) + Us = torch.jit.annotate(List[Tensor], []) + + # decompose + for i in range(ns - self.L): + # print('q shape',q.shape) + d, q = self.wavelet_transform(q) + Ud_q += [tuple([d, q])] + Us_q += [d] + for i in range(ns - self.L): + d, k = self.wavelet_transform(k) + Ud_k += [tuple([d, k])] + Us_k += [d] + for i in range(ns - self.L): + d, v = self.wavelet_transform(v) + Ud_v += [tuple([d, v])] + Us_v += [d] + for i in range(ns - self.L): + dk, sk = Ud_k[i], Us_k[i] + dq, sq = Ud_q[i], Us_q[i] + dv, sv = Ud_v[i], Us_v[i] + Ud += [self.attn1(dq[0], dk[0], dv[0], mask)[0] + self.attn2(dq[1], dk[1], dv[1], mask)[0]] + Us += [self.attn3(sq, sk, sv, mask)[0]] + v = self.attn4(q, k, v, mask)[0] + + # reconstruct + for i in range(ns - 1 - self.L, -1, -1): + v = v + Us[i] + v = torch.cat((v, Ud[i]), -1) + v = self.evenOdd(v) + # raise Exception('test break') + v = self.out(v[:, :N, :, :].contiguous().view(B, N, -1)) + return (v.contiguous(), None) + + def wavelet_transform(self, x): + xa = torch.cat([x[:, ::2, :, :], + x[:, 1::2, :, :], + ], -1) + d = torch.matmul(xa, self.ec_d) + s = torch.matmul(xa, self.ec_s) + return d, s + + def evenOdd(self, x): + + B, N, c, ich = x.shape # (B, N, c, k) + assert ich == 2 * self.k + x_e = torch.matmul(x, self.rc_e) + x_o = torch.matmul(x, self.rc_o) + + x = torch.zeros(B, N * 2, c, self.k, + device=x.device) + x[..., ::2, :, :] = x_e + x[..., 1::2, :, :] = x_o + return x \ No newline at end of file diff --git a/layers/utils.py b/layers/utils.py new file mode 100644 index 0000000..6c06092 --- /dev/null +++ b/layers/utils.py @@ -0,0 +1,391 @@ +import torch +import torch.nn as nn + +import numpy as np +from functools import partial + +from scipy.special import eval_legendre +from sympy import Poly, legendre, Symbol, chebyshevt + + +def legendreDer(k, x): + def _legendre(k, x): + return (2 * k + 1) * eval_legendre(k, x) + + out = 0 + for i in np.arange(k - 1, -1, -2): + out += _legendre(i, x) + return out + + +def phi_(phi_c, x, lb=0, ub=1): + mask = np.logical_or(x < lb, x > ub) * 1.0 + return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1 - mask) + + +def get_phi_psi(k, base): + x = Symbol('x') + phi_coeff = np.zeros((k, k)) + phi_2x_coeff = np.zeros((k, k)) + if base == 'legendre': + for ki in range(k): + coeff_ = Poly(legendre(ki, 2 * x - 1), x).all_coeffs() + phi_coeff[ki, :ki + 1] = np.flip(np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)) + coeff_ = Poly(legendre(ki, 4 * x - 1), x).all_coeffs() + phi_2x_coeff[ki, :ki + 1] = np.flip(np.sqrt(2) * np.sqrt(2 * ki + 1) * np.array(coeff_).astype(np.float64)) + + psi1_coeff = np.zeros((k, k)) + psi2_coeff = np.zeros((k, k)) + for ki in range(k): + psi1_coeff[ki, :] = phi_2x_coeff[ki, :] + for i in range(k): + a = phi_2x_coeff[ki, :ki + 1] + b = phi_coeff[i, :i + 1] + prod_ = np.convolve(a, b) + prod_[np.abs(prod_) < 1e-8] = 0 + proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() + psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] + psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] + for j in range(ki): + a = phi_2x_coeff[ki, :ki + 1] + b = psi1_coeff[j, :] + prod_ = np.convolve(a, b) + prod_[np.abs(prod_) < 1e-8] = 0 + proj_ = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() + psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] + psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] + + a = psi1_coeff[ki, :] + prod_ = np.convolve(a, a) + prod_[np.abs(prod_) < 1e-8] = 0 + norm1 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * np.power(0.5, 1 + np.arange(len(prod_)))).sum() + + a = psi2_coeff[ki, :] + prod_ = np.convolve(a, a) + prod_[np.abs(prod_) < 1e-8] = 0 + norm2 = (prod_ * 1 / (np.arange(len(prod_)) + 1) * (1 - np.power(0.5, 1 + np.arange(len(prod_))))).sum() + norm_ = np.sqrt(norm1 + norm2) + psi1_coeff[ki, :] /= norm_ + psi2_coeff[ki, :] /= norm_ + psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0 + psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0 + + phi = [np.poly1d(np.flip(phi_coeff[i, :])) for i in range(k)] + psi1 = [np.poly1d(np.flip(psi1_coeff[i, :])) for i in range(k)] + psi2 = [np.poly1d(np.flip(psi2_coeff[i, :])) for i in range(k)] + + elif base == 'chebyshev': + for ki in range(k): + if ki == 0: + phi_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) + phi_2x_coeff[ki, :ki + 1] = np.sqrt(2 / np.pi) * np.sqrt(2) + else: + coeff_ = Poly(chebyshevt(ki, 2 * x - 1), x).all_coeffs() + phi_coeff[ki, :ki + 1] = np.flip(2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) + coeff_ = Poly(chebyshevt(ki, 4 * x - 1), x).all_coeffs() + phi_2x_coeff[ki, :ki + 1] = np.flip( + np.sqrt(2) * 2 / np.sqrt(np.pi) * np.array(coeff_).astype(np.float64)) + + phi = [partial(phi_, phi_coeff[i, :]) for i in range(k)] + + x = Symbol('x') + kUse = 2 * k + roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots() + x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = np.pi / kUse / 2 + + psi1_coeff = np.zeros((k, k)) + psi2_coeff = np.zeros((k, k)) + + psi1 = [[] for _ in range(k)] + psi2 = [[] for _ in range(k)] + + for ki in range(k): + psi1_coeff[ki, :] = phi_2x_coeff[ki, :] + for i in range(k): + proj_ = (wm * phi[i](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum() + psi1_coeff[ki, :] -= proj_ * phi_coeff[i, :] + psi2_coeff[ki, :] -= proj_ * phi_coeff[i, :] + + for j in range(ki): + proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2 * x_m)).sum() + psi1_coeff[ki, :] -= proj_ * psi1_coeff[j, :] + psi2_coeff[ki, :] -= proj_ * psi2_coeff[j, :] + + psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5) + psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5, ub=1) + + norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() + norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() + + norm_ = np.sqrt(norm1 + norm2) + psi1_coeff[ki, :] /= norm_ + psi2_coeff[ki, :] /= norm_ + psi1_coeff[np.abs(psi1_coeff) < 1e-8] = 0 + psi2_coeff[np.abs(psi2_coeff) < 1e-8] = 0 + + psi1[ki] = partial(phi_, psi1_coeff[ki, :], lb=0, ub=0.5 + 1e-16) + psi2[ki] = partial(phi_, psi2_coeff[ki, :], lb=0.5 + 1e-16, ub=1) + + return phi, psi1, psi2 + + +def get_filter(base, k): + def psi(psi1, psi2, i, inp): + mask = (inp <= 0.5) * 1.0 + return psi1[i](inp) * mask + psi2[i](inp) * (1 - mask) + + if base not in ['legendre', 'chebyshev']: + raise Exception('Base not supported') + + x = Symbol('x') + H0 = np.zeros((k, k)) + H1 = np.zeros((k, k)) + G0 = np.zeros((k, k)) + G1 = np.zeros((k, k)) + PHI0 = np.zeros((k, k)) + PHI1 = np.zeros((k, k)) + phi, psi1, psi2 = get_phi_psi(k, base) + if base == 'legendre': + roots = Poly(legendre(k, 2 * x - 1)).all_roots() + x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + wm = 1 / k / legendreDer(k, 2 * x_m - 1) / eval_legendre(k - 1, 2 * x_m - 1) + + for ki in range(k): + for kpi in range(k): + H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() + G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() + H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() + G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() + + PHI0 = np.eye(k) + PHI1 = np.eye(k) + + elif base == 'chebyshev': + x = Symbol('x') + kUse = 2 * k + roots = Poly(chebyshevt(kUse, 2 * x - 1)).all_roots() + x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = np.pi / kUse / 2 + + for ki in range(k): + for kpi in range(k): + H0[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki](x_m / 2) * phi[kpi](x_m)).sum() + G0[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m / 2) * phi[kpi](x_m)).sum() + H1[ki, kpi] = 1 / np.sqrt(2) * (wm * phi[ki]((x_m + 1) / 2) * phi[kpi](x_m)).sum() + G1[ki, kpi] = 1 / np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m + 1) / 2) * phi[kpi](x_m)).sum() + + PHI0[ki, kpi] = (wm * phi[ki](2 * x_m) * phi[kpi](2 * x_m)).sum() * 2 + PHI1[ki, kpi] = (wm * phi[ki](2 * x_m - 1) * phi[kpi](2 * x_m - 1)).sum() * 2 + + PHI0[np.abs(PHI0) < 1e-8] = 0 + PHI1[np.abs(PHI1) < 1e-8] = 0 + + H0[np.abs(H0) < 1e-8] = 0 + H1[np.abs(H1) < 1e-8] = 0 + G0[np.abs(G0) < 1e-8] = 0 + G1[np.abs(G1) < 1e-8] = 0 + + return H0, H1, G0, G1, PHI0, PHI1 + + +def train(model, train_loader, optimizer, epoch, device, verbose=0, + lossFn=None, lr_schedule=None, + post_proc=lambda args: args): + if lossFn is None: + lossFn = nn.MSELoss() + + model.train() + + total_loss = 0. + + for batch_idx, (data, target) in enumerate(train_loader): + bs = len(data) + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + + output = model(data) + + target = post_proc(target) + output = post_proc(output) + loss = lossFn(output.view(bs, -1), target.view(bs, -1)) + + loss.backward() + optimizer.step() + total_loss += loss.sum().item() + if lr_schedule is not None: lr_schedule.step() + + if verbose > 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + + return total_loss / len(train_loader.dataset) + + +def test(model, test_loader, device, verbose=0, lossFn=None, + post_proc=lambda args: args): + model.eval() + if lossFn is None: + lossFn = nn.MSELoss() + + total_loss = 0. + predictions = [] + + with torch.no_grad(): + for data, target in test_loader: + bs = len(data) + + data, target = data.to(device), target.to(device) + output = model(data) + output = post_proc(output) + + loss = lossFn(output.view(bs, -1), target.view(bs, -1)) + total_loss += loss.sum().item() + + return total_loss / len(test_loader.dataset) + + +# Till EoF +# taken from FNO paper: +# https://github.com/zongyi-li/fourier_neural_operator + +# normalization, pointwise gaussian +class UnitGaussianNormalizer(object): + def __init__(self, x, eps=0.00001): + super(UnitGaussianNormalizer, self).__init__() + + # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T + self.mean = torch.mean(x, 0) + self.std = torch.std(x, 0) + self.eps = eps + + def encode(self, x): + x = (x - self.mean) / (self.std + self.eps) + return x + + def decode(self, x, sample_idx=None): + if sample_idx is None: + std = self.std + self.eps # n + mean = self.mean + else: + if len(self.mean.shape) == len(sample_idx[0].shape): + std = self.std[sample_idx] + self.eps # batch*n + mean = self.mean[sample_idx] + if len(self.mean.shape) > len(sample_idx[0].shape): + std = self.std[:, sample_idx] + self.eps # T*batch*n + mean = self.mean[:, sample_idx] + + # x is in shape of batch*n or T*batch*n + x = (x * std) + mean + return x + + def cuda(self): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def cpu(self): + self.mean = self.mean.cpu() + self.std = self.std.cpu() + + +# normalization, Gaussian +class GaussianNormalizer(object): + def __init__(self, x, eps=0.00001): + super(GaussianNormalizer, self).__init__() + + self.mean = torch.mean(x) + self.std = torch.std(x) + self.eps = eps + + def encode(self, x): + x = (x - self.mean) / (self.std + self.eps) + return x + + def decode(self, x, sample_idx=None): + x = (x * (self.std + self.eps)) + self.mean + return x + + def cuda(self): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def cpu(self): + self.mean = self.mean.cpu() + self.std = self.std.cpu() + + +# normalization, scaling by range +class RangeNormalizer(object): + def __init__(self, x, low=0.0, high=1.0): + super(RangeNormalizer, self).__init__() + mymin = torch.min(x, 0)[0].view(-1) + mymax = torch.max(x, 0)[0].view(-1) + + self.a = (high - low) / (mymax - mymin) + self.b = -self.a * mymax + high + + def encode(self, x): + s = x.size() + x = x.view(s[0], -1) + x = self.a * x + self.b + x = x.view(s) + return x + + def decode(self, x): + s = x.size() + x = x.view(s[0], -1) + x = (x - self.b) / self.a + x = x.view(s) + return x + + +class LpLoss(object): + def __init__(self, d=2, p=2, size_average=True, reduction=True): + super(LpLoss, self).__init__() + + # Dimension and Lp-norm type are postive + assert d > 0 and p > 0 + + self.d = d + self.p = p + self.reduction = reduction + self.size_average = size_average + + def abs(self, x, y): + num_examples = x.size()[0] + + # Assume uniform mesh + h = 1.0 / (x.size()[1] - 1.0) + + all_norms = (h ** (self.d / self.p)) * torch.norm(x.view(num_examples, -1) - y.view(num_examples, -1), self.p, + 1) + + if self.reduction: + if self.size_average: + return torch.mean(all_norms) + else: + return torch.sum(all_norms) + + return all_norms + + def rel(self, x, y): + num_examples = x.size()[0] + + diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) + + if self.reduction: + if self.size_average: + return torch.mean(diff_norms / y_norms) + else: + return torch.sum(diff_norms / y_norms) + + return diff_norms / y_norms + + def __call__(self, x, y): + return self.rel(x, y) \ No newline at end of file diff --git a/models/Autoformer.py b/models/Autoformer.py new file mode 100644 index 0000000..7f0f441 --- /dev/null +++ b/models/Autoformer.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers.Embed import DataEmbedding, DataEmbedding_wo_pos,DataEmbedding_wo_pos_temp,DataEmbedding_wo_temp +from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer +from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp +import math +import numpy as np + + +class Model(nn.Module): + """ + Autoformer is the first method to achieve the series-wise connection, + with inherent O(LlogL) complexity + """ + def __init__(self, configs): + super(Model, self).__init__() + self.seq_len = configs.seq_len + self.label_len = configs.label_len + self.pred_len = configs.pred_len + self.output_attention = configs.output_attention + + # Decomp + kernel_size = configs.moving_avg + self.decomp = series_decomp(kernel_size) + + # Embedding + # The series-wise connection inherently contains the sequential information. + # Thus, we can discard the position embedding of transformers. + if configs.embed_type == 0: + self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 1: + self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 2: + self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + + elif configs.embed_type == 3: + self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 4: + self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AutoCorrelationLayer( + AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention), + configs.d_model, configs.n_heads), + configs.d_model, + configs.d_ff, + moving_avg=configs.moving_avg, + dropout=configs.dropout, + activation=configs.activation + ) for l in range(configs.e_layers) + ], + norm_layer=my_Layernorm(configs.d_model) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AutoCorrelationLayer( + AutoCorrelation(True, configs.factor, attention_dropout=configs.dropout, + output_attention=False), + configs.d_model, configs.n_heads), + AutoCorrelationLayer( + AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=False), + configs.d_model, configs.n_heads), + configs.d_model, + configs.c_out, + configs.d_ff, + moving_avg=configs.moving_avg, + dropout=configs.dropout, + activation=configs.activation, + ) + for l in range(configs.d_layers) + ], + norm_layer=my_Layernorm(configs.d_model), + projection=nn.Linear(configs.d_model, configs.c_out, bias=True) + ) + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + + x_mark_enc = torch.zeros(x_enc.shape[0], x_enc.shape[1], 4).to(x_enc.device) + x_dec = torch.zeros(x_enc.shape[0], 48+720, x_enc.shape[2]).to(x_enc.device) + x_mark_dec = torch.zeros(x_enc.shape[0], 48+720, 4).to(x_enc.device) + + # decomp init + mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) + zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]], device=x_enc.device) + seasonal_init, trend_init = self.decomp(x_enc) + # decoder input + trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) + seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) + # enc + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + # dec + dec_out = self.dec_embedding(seasonal_init, x_mark_dec) + seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, + trend=trend_init) + # final + dec_out = trend_part + seasonal_part + + if self.output_attention: + return dec_out[:, -self.pred_len:, :], attns + else: + return dec_out[:, -self.pred_len:, :] # [B, L, D] diff --git a/models/DLinear.py b/models/DLinear.py new file mode 100644 index 0000000..5317dd5 --- /dev/null +++ b/models/DLinear.py @@ -0,0 +1,174 @@ +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# import numpy as np +# +# class moving_avg(nn.Module): +# """ +# Moving average block to highlight the trend of time series +# """ +# def __init__(self, kernel_size, stride): +# super(moving_avg, self).__init__() +# self.kernel_size = kernel_size +# self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) +# +# def forward(self, x): +# # padding on the both ends of time series +# front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) +# end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) +# x = torch.cat([front, x, end], dim=1) +# x = self.avg(x.permute(0, 2, 1)) +# x = x.permute(0, 2, 1) +# return x +# +# +# class series_decomp(nn.Module): +# """ +# Series decomposition block +# """ +# def __init__(self, kernel_size): +# super(series_decomp, self).__init__() +# self.moving_avg = moving_avg(kernel_size, stride=1) +# +# def forward(self, x): +# moving_mean = self.moving_avg(x) +# res = x - moving_mean +# return res, moving_mean +# +# class Model(nn.Module): +# """ +# Decomposition-Linear +# """ +# def __init__(self, configs): +# super(Model, self).__init__() +# self.seq_len = configs.seq_len +# self.pred_len = configs.pred_len +# +# # Decompsition Kernel Size +# kernel_size = 25 +# self.decompsition = series_decomp(kernel_size) +# self.individual = configs.individual +# self.channels = configs.enc_in +# +# if self.individual: +# self.Linear_Seasonal = nn.ModuleList() +# self.Linear_Trend = nn.ModuleList() +# +# for i in range(self.channels): +# self.Linear_Seasonal.append(nn.Linear(self.seq_len,self.pred_len)) +# self.Linear_Trend.append(nn.Linear(self.seq_len,self.pred_len)) +# +# # Use this two lines if you want to visualize the weights +# # self.Linear_Seasonal[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) +# # self.Linear_Trend[i].weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) +# else: +# self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len) +# self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len) +# +# # Use this two lines if you want to visualize the weights +# # self.Linear_Seasonal.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) +# # self.Linear_Trend.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) +# +# def forward(self, x): +# # x: [Batch, Input length, Channel] +# seasonal_init, trend_init = self.decompsition(x) +# seasonal_init, trend_init = seasonal_init.permute(0,2,1), trend_init.permute(0,2,1) +# if self.individual: +# seasonal_output = torch.zeros([seasonal_init.size(0),seasonal_init.size(1),self.pred_len],dtype=seasonal_init.dtype).to(seasonal_init.device) +# trend_output = torch.zeros([trend_init.size(0),trend_init.size(1),self.pred_len],dtype=trend_init.dtype).to(trend_init.device) +# for i in range(self.channels): +# seasonal_output[:,i,:] = self.Linear_Seasonal[i](seasonal_init[:,i,:]) +# trend_output[:,i,:] = self.Linear_Trend[i](trend_init[:,i,:]) +# else: +# seasonal_output = self.Linear_Seasonal(seasonal_init) +# trend_output = self.Linear_Trend(trend_init) +# +# x = seasonal_output + trend_output +# return x.permute(0,2,1) # to [Batch, Output length, Channel] + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class moving_avg(nn.Module): + """ + Moving average block to highlight the trend of time series + """ + + def __init__(self, kernel_size, stride): + super(moving_avg, self).__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class series_decomp(nn.Module): + """ + Series decomposition block + """ + + def __init__(self, kernel_size): + super(series_decomp, self).__init__() + self.moving_avg = moving_avg(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +class Model(nn.Module): + """ + Decomposition-Linear + """ + + def __init__(self, configs): + super(Model, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + + # Decompsition Kernel Size + kernel_size = 25 + self.decompsition = series_decomp(kernel_size) + self.individual = configs.individual + self.enc_in = configs.enc_in + self.period_len = 24 + + self.seg_num_x = self.seq_len // self.period_len + self.seg_num_y = self.pred_len // self.period_len + + # self.Linear_Seasonal = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False) + # self.Linear_Trend = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False) + self.Linear_Seasonal = nn.Linear(self.seq_len,self.pred_len) + self.Linear_Trend = nn.Linear(self.seq_len,self.pred_len) + + def forward(self, x): + # x: [Batch, Input length, Channel] + seasonal_init, trend_init = self.decompsition(x) + seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1) + + seasonal_output = self.Linear_Seasonal(seasonal_init) + trend_output = self.Linear_Trend(trend_init) + + # seasonal_init = seasonal_init.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1) + # seasonal_output = self.Linear_Seasonal(seasonal_init) # bc,w,m + # seasonal_output = seasonal_output.permute(0, 2, 1).reshape(x.size(0), self.enc_in, self.pred_len) + # + # trend_init = trend_init.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1) + # trend_output = self.Linear_Trend(trend_init) # bc,w,m + # trend_output = trend_output.permute(0, 2, 1).reshape(x.size(0), self.enc_in, self.pred_len) + + x = seasonal_output + trend_output + return x.permute(0, 2, 1) # to [Batch, Output length, Channel] + diff --git a/models/FEDformer.py b/models/FEDformer.py new file mode 100644 index 0000000..ee89bf6 --- /dev/null +++ b/models/FEDformer.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers.Embed import DataEmbedding, DataEmbedding_wo_pos,DataEmbedding_wo_pos_temp,DataEmbedding_wo_temp +from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer +from layers.FED_FourierCorrelation import FourierBlock, FourierCrossAttention +from layers.MultiWaveletCorrelation import MultiWaveletCross, MultiWaveletTransform +from layers.SelfAttention_Family import FullAttention, ProbAttention +# from layers.FED_wo_decomp import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp, series_decomp_multi +from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp, series_decomp_multi +import math +import numpy as np + + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +class Model(nn.Module): + """ + FEDformer performs the attention mechanism on frequency domain and achieved O(N) complexity + """ + def __init__(self, configs): + super(Model, self).__init__() + self.version = 'Fourier' + self.mode_select = 'random' + self.modes = 64 + self.seq_len = configs.seq_len + self.label_len = configs.label_len + self.pred_len = configs.pred_len + self.output_attention = False + + # Decomp + kernel_size = configs.moving_avg + if isinstance(kernel_size, list): + self.decomp = series_decomp_multi(kernel_size) + else: + self.decomp = series_decomp(kernel_size) + + # Embedding + # The series-wise connection inherently contains the sequential information. + # Thus, we can discard the position embedding of transformers. + # self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + # configs.dropout) + # self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, + # configs.dropout) + if configs.embed_type == 0: + self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 1: + self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 2: + self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 3: + self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + + if self.version == 'Wavelets': + encoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=configs.L, base=configs.base) + decoder_self_att = MultiWaveletTransform(ich=configs.d_model, L=configs.L, base=configs.base) + decoder_cross_att = MultiWaveletCross(in_channels=configs.d_model, + out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len, + modes=self.modes, + ich=configs.d_model, + base=configs.base, + activation=configs.cross_activation) + else: + encoder_self_att = FourierBlock(in_channels=configs.d_model, + out_channels=configs.d_model, + seq_len=self.seq_len, + modes=self.modes, + mode_select_method=self.mode_select) + decoder_self_att = FourierBlock(in_channels=configs.d_model, + out_channels=configs.d_model, + seq_len=self.seq_len//2+self.pred_len, + modes=self.modes, + mode_select_method=self.mode_select) + decoder_cross_att = FourierCrossAttention(in_channels=configs.d_model, + out_channels=configs.d_model, + seq_len_q=self.seq_len//2+self.pred_len, + seq_len_kv=self.seq_len, + modes=self.modes, + mode_select_method=self.mode_select) + # Encoder + enc_modes = int(min(self.modes, configs.seq_len//2)) + dec_modes = int(min(self.modes, (configs.seq_len//2+configs.pred_len)//2)) + print('enc_modes: {}, dec_modes: {}'.format(enc_modes, dec_modes)) + + self.encoder = Encoder( + [ + EncoderLayer( + AutoCorrelationLayer( + encoder_self_att, + configs.d_model, configs.n_heads), + + configs.d_model, + configs.d_ff, + moving_avg=configs.moving_avg, + dropout=configs.dropout, + activation=configs.activation + ) for l in range(configs.e_layers) + ], + norm_layer=my_Layernorm(configs.d_model) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AutoCorrelationLayer( + decoder_self_att, + configs.d_model, configs.n_heads), + AutoCorrelationLayer( + decoder_cross_att, + configs.d_model, configs.n_heads), + configs.d_model, + configs.c_out, + configs.d_ff, + moving_avg=configs.moving_avg, + dropout=configs.dropout, + activation=configs.activation, + ) + for l in range(configs.d_layers) + ], + norm_layer=my_Layernorm(configs.d_model), + projection=nn.Linear(configs.d_model, configs.c_out, bias=True) + ) + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + + x_mark_enc = torch.zeros(x_enc.shape[0], x_enc.shape[1], 4).to(x_enc.device) + x_dec = torch.zeros(x_enc.shape[0], 48+720, x_enc.shape[2]).to(x_enc.device) + x_mark_dec = torch.zeros(x_enc.shape[0], 48+720, 4).to(x_enc.device) + + # decomp init + mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) + zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]]).to(device) # cuda() + seasonal_init, trend_init = self.decomp(x_enc) + # decoder input + trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) + seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)) + # enc + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + # dec + dec_out = self.dec_embedding(seasonal_init, x_mark_dec) + seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, + trend=trend_init) + # final + dec_out = trend_part + seasonal_part + + if self.output_attention: + return dec_out[:, -self.pred_len:, :], attns + else: + return dec_out[:, -self.pred_len:, :] # [B, L, D] + +if __name__ == '__main__': + class Configs(object): + ab = 0 + modes = 32 + mode_select = 'random' + # version = 'Fourier' + version = 'Wavelets' + moving_avg = [12, 24] + L = 1 + base = 'legendre' + cross_activation = 'tanh' + seq_len = 96 + label_len = 48 + pred_len = 96 + output_attention = True + enc_in = 7 + dec_in = 7 + d_model = 16 + embed = 'timeF' + dropout = 0.05 + freq = 'h' + factor = 1 + n_heads = 8 + d_ff = 16 + e_layers = 2 + d_layers = 1 + c_out = 7 + activation = 'gelu' + wavelet = 0 + + configs = Configs() + model = Model(configs) + + print('parameter number is {}'.format(sum(p.numel() for p in model.parameters()))) + enc = torch.randn([3, configs.seq_len, 7]) + enc_mark = torch.randn([3, configs.seq_len, 4]) + + dec = torch.randn([3, configs.seq_len//2+configs.pred_len, 7]) + dec_mark = torch.randn([3, configs.seq_len//2+configs.pred_len, 4]) + out = model.forward(enc, enc_mark, dec, dec_mark) + print(out) \ No newline at end of file diff --git a/models/Film.py b/models/Film.py new file mode 100644 index 0000000..ce65d7a --- /dev/null +++ b/models/Film.py @@ -0,0 +1,247 @@ +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +from layers.Embed import DataEmbedding, DataEmbedding_wo_pos +from layers.AutoCorrelation import AutoCorrelation, AutoCorrelationLayer +from layers.FourierCorrelation import SpectralConv1d, SpectralConvCross1d, SpectralConv1d_local, \ + SpectralConvCross1d_local +from layers.mwt import MWT_CZ1d_cross, mwt_transform +from layers.SelfAttention_Family import FullAttention, ProbAttention +from layers.Autoformer_EncDec import Encoder, Decoder, EncoderLayer, DecoderLayer, my_Layernorm, series_decomp, \ + series_decomp_multi +import math +import numpy as np + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class Model(nn.Module): + """ + Autoformer is the first method to achieve the series-wise connection, + with inherent O(LlogL) complexity + """ + + def __init__(self, configs): + super(Model, self).__init__() + # self.modes = configs.modes + self.seq_len = configs.seq_len + self.label_len = configs.label_len + self.pred_len = configs.pred_len + self.output_attention = configs.output_attention + + # Decomp + kernel_size = configs.moving_avg + # self.decomp = series_decomp(kernel_size) + kernel_size = [kernel_size] + self.decomp = series_decomp_multi(kernel_size) + + # Embedding + # The series-wise connection inherently contains the sequential information. + # Thus, we can discard the position embedding of transformers. + self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + + configs.ab = 2 + + if configs.ab == 0: + encoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len, modes1=configs.modes1) + decoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len // 2 + self.pred_len, modes1=configs.modes1) + decoder_cross_att = SpectralConvCross1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len, modes1=configs.modes1) + elif configs.ab == 1: + encoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len) + decoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len // 2 + self.pred_len) + decoder_cross_att = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=False, configs=configs) + elif configs.ab == 2: + encoder_self_att = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + decoder_self_att = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + decoder_cross_att = SpectralConvCross1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len) + elif configs.ab == 3: + encoder_self_att = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention, configs=configs) + decoder_self_att = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention, configs=configs) + decoder_cross_att = AutoCorrelation(False, configs.factor, attention_dropout=configs.dropout, + output_attention=False, configs=configs) + elif configs.ab == 4: + encoder_self_att = FullAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + decoder_self_att = FullAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + decoder_cross_att = FullAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + elif configs.ab == 8: + encoder_self_att = ProbAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + decoder_self_att = ProbAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + decoder_cross_att = ProbAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention) + elif configs.ab == 5: + encoder_self_att = SpectralConvCross1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len, seq_len_kv=self.seq_len) + decoder_self_att = SpectralConvCross1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len // 2 + self.pred_len) + decoder_cross_att = SpectralConvCross1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len) + elif configs.ab == 6: + encoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len, modes1=configs.modes1) + decoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len // 2 + self.pred_len, modes1=configs.modes1) + decoder_cross_att = SpectralConvCross1d_local(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len, modes1=configs.modes1) + elif configs.ab == 7: + # encoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, seq_len=self.seq_len, modes1=configs.modes1) + # decoder_self_att = SpectralConv1d(in_channels=configs.d_model, out_channels=configs.d_model, seq_len=self.seq_len//2+self.pred_len, modes1=configs.modes1) + # encoder_self_att = mwt_transform(ich=configs.d_model, L=3, alpha=int(self.pred_len/2+1)) + # decoder_self_att = mwt_transform(ich=configs.d_model, L=3, alpha=int(self.pred_len/2+1)) + encoder_self_att = mwt_transform(ich=configs.d_model, L=configs.L, base=configs.base) + decoder_self_att = mwt_transform(ich=configs.d_model, L=configs.L, base=configs.base) + # decoder_cross_att = SpectralConvCross1d(in_channels=configs.d_model, out_channels=configs.d_model,seq_len_q=self.seq_len//2+self.pred_len, seq_len_kv=self.seq_len, modes1=configs.modes1) + decoder_cross_att = MWT_CZ1d_cross(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, seq_len_kv=self.seq_len, + modes1=configs.modes1, ich=configs.d_model, base=configs.base, + activation=configs.cross_activation) + elif config.ab == 8: + encoder_self_att = SpectralConv1d_local(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len, modes1=configs.modes1) + decoder_self_att = SpectralConv1d_local(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len=self.seq_len // 2 + self.pred_len, modes1=configs.modes1) + decoder_cross_att = SpectralConvCross1d_local(in_channels=configs.d_model, out_channels=configs.d_model, + seq_len_q=self.seq_len // 2 + self.pred_len, + seq_len_kv=self.seq_len, modes1=configs.modes1) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AutoCorrelationLayer( + encoder_self_att, + configs.d_model, configs.n_heads), + + configs.d_model, + configs.d_ff, + moving_avg=configs.moving_avg, + dropout=configs.dropout, + activation=configs.activation + ) for l in range(configs.e_layers) + ], + norm_layer=my_Layernorm(configs.d_model) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AutoCorrelationLayer( + decoder_self_att, + configs.d_model, configs.n_heads), + AutoCorrelationLayer( + decoder_cross_att, + configs.d_model, configs.n_heads), + configs.d_model, + configs.c_out, + configs.d_ff, + moving_avg=configs.moving_avg, + dropout=configs.dropout, + activation=configs.activation, + ) + for l in range(configs.d_layers) + ], + norm_layer=my_Layernorm(configs.d_model), + projection=nn.Linear(configs.d_model, configs.c_out, bias=True) + ) + a = 2 + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + x_mark_enc = torch.zeros(x_enc.shape[0], x_enc.shape[1], 4).to(x_enc.device) + x_dec = torch.zeros(x_enc.shape[0], 48+720, x_enc.shape[2]).to(x_enc.device) + x_mark_dec = torch.zeros(x_enc.shape[0], 48+720, 4).to(x_enc.device) + # decomp init + mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) + zeros = torch.zeros([x_dec.shape[0], self.pred_len, x_dec.shape[2]]).to(device) # cuda() + seasonal_init, trend_init = self.decomp(x_enc) + # decoder input + trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1) + seasonal_init = F.pad(seasonal_init[:, -self.label_len:, :], (0, 0, 0, self.pred_len)) + # seasonal_init1 = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1) + # enc + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + # dec + dec_out = self.dec_embedding(seasonal_init, x_mark_dec) + seasonal_part, trend_part = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, + trend=trend_init) + # final + dec_out = trend_part + seasonal_part + + if self.output_attention: + return dec_out[:, -self.pred_len:, :], attns + else: + return dec_out[:, -self.pred_len:, :] # [B, L, D] + + +if __name__ == '__main__': + class Configs(object): + ab = 0 + modes1 = 32 + seq_len = 336 + label_len = 48 + pred_len = 720 + output_attention = True + enc_in = 7 + dec_in = 7 + d_model = 16 + embed = 'timeF' + dropout = 0.05 + freq = 'h' + factor = 1 + n_heads = 8 + d_ff = 16 + e_layers = 2 + d_layers = 1 + moving_avg = [25] + c_out = 7 + activation = 'gelu' + wavelet = 0 + + + configs = Configs() + model = Model(configs) + + enc = torch.randn([32, configs.seq_len, 7]) + enc_mark = torch.randn([32, configs.seq_len, 4]) + + dec = torch.randn([32, configs.label_len + configs.pred_len, 7]) + dec_mark = torch.randn([32, configs.label_len + configs.pred_len, 4]) + out = model.forward(enc, enc_mark, dec, dec_mark) + print('input shape', enc.shape) + print('output shape', out[0].shape) + a = 1 + + + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + + print('model size', count_parameters(model) / (1024 * 1024)) \ No newline at end of file diff --git a/models/Informer.py b/models/Informer.py new file mode 100644 index 0000000..3ce850a --- /dev/null +++ b/models/Informer.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from utils.masking import TriangularCausalMask, ProbMask +from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer +from layers.SelfAttention_Family import FullAttention, ProbAttention, AttentionLayer +from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp +import numpy as np + + +class Model(nn.Module): + """ + Informer with Propspare attention in O(LlogL) complexity + """ + def __init__(self, configs): + super(Model, self).__init__() + self.pred_len = configs.pred_len + self.output_attention = configs.output_attention + + # Embedding + if configs.embed_type == 0: + self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 1: + self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 2: + self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + + elif configs.embed_type == 3: + self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + elif configs.embed_type == 4: + self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, + configs.dropout) + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + ProbAttention(False, configs.factor, attention_dropout=configs.dropout, + output_attention=configs.output_attention), + configs.d_model, configs.n_heads), + configs.d_model, + configs.d_ff, + dropout=configs.dropout, + activation=configs.activation + ) for l in range(configs.e_layers) + ], + [ + ConvLayer( + configs.d_model + ) for l in range(configs.e_layers - 1) + ] if configs.distil else None, + norm_layer=torch.nn.LayerNorm(configs.d_model) + ) + # Decoder + self.decoder = Decoder( + [ + DecoderLayer( + AttentionLayer( + ProbAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), + configs.d_model, configs.n_heads), + AttentionLayer( + ProbAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), + configs.d_model, configs.n_heads), + configs.d_model, + configs.d_ff, + dropout=configs.dropout, + activation=configs.activation, + ) + for l in range(configs.d_layers) + ], + norm_layer=torch.nn.LayerNorm(configs.d_model), + projection=nn.Linear(configs.d_model, configs.c_out, bias=True) + ) + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + + x_mark_enc = torch.zeros(x_enc.shape[0], x_enc.shape[1], 4).to(x_enc.device) + x_dec = torch.zeros(x_enc.shape[0], 48+720, x_enc.shape[2]).to(x_enc.device) + x_mark_dec = torch.zeros(x_enc.shape[0], 48+720, 4).to(x_enc.device) + + + enc_out = self.enc_embedding(x_enc, x_mark_enc) + enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) + + dec_out = self.dec_embedding(x_dec, x_mark_dec) + dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) + + if self.output_attention: + return dec_out[:, -self.pred_len:, :], attns + else: + return dec_out[:, -self.pred_len:, :] # [B, L, D] diff --git a/models/Linear.py b/models/Linear.py new file mode 100644 index 0000000..9ab0b3d --- /dev/null +++ b/models/Linear.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Model(nn.Module): + """ + Just one Linear layer + """ + def __init__(self, configs): + super(Model, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.Linear = nn.Linear(self.seq_len, self.pred_len) + # Use this line if you want to visualize the weights + self.Linear.weight = nn.Parameter((1/self.seq_len)*torch.ones([self.pred_len,self.seq_len])) + + def forward(self, x): + # x: [Batch, Input length, Channel] + x = self.Linear(x.permute(0,2,1)).permute(0,2,1) + return x # [Batch, Output length, Channel] \ No newline at end of file diff --git a/models/PatchTST.py b/models/PatchTST.py new file mode 100644 index 0000000..5906049 --- /dev/null +++ b/models/PatchTST.py @@ -0,0 +1,92 @@ +__all__ = ['PatchTST'] + +# Cell +from typing import Callable, Optional +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +import numpy as np + +from layers.PatchTST_backbone import PatchTST_backbone +from layers.PatchTST_layers import series_decomp + + +class Model(nn.Module): + def __init__(self, configs, max_seq_len:Optional[int]=1024, d_k:Optional[int]=None, d_v:Optional[int]=None, norm:str='BatchNorm', attn_dropout:float=0., + act:str="gelu", key_padding_mask:bool='auto',padding_var:Optional[int]=None, attn_mask:Optional[Tensor]=None, res_attention:bool=True, + pre_norm:bool=False, store_attn:bool=False, pe:str='zeros', learn_pe:bool=True, pretrain_head:bool=False, head_type = 'flatten', verbose:bool=False, **kwargs): + + super().__init__() + + # load parameters + c_in = configs.enc_in + context_window = configs.seq_len + target_window = configs.pred_len + + n_layers = configs.e_layers + n_heads = configs.n_heads + d_model = configs.d_model + d_ff = configs.d_ff + dropout = configs.dropout + fc_dropout = configs.fc_dropout + head_dropout = configs.head_dropout + + individual = configs.individual + + patch_len = configs.patch_len + stride = configs.stride + padding_patch = configs.padding_patch + + revin = configs.revin + affine = configs.affine + subtract_last = configs.subtract_last + + decomposition = configs.decomposition + kernel_size = configs.kernel_size + + + # model + self.decomposition = decomposition + if self.decomposition: + self.decomp_module = series_decomp(kernel_size) + self.model_trend = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, + max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, + n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, + dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, + attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, + pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, + pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, + subtract_last=subtract_last, verbose=verbose, **kwargs) + self.model_res = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, + max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, + n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, + dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, + attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, + pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, + pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, + subtract_last=subtract_last, verbose=verbose, **kwargs) + else: + self.model = PatchTST_backbone(c_in=c_in, context_window = context_window, target_window=target_window, patch_len=patch_len, stride=stride, + max_seq_len=max_seq_len, n_layers=n_layers, d_model=d_model, + n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, + dropout=dropout, act=act, key_padding_mask=key_padding_mask, padding_var=padding_var, + attn_mask=attn_mask, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, + pe=pe, learn_pe=learn_pe, fc_dropout=fc_dropout, head_dropout=head_dropout, padding_patch = padding_patch, + pretrain_head=pretrain_head, head_type=head_type, individual=individual, revin=revin, affine=affine, + subtract_last=subtract_last, verbose=verbose, **kwargs) + + + def forward(self, x): # x: [Batch, Input length, Channel] + if self.decomposition: + res_init, trend_init = self.decomp_module(x) + res_init, trend_init = res_init.permute(0,2,1), trend_init.permute(0,2,1) # x: [Batch, Channel, Input length] + res = self.model_res(res_init) + trend = self.model_trend(trend_init) + x = res + trend + x = x.permute(0,2,1) # x: [Batch, Input length, Channel] + else: + x = x.permute(0,2,1) # x: [Batch, Channel, Input length] + x = self.model(x) + x = x.permute(0,2,1) # x: [Batch, Input length, Channel] + return x \ No newline at end of file diff --git a/models/SparseTSF.py b/models/SparseTSF.py new file mode 100644 index 0000000..0b9d358 --- /dev/null +++ b/models/SparseTSF.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +from layers.Embed import PositionalEmbedding + +class Model(nn.Module): + def __init__(self, configs): + super(Model, self).__init__() + + # get parameters + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.enc_in = configs.enc_in + self.period_len = configs.period_len + + self.seg_num_x = self.seq_len // self.period_len + self.seg_num_y = self.pred_len // self.period_len + + self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * self.period_len // 2, + stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False) + + self.linear = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False) + + + def forward(self, x): + batch_size = x.shape[0] + # normalization and permute b,s,c -> b,c,s + seq_mean = torch.mean(x, dim=1).unsqueeze(1) + x = (x - seq_mean).permute(0, 2, 1) + + # 1D convolution aggregation + x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x + + # downsampling: b,c,s -> bc,n,w -> bc,w,n + x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1) + + # sparse forecasting + y = self.linear(x) # bc,w,m + + # upsampling: bc,w,m -> bc,m,w -> b,c,s + y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len) + + # permute and denorm + y = y.permute(0, 2, 1) + seq_mean + + return y diff --git a/models/Stat_models.py b/models/Stat_models.py new file mode 100644 index 0000000..21e853b --- /dev/null +++ b/models/Stat_models.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from tqdm import tqdm +import pmdarima as pm +import threading +from sklearn.ensemble import GradientBoostingRegressor + +class Naive_repeat(nn.Module): + def __init__(self, configs): + super(Naive_repeat, self).__init__() + self.pred_len = configs.pred_len + + def forward(self, x): + B,L,D = x.shape + x = x[:,-1,:].reshape(B,1,D).repeat(self.pred_len,axis=1) + return x # [B, L, D] + +class Naive_thread(threading.Thread): + def __init__(self,func,args=()): + super(Naive_thread,self).__init__() + self.func = func + self.args = args + + def run(self): + self.results = self.func(*self.args) + + def return_result(self): + threading.Thread.join(self) + return self.results + +def _arima(seq,pred_len,bt,i): + model = pm.auto_arima(seq) + forecasts = model.predict(pred_len) + return forecasts,bt,i + +class Arima(nn.Module): + """ + Extremely slow, please sample < 0.1 + """ + def __init__(self, configs): + super(Arima, self).__init__() + self.pred_len = configs.pred_len + + def forward(self, x): + result = np.zeros([x.shape[0],self.pred_len,x.shape[2]]) + threads = [] + for bt,seqs in tqdm(enumerate(x)): + for i in range(seqs.shape[-1]): + seq = seqs[:,i] + one_seq = Naive_thread(func=_arima,args=(seq,self.pred_len,bt,i)) + threads.append(one_seq) + threads[-1].start() + for every_thread in tqdm(threads): + forcast,bt,i = every_thread.return_result() + result[bt,:,i] = forcast + + return result # [B, L, D] + +def _sarima(season,seq,pred_len,bt,i): + model = pm.auto_arima(seq, seasonal=True, m=season) + forecasts = model.predict(pred_len) + return forecasts,bt,i + +class SArima(nn.Module): + """ + Extremely extremely slow, please sample < 0.01 + """ + def __init__(self, configs): + super(SArima, self).__init__() + self.pred_len = configs.pred_len + self.seq_len = configs.seq_len + self.season = 24 + if 'Ettm' in configs.data_path: + self.season = 12 + elif 'ILI' in configs.data_path: + self.season = 1 + if self.season >= self.seq_len: + self.season = 1 + + def forward(self, x): + result = np.zeros([x.shape[0],self.pred_len,x.shape[2]]) + threads = [] + for bt,seqs in tqdm(enumerate(x)): + for i in range(seqs.shape[-1]): + seq = seqs[:,i] + one_seq = Naive_thread(func=_sarima,args=(self.season,seq,self.pred_len,bt,i)) + threads.append(one_seq) + threads[-1].start() + for every_thread in tqdm(threads): + forcast,bt,i = every_thread.return_result() + result[bt,:,i] = forcast + return result # [B, L, D] + +def _gbrt(seq,seq_len,pred_len,bt,i): + model = GradientBoostingRegressor() + model.fit(np.arange(seq_len).reshape(-1,1),seq.reshape(-1,1)) + forecasts = model.predict(np.arange(seq_len,seq_len+pred_len).reshape(-1,1)) + return forecasts,bt,i + +class GBRT(nn.Module): + def __init__(self, configs): + super(GBRT, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + + def forward(self, x): + result = np.zeros([x.shape[0],self.pred_len,x.shape[2]]) + threads = [] + for bt,seqs in tqdm(enumerate(x)): + for i in range(seqs.shape[-1]): + seq = seqs[:,i] + one_seq = Naive_thread(func=_gbrt,args=(seq,self.seq_len,self.pred_len,bt,i)) + threads.append(one_seq) + threads[-1].start() + for every_thread in tqdm(threads): + forcast,bt,i = every_thread.return_result() + result[bt,:,i] = forcast + return result # [B, L, D] diff --git a/models/Transformer.py b/models/Transformer.py new file mode 100644 index 0000000..f690dfb --- /dev/null +++ b/models/Transformer.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +from layers.Embed import PositionalEmbedding +import numpy as np + + + +class Model(nn.Module): + """ + Vanilla Transformer with O(L^2) complexity + """ + def __init__(self, configs): + super(Model, self).__init__() + self.seq_len = configs.seq_len + self.pred_len = configs.pred_len + self.enc_in = configs.enc_in + + self.d_model = 128 + self.n_heads = 4 + self.e_layers = 2 + self.d_layers = 2 + self.d_ff = 256 + + self.transformer_model = nn.Transformer(d_model=self.d_model, nhead=self.n_heads, num_encoder_layers=self.e_layers, + num_decoder_layers=self.d_layers, dim_feedforward=self.d_ff, batch_first=True) + + self.pe = PositionalEmbedding(self.d_model) + + self.input = nn.Linear(self.enc_in, self.d_model) + self.output = nn.Linear(self.d_model, self.enc_in) + + + + def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, + enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): + batch_size = x_enc.shape[0] + + enc_inp = self.input(x_enc) + enc_inp = enc_inp + self.pe(enc_inp) + dec_inp = torch.zeros(batch_size, self.pred_len, self.d_model).float().to(x_enc.device) + dec_inp = dec_inp + self.pe(dec_inp) + + out = self.transformer_model(enc_inp, dec_inp) + + y = self.output(out) + + return y + + + +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer, ConvLayer +# from layers.SelfAttention_Family import FullAttention, AttentionLayer +# from layers.Embed import DataEmbedding,DataEmbedding_wo_pos,DataEmbedding_wo_temp,DataEmbedding_wo_pos_temp +# import numpy as np +# +# +# class Model(nn.Module): +# """ +# Vanilla Transformer with O(L^2) complexity +# """ +# def __init__(self, configs): +# super(Model, self).__init__() +# self.pred_len = configs.pred_len +# self.output_attention = configs.output_attention +# +# # Embedding +# if configs.embed_type == 0: +# self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# elif configs.embed_type == 1: +# self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# elif configs.embed_type == 2: +# self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# self.dec_embedding = DataEmbedding_wo_pos(configs.dec_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# +# elif configs.embed_type == 3: +# self.enc_embedding = DataEmbedding_wo_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# self.dec_embedding = DataEmbedding_wo_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# elif configs.embed_type == 4: +# self.enc_embedding = DataEmbedding_wo_pos_temp(configs.enc_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# self.dec_embedding = DataEmbedding_wo_pos_temp(configs.dec_in, configs.d_model, configs.embed, configs.freq, +# configs.dropout) +# # Encoder +# self.encoder = Encoder( +# [ +# EncoderLayer( +# AttentionLayer( +# FullAttention(False, configs.factor, attention_dropout=configs.dropout, +# output_attention=configs.output_attention), configs.d_model, configs.n_heads), +# configs.d_model, +# configs.d_ff, +# dropout=configs.dropout, +# activation=configs.activation +# ) for l in range(configs.e_layers) +# ], +# norm_layer=torch.nn.LayerNorm(configs.d_model) +# ) +# # Decoder +# self.decoder = Decoder( +# [ +# DecoderLayer( +# AttentionLayer( +# FullAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False), +# configs.d_model, configs.n_heads), +# AttentionLayer( +# FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False), +# configs.d_model, configs.n_heads), +# configs.d_model, +# configs.d_ff, +# dropout=configs.dropout, +# activation=configs.activation, +# ) +# for l in range(configs.d_layers) +# ], +# norm_layer=torch.nn.LayerNorm(configs.d_model), +# projection=nn.Linear(configs.d_model, configs.c_out, bias=True) +# ) +# +# def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, +# enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): +# +# enc_out = self.enc_embedding(x_enc, x_mark_enc) +# enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) +# +# dec_out = self.dec_embedding(x_dec, x_mark_dec) +# dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) +# +# if self.output_attention: +# return dec_out[:, -self.pred_len:, :], attns +# else: +# return dec_out[:, -self.pred_len:, :] # [B, L, D] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9309089 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy +matplotlib +pandas +scikit-learn +torch \ No newline at end of file diff --git a/run_all.sh b/run_all.sh new file mode 100644 index 0000000..08861fe --- /dev/null +++ b/run_all.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +sh scripts/SparseTSF/etth1.sh; +sh scripts/SparseTSF/etth2.sh; +sh scripts/SparseTSF/electricity.sh; +sh scripts/SparseTSF/traffic.sh; +sh scripts/SparseTSF/ettm1.sh; +sh scripts/SparseTSF/ettm2.sh; +sh scripts/SparseTSF/weather.sh; + + + diff --git a/run_longExp.py b/run_longExp.py new file mode 100644 index 0000000..f11b882 --- /dev/null +++ b/run_longExp.py @@ -0,0 +1,156 @@ +import argparse +import os +import torch +from exp.exp_main import Exp_Main +import random +import numpy as np + +parser = argparse.ArgumentParser(description='SparseTSF & other models for Time Series Forecasting') + +# basic config +parser.add_argument('--is_training', type=int, required=True, default=1, help='status') +parser.add_argument('--model_id', type=str, required=True, default='test', help='model id') +parser.add_argument('--model', type=str, required=True, default='SparseTSF', help='model name') + +# data loader +parser.add_argument('--data', type=str, required=True, default='ETTm1', help='dataset type') +parser.add_argument('--root_path', type=str, default='./data/ETT/', help='root path of the data file') +parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file') +parser.add_argument('--features', type=str, default='M', + help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate') +parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task') +parser.add_argument('--freq', type=str, default='h', + help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h') +parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints') + +# forecasting task +parser.add_argument('--seq_len', type=int, default=96, help='input sequence length') +parser.add_argument('--label_len', type=int, default=48, help='start token length') +parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length') + +# SparseTSF +parser.add_argument('--period_len', type=int, default=24, help='period length') + +# DLinear +#parser.add_argument('--individual', action='store_true', default=False, help='DLinear: a linear layer for each variate(channel) individually') + +# PatchTST +parser.add_argument('--fc_dropout', type=float, default=0.05, help='fully connected dropout') +parser.add_argument('--head_dropout', type=float, default=0.0, help='head dropout') +parser.add_argument('--patch_len', type=int, default=16, help='patch length') +parser.add_argument('--stride', type=int, default=8, help='stride') +parser.add_argument('--padding_patch', default='end', help='None: None; end: padding on the end') +parser.add_argument('--revin', type=int, default=1, help='RevIN; True 1 False 0') +parser.add_argument('--affine', type=int, default=0, help='RevIN-affine; True 1 False 0') +parser.add_argument('--subtract_last', type=int, default=0, help='0: subtract mean; 1: subtract last') +parser.add_argument('--decomposition', type=int, default=0, help='decomposition; True 1 False 0') +parser.add_argument('--kernel_size', type=int, default=25, help='decomposition-kernel') +parser.add_argument('--individual', type=int, default=0, help='individual head; True 1 False 0') + +# Formers +parser.add_argument('--embed_type', type=int, default=0, help='0: default 1: value embedding + temporal embedding + positional embedding 2: value embedding + temporal embedding 3: value embedding + positional embedding 4: value embedding') +parser.add_argument('--enc_in', type=int, default=7, help='encoder input size') # DLinear with --individual, use this hyperparameter as the number of channels +parser.add_argument('--dec_in', type=int, default=7, help='decoder input size') +parser.add_argument('--c_out', type=int, default=7, help='output size') +parser.add_argument('--d_model', type=int, default=512, help='dimension of model') +parser.add_argument('--n_heads', type=int, default=8, help='num of heads') +parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers') +parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers') +parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn') +parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average') +parser.add_argument('--factor', type=int, default=1, help='attn factor') +parser.add_argument('--distil', action='store_false', + help='whether to use distilling in encoder, using this argument means not using distilling', + default=True) +parser.add_argument('--dropout', type=float, default=0.05, help='dropout') +parser.add_argument('--embed', type=str, default='learned', + help='time features encoding, options:[timeF, fixed, learned]') +parser.add_argument('--activation', type=str, default='gelu', help='activation') +parser.add_argument('--output_attention', action='store_true', default=False, help='whether to output attention in ecoder') +parser.add_argument('--do_predict', action='store_true', help='whether to predict unseen future data') + +# optimization +parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') +parser.add_argument('--itr', type=int, default=2, help='experiments times') +parser.add_argument('--train_epochs', type=int, default=100, help='train epochs') +parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data') +parser.add_argument('--patience', type=int, default=100, help='early stopping patience') +parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate') +parser.add_argument('--des', type=str, default='test', help='exp description') +parser.add_argument('--loss', type=str, default='mse', help='loss function') +parser.add_argument('--lradj', type=str, default='type3', help='adjust learning rate') +parser.add_argument('--pct_start', type=float, default=0.3, help='pct_start') +parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False) + +# GPU +parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') +parser.add_argument('--gpu', type=int, default=0, help='gpu') +parser.add_argument('--use_multi_gpu', type=int, help='use multiple gpus', default=0) +parser.add_argument('--devices', type=str, default='0,1', help='device ids of multile gpus') +parser.add_argument('--test_flop', action='store_true', default=False, help='See utils/tools for usage') + +args = parser.parse_args() + +# random seed +fix_seed_list = range(2023, 2033) + + +args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False + +if args.use_gpu and args.use_multi_gpu: + args.dvices = args.devices.replace(' ', '') + device_ids = args.devices.split(',') + args.device_ids = [int(id_) for id_ in device_ids] + args.gpu = args.device_ids[0] + +print('Args in experiment:') +print(args) + +Exp = Exp_Main + +if args.is_training: + for ii in range(args.itr): + random.seed(fix_seed_list[ii]) + torch.manual_seed(fix_seed_list[ii]) + np.random.seed(fix_seed_list[ii]) + # setting record of experiments + setting = '{}_{}_{}_ft{}_sl{}_pl{}_{}_{}_seed{}'.format( + args.model_id, + args.model, + args.data, + args.features, + args.seq_len, + args.pred_len, + args.des, + ii, + fix_seed_list[ii]) + + exp = Exp(args) # set experiments + print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) + exp.train(setting) + + print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) + exp.test(setting) + + if args.do_predict: + print('>>>>>>>predicting : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) + exp.predict(setting, True) + + torch.cuda.empty_cache() +else: + ii = 0 + setting = '{}_{}_{}_ft{}_sl{}_pl{}_{}_{}_seed{}'.format( + args.model_id, + args.model, + args.data, + args.features, + args.seq_len, + args.pred_len, + args.des, + ii, + fix_seed_list[ii]) + + exp = Exp(args) # set experiments + print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) + exp.test(setting, test=1) + torch.cuda.empty_cache() diff --git a/scripts/SparseTSF/electricity.sh b/scripts/SparseTSF/electricity.sh new file mode 100644 index 0000000..4af991f --- /dev/null +++ b/scripts/SparseTSF/electricity.sh @@ -0,0 +1,30 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=electricity.csv +model_id_name=Electricity +data_name=custom + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 24 \ + --enc_in 321 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 128 --learning_rate 0.02 +done diff --git a/scripts/SparseTSF/etth1.sh b/scripts/SparseTSF/etth1.sh new file mode 100644 index 0000000..93d4b8d --- /dev/null +++ b/scripts/SparseTSF/etth1.sh @@ -0,0 +1,31 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=ETTh1.csv +model_id_name=ETTh1 +data_name=ETTh1 + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 24 \ + --enc_in 7 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 256 --learning_rate 0.02 +done + diff --git a/scripts/SparseTSF/etth2.sh b/scripts/SparseTSF/etth2.sh new file mode 100644 index 0000000..8979ad7 --- /dev/null +++ b/scripts/SparseTSF/etth2.sh @@ -0,0 +1,31 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=ETTh2.csv +model_id_name=ETTh2 +data_name=ETTh2 + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 24 \ + --enc_in 7 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 256 --learning_rate 0.03 +done + diff --git a/scripts/SparseTSF/ettm1.sh b/scripts/SparseTSF/ettm1.sh new file mode 100644 index 0000000..874ae3f --- /dev/null +++ b/scripts/SparseTSF/ettm1.sh @@ -0,0 +1,30 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=ETTm1.csv +model_id_name=ETTm1 +data_name=ETTm1 + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 4 \ + --enc_in 7 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 256 --learning_rate 0.02 +done diff --git a/scripts/SparseTSF/ettm2.sh b/scripts/SparseTSF/ettm2.sh new file mode 100644 index 0000000..f814123 --- /dev/null +++ b/scripts/SparseTSF/ettm2.sh @@ -0,0 +1,30 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=ETTm2.csv +model_id_name=ETTm2 +data_name=ETTm2 + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 4 \ + --enc_in 7 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 256 --learning_rate 0.02 +done diff --git a/scripts/SparseTSF/traffic.sh b/scripts/SparseTSF/traffic.sh new file mode 100644 index 0000000..fb15172 --- /dev/null +++ b/scripts/SparseTSF/traffic.sh @@ -0,0 +1,31 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=traffic.csv +model_id_name=traffic +data_name=custom + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 24 \ + --enc_in 862 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 128 --learning_rate 0.03 +done + diff --git a/scripts/SparseTSF/weather.sh b/scripts/SparseTSF/weather.sh new file mode 100644 index 0000000..895f331 --- /dev/null +++ b/scripts/SparseTSF/weather.sh @@ -0,0 +1,30 @@ +if [ ! -d "./logs" ]; then + mkdir ./logs +fi + +model_name=SparseTSF + +root_path_name=./dataset/ +data_path_name=weather.csv +model_id_name=weather +data_name=custom + +seq_len=720 +for pred_len in 96 192 336 720 +do + python -u run_longExp.py \ + --is_training 1 \ + --root_path $root_path_name \ + --data_path $data_path_name \ + --model_id $model_id_name'_'$seq_len'_'$pred_len \ + --model $model_name \ + --data $data_name \ + --features M \ + --seq_len $seq_len \ + --pred_len $pred_len \ + --period_len 4 \ + --enc_in 21 \ + --train_epochs 30 \ + --patience 5 \ + --itr 1 --batch_size 256 --learning_rate 0.02 +done diff --git a/utils/augmentations.py b/utils/augmentations.py new file mode 100644 index 0000000..4587641 --- /dev/null +++ b/utils/augmentations.py @@ -0,0 +1,163 @@ +import torch +import numpy as np + + +def augmentation(augment_time): + if augment_time == 'batch': + return BatchAugmentation() + elif augment_time == 'dataset': + return DatasetAugmentation() + + +class BatchAugmentation(): + def __init__(self): + pass + + def freq_mask(self, x, y, rate=0.5, dim=1): + xy = torch.cat([x, y], dim=1) + xy_f = torch.fft.rfft(xy, dim=dim) + m = torch.cuda.FloatTensor(xy_f.shape).uniform_() < rate + freal = xy_f.real.masked_fill(m, 0) + fimag = xy_f.imag.masked_fill(m, 0) + xy_f = torch.complex(freal, fimag) + xy = torch.fft.irfft(xy_f, dim=dim) + return xy + + def freq_mix(self, x, y, rate=0.5, dim=1): + xy = torch.cat([x, y], dim=dim) + xy_f = torch.fft.rfft(xy, dim=dim) + + m = torch.cuda.FloatTensor(xy_f.shape).uniform_() < rate + amp = abs(xy_f) + _, index = amp.sort(dim=dim, descending=True) + dominant_mask = index > 2 + m = torch.bitwise_and(m, dominant_mask) + freal = xy_f.real.masked_fill(m, 0) + fimag = xy_f.imag.masked_fill(m, 0) + + b_idx = np.arange(x.shape[0]) + np.random.shuffle(b_idx) + x2, y2 = x[b_idx], y[b_idx] + xy2 = torch.cat([x2, y2], dim=dim) + xy2_f = torch.fft.rfft(xy2, dim=dim) + + m = torch.bitwise_not(m) + freal2 = xy2_f.real.masked_fill(m, 0) + fimag2 = xy2_f.imag.masked_fill(m, 0) + + freal += freal2 + fimag += fimag2 + + xy_f = torch.complex(freal, fimag) + + xy = torch.fft.irfft(xy_f, dim=dim) + return xy + + def noise(self, x, y, rate=0.05, dim=1): + xy = torch.cat([x, y], dim=1) + noise_xy = (torch.rand(xy.shape) - 0.5) * 0.1 + xy = xy + noise_xy.cuda() + return xy + + def noise_input(self, x, y, rate=0.05, dim=1): + noise = (torch.rand(x.shape) - 0.5) * 0.1 + x = x + noise.cuda() + xy = torch.cat([x, y], dim=1) + return xy + + def vFlip(self, x, y, rate=0.05, dim=1): + # vertically flip the xy + xy = torch.cat([x, y], dim=1) + xy = -xy + return xy + + def hFlip(self, x, y, rate=0.05, dim=1): + # horizontally flip the xy + xy = torch.cat([x, y], dim=1) + # reverse the order of dim 1 + xy = xy.flip(dims=[dim]) + return xy + + def time_combination(self, x, y, rate=0.5, dim=1): + xy = torch.cat([x, y], dim=dim) + + b_idx = np.arange(x.shape[0]) + np.random.shuffle(b_idx) + x2, y2 = x[b_idx], y[b_idx] + xy2 = torch.cat([x2, y2], dim=dim) + + xy = (xy + xy2) / 2 + return xy + + def magnitude_warping(self, x, y, rate=0.5, dim=1): + pass + + def linear_upsampling(self, x, y, rate=0.5, dim=1): + xy = torch.cat([x, y], dim=dim) + original_shape = xy.shape + # randomly cut a segment from xy the length should be half of it + # generate a random integer from 0 to the length of xy + start_point = np.random.randint(0, original_shape[1] // 2) + + xy = xy[:, start_point:start_point + original_shape[1] // 2, :] + + # interpolate the xy to the original_shape + xy = xy.permute(0, 2, 1) + xy = torch.nn.functional.interpolate(xy, scale_factor=2, mode='linear') + xy = xy.permute(0, 2, 1) + return xy + + +class DatasetAugmentation(): + def __init__(self): + pass + + def freq_dropout(self, x, y, dropout_rate=0.2, dim=0, keep_dominant=True): + x, y = torch.from_numpy(x), torch.from_numpy(y) + + xy = torch.cat([x, y], dim=0) + xy_f = torch.fft.rfft(xy, dim=0) + + m = torch.FloatTensor(xy_f.shape).uniform_() < dropout_rate + + # amp = abs(xy_f) + # _,index = amp.sort(dim=dim, descending=True) + # dominant_mask = index > 5 + # m = torch.bitwise_and(m,dominant_mask) + + freal = xy_f.real.masked_fill(m, 0) + fimag = xy_f.imag.masked_fill(m, 0) + xy_f = torch.complex(freal, fimag) + xy = torch.fft.irfft(xy_f, dim=dim) + + x, y = xy[:x.shape[0], :].numpy(), xy[-y.shape[0]:, :].numpy() + return x, y + + def freq_mix(self, x, y, x2, y2, dropout_rate=0.2): + x, y = torch.from_numpy(x), torch.from_numpy(y) + + xy = torch.cat([x, y], dim=0) + xy_f = torch.fft.rfft(xy, dim=0) + m = torch.FloatTensor(xy_f.shape).uniform_() < dropout_rate + amp = abs(xy_f) + _, index = amp.sort(dim=0, descending=True) + dominant_mask = index > 2 + m = torch.bitwise_and(m, dominant_mask) + freal = xy_f.real.masked_fill(m, 0) + fimag = xy_f.imag.masked_fill(m, 0) + + x2, y2 = torch.from_numpy(x2), torch.from_numpy(y2) + xy2 = torch.cat([x2, y2], dim=0) + xy2_f = torch.fft.rfft(xy2, dim=0) + + m = torch.bitwise_not(m) + freal2 = xy2_f.real.masked_fill(m, 0) + fimag2 = xy2_f.imag.masked_fill(m, 0) + + freal += freal2 + fimag += fimag2 + + xy_f = torch.complex(freal, fimag) + xy = torch.fft.irfft(xy_f, dim=0) + x, y = xy[:x.shape[0], :].numpy(), xy[-y.shape[0]:, :].numpy() + return x, y \ No newline at end of file diff --git a/utils/masking.py b/utils/masking.py new file mode 100644 index 0000000..7924ab8 --- /dev/null +++ b/utils/masking.py @@ -0,0 +1,39 @@ +import torch + + +class TriangularCausalMask(): + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask + + +class ProbMask(): + def __init__(self, B, H, L, index, scores, device="cpu"): + _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1) + _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1]) + indicator = _mask_ex[torch.arange(B)[:, None, None], + torch.arange(H)[None, :, None], + index, :].to(device) + self._mask = indicator.view(scores.shape).to(device) + + @property + def mask(self): + return self._mask + + +class LocalMask(): + def __init__(self, B, L,S,device="cpu"): + mask_shape = [B, 1, L, S] + with torch.no_grad(): + self.len = math.ceil(np.log2(L)) + self._mask1 = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + self._mask2 = ~torch.triu(torch.ones(mask_shape,dtype=torch.bool),diagonal=-self.len).to(device) + self._mask = self._mask1+self._mask2 + @property + def mask(self): + return self._mask diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..bb6544b --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,44 @@ +import numpy as np + + +def RSE(pred, true): + return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2)) + + +def CORR(pred, true): + u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0) + d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0)) + d += 1e-12 + return 0.01*(u / d).mean(-1) + + +def MAE(pred, true): + return np.mean(np.abs(pred - true)) + + +def MSE(pred, true): + return np.mean((pred - true) ** 2) + + +def RMSE(pred, true): + return np.sqrt(MSE(pred, true)) + + +def MAPE(pred, true): + return np.mean(np.abs((pred - true) / true)) + + +def MSPE(pred, true): + return np.mean(np.square((pred - true) / true)) + + +def metric(pred, true): + mae = MAE(pred, true) + mse = MSE(pred, true) + rmse = RMSE(pred, true) + mape = MAPE(pred, true) + mspe = MSPE(pred, true) + rse = RSE(pred, true) + corr = CORR(pred, true) + + return mae, mse, rmse, mape, mspe, rse, corr diff --git a/utils/timefeatures.py b/utils/timefeatures.py new file mode 100644 index 0000000..f5678f0 --- /dev/null +++ b/utils/timefeatures.py @@ -0,0 +1,134 @@ +from typing import List + +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset + + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class SecondOfMinute(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + + +class MinuteOfHour(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + + +class HourOfDay(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + + +class DayOfWeek(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + + +class DayOfMonth(TimeFeature): + """Day of month encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + + +class DayOfYear(TimeFeature): + """Day of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + + +class MonthOfYear(TimeFeature): + """Month of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + + +class WeekOfYear(TimeFeature): + """Week of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.isocalendar().week - 1) / 52.0 - 0.5 + + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + Returns a list of time features that will be appropriate for the given frequency string. + Parameters + ---------- + freq_str + Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + + supported_freq_msg = f""" + Unsupported frequency {freq_str} + The following frequencies are supported: + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + """ + raise RuntimeError(supported_freq_msg) + + +def time_features(dates, freq='h'): + return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)]) diff --git a/utils/tools.py b/utils/tools.py new file mode 100644 index 0000000..5eff6b2 --- /dev/null +++ b/utils/tools.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +import matplotlib.pyplot as plt +import time + +plt.switch_backend('agg') + + +def adjust_learning_rate(optimizer, scheduler, epoch, args, printout=True): + # lr = args.learning_rate * (0.2 ** (epoch // 2)) + if args.lradj == 'type1': + lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch - 1) // 1))} + elif args.lradj == 'type2': + lr_adjust = { + 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, + 10: 5e-7, 15: 1e-7, 20: 5e-8 + } + elif args.lradj == 'type3': + lr_adjust = {epoch: args.learning_rate if epoch < 3 else args.learning_rate * (0.8 ** ((epoch - 3) // 1))} + elif args.lradj == 'constant': + lr_adjust = {epoch: args.learning_rate} + elif args.lradj == '3': + lr_adjust = {epoch: args.learning_rate if epoch < 10 else args.learning_rate*0.1} + elif args.lradj == '4': + lr_adjust = {epoch: args.learning_rate if epoch < 15 else args.learning_rate*0.1} + elif args.lradj == '5': + lr_adjust = {epoch: args.learning_rate if epoch < 25 else args.learning_rate*0.1} + elif args.lradj == '6': + lr_adjust = {epoch: args.learning_rate if epoch < 5 else args.learning_rate*0.1} + elif args.lradj == 'TST': + lr_adjust = {epoch: scheduler.get_last_lr()[0]} + + if epoch in lr_adjust.keys(): + lr = lr_adjust[epoch] + for param_group in optimizer.param_groups: + param_group['lr'] = lr + if printout: print('Updating learning rate to {}'.format(lr)) + + +class EarlyStopping: + def __init__(self, patience=7, verbose=False, delta=0): + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + + def __call__(self, val_loss, model, path): + score = -val_loss + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model, path) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model, path) + self.counter = 0 + + def save_checkpoint(self, val_loss, model, path): + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') + self.val_loss_min = val_loss + + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +class StandardScaler(): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def transform(self, data): + return (data - self.mean) / self.std + + def inverse_transform(self, data): + return (data * self.std) + self.mean + + +def visual(true, preds=None, name='./pic/test.pdf'): + """ + Results visualization + """ + plt.figure() + plt.plot(true, label='GroundTruth', linewidth=2) + if preds is not None: + plt.plot(preds, label='Prediction', linewidth=2) + plt.legend() + plt.savefig(name, bbox_inches='tight') + +def test_params_flop(model,x_shape): + """ + If you want to thest former's flop, you need to give default value to inputs in model.forward(), the following code can only pass one argument to forward() + """ + # model_params = 0 + # for parameter in model.parameters(): + # model_params += parameter.numel() + # print('INFO: Trainable parameter count: {:.2f}M'.format(model_params / 1000000.0)) + # from ptflops import get_model_complexity_info + # with torch.cuda.device(0): + # macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=True) + # # print('Flops:' + flops) + # # print('Params:' + params) + # print('{:<30} {:<8}'.format('Computational complexity: ', macs)) + # print('{:<30} {:<8}'.format('Number of parameters: ', params)) + from ptflops import get_model_complexity_info + with torch.cuda.device(0): + macs, params = get_model_complexity_info(model.cuda(), x_shape, as_strings=True, print_per_layer_stat=False) + print('{:<30} {:<8}'.format('Computational complexity: ', macs)) + print('{:<30} {:<8}'.format('Number of parameters: ', params)) \ No newline at end of file