forked from gzerveas/mvts_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added failsafe reqs, documentation for own datasets
- Loading branch information
Showing
4 changed files
with
150 additions
and
386 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
class MachineData(BaseData): | ||
""" | ||
Dataset class for Machine dataset. | ||
Attributes: | ||
all_df: dataframe indexed by ID, with multiple rows corresponding to the same index (sample). | ||
Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature. | ||
feature_df: contains the subset of columns of `all_df` which correspond to selected features | ||
feature_names: names of columns contained in `feature_df` (same as feature_df.columns) | ||
all_IDs: IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() ) | ||
max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used. | ||
(Moreover, script argument overrides this attribute) | ||
""" | ||
|
||
def __init__(self, root_dir, file_list=None, pattern=None, n_proc=1, limit_size=None, config=None): | ||
|
||
self.set_num_processes(n_proc=n_proc) | ||
|
||
self.all_df = self.load_all(root_dir, file_list=file_list, pattern=pattern) | ||
self.all_df = self.all_df.sort_values(by=['machine_record_index']) # datasets is presorted | ||
self.all_df = self.all_df.set_index('machine_record_index') | ||
self.all_IDs = self.all_df.index.unique() # all sample (session) IDs | ||
self.max_seq_len = 66 | ||
if limit_size is not None: | ||
if limit_size > 1: | ||
limit_size = int(limit_size) | ||
else: # interpret as proportion if in (0, 1] | ||
limit_size = int(limit_size * len(self.all_IDs)) | ||
self.all_IDs = self.all_IDs[:limit_size] | ||
self.all_df = self.all_df.loc[self.all_IDs] | ||
|
||
self.feature_names = ['feed_speed', 'current', 'voltage', 'motor_current', 'power'] | ||
self.feature_df = self.all_df[self.feature_names] | ||
|
||
def load_all(self, root_dir, file_list=None, pattern=None): | ||
""" | ||
Loads datasets from csv files contained in `root_dir` into a dataframe, optionally choosing from `pattern` | ||
Args: | ||
root_dir: directory containing all individual .csv files | ||
file_list: optionally, provide a list of file paths within `root_dir` to consider. | ||
Otherwise, entire `root_dir` contents will be used. | ||
pattern: optionally, apply regex string to select subset of files | ||
Returns: | ||
all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files | ||
""" | ||
# each file name corresponds to another date. Also tools (A, B) and others. | ||
|
||
# Select paths for training and evaluation | ||
if file_list is None: | ||
data_paths = glob.glob(os.path.join(root_dir, '*')) # list of all paths | ||
else: | ||
data_paths = [os.path.join(root_dir, p) for p in file_list] | ||
if len(data_paths) == 0: | ||
raise Exception('No files found using: {}'.format(os.path.join(root_dir, '*'))) | ||
|
||
if pattern is None: | ||
# by default evaluate on | ||
selected_paths = data_paths | ||
else: | ||
selected_paths = list(filter(lambda x: re.search(pattern, x), data_paths)) | ||
|
||
input_paths = [p for p in selected_paths if os.path.isfile(p) and p.endswith('.csv')] | ||
if len(input_paths) == 0: | ||
raise Exception("No .csv files found using pattern: '{}'".format(pattern)) | ||
|
||
if self.n_proc > 1: | ||
# Load in parallel | ||
_n_proc = min(self.n_proc, len(input_paths)) # no more than file_names needed here | ||
logger.info("Loading {} datasets files using {} parallel processes ...".format(len(input_paths), _n_proc)) | ||
with Pool(processes=_n_proc) as pool: | ||
all_df = pd.concat(pool.map(machineData.load_single, input_paths)) | ||
else: # read 1 file at a time | ||
all_df = pd.concat(machineData.load_single(path) for path in input_paths) | ||
|
||
return all_df | ||
|
||
@staticmethod | ||
def load_single(filepath): | ||
df = machineData.read_data(filepath) | ||
df = machineData.select_columns(df) | ||
num_nan = df.isna().sum().sum() | ||
if num_nan > 0: | ||
logger.warning("{} nan values in {} will be replaced by 0".format(num_nan, filepath)) | ||
df = df.fillna(0) | ||
|
||
return df | ||
|
||
@staticmethod | ||
def read_data(filepath): | ||
"""Reads a single .csv, which typically contains a day of datasets of various machine sessions. | ||
""" | ||
df = pd.read_csv(filepath) | ||
return df | ||
|
||
@staticmethod | ||
def select_columns(df): | ||
"""""" | ||
df = df.rename(columns={"per_energy": "power"}) | ||
# Sometimes 'diff_time' is not measured correctly (is 0), and power ('per_energy') becomes infinite | ||
is_error = df['power'] > 1e16 | ||
df.loc[is_error, 'power'] = df.loc[is_error, 'true_energy'] / df['diff_time'].median() | ||
|
||
df['machine_record_index'] = df['machine_record_index'].astype(int) | ||
keep_cols = ['machine_record_index', 'wire_feed_speed', 'current', 'voltage', 'motor_current', 'power'] | ||
df = df[keep_cols] | ||
|
||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
ipdb==0.13.9 | ||
tqdm==4.61.1 | ||
xlrd==1.2.0 | ||
pandas==1.1.3 | ||
xlutils==2.0.0 | ||
tabulate==0.8.9 | ||
xlwt==1.3.0 | ||
torch==1.7.1 | ||
tensorboard==1.15.0 | ||
matplotlib==3.3.4 | ||
numpy==1.19.2 | ||
scikit_learn==0.23.2 | ||
protobuf==3.17.3 | ||
sktime==0.4.1 |
Oops, something went wrong.