-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
31 changed files
with
917 additions
and
251 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,43 @@ | ||
# Copyright (c) 2022 Raven Stock. email:[email protected] | ||
|
||
import os | ||
from numpy import source | ||
import torch | ||
import numpy as np | ||
from torch.utils.data import DataLoader | ||
from sklearn.model_selection import train_test_split | ||
|
||
|
||
__all__ = ['DDS'] | ||
|
||
|
||
class DDS(object): | ||
def __init__(self, cfg): | ||
''' | ||
base_dir:存放故障类别的根目录 | ||
source:源域 | ||
target:目标域 | ||
batch_size:batch_size | ||
test_size:测试集大小 | ||
''' | ||
self.table = ['20R_0HP', '20R_4HP', '20R_8HP', | ||
'30R_0HP', '30R_4HP', '30R_8HP', | ||
'40R_0HP', '40R_4HP', '40R_8HP',] | ||
if cfg.DATASET.SOURCE not in self.table: | ||
raise ValueError("param \'soruce\' error") | ||
if cfg.DATASET.TARGET not in self.table: | ||
raise ValueError("param \'target\' error") | ||
if cfg.DATASET.SOURCE == cfg.DATASET.TARGET: | ||
Warning('source and target are the same param!') | ||
|
||
self.base_dir = cfg.DATASET.ROOT | ||
self.soruce = cfg.DATASET.SOURCE | ||
self.target = cfg.DATASET.TARGET | ||
self.batch_size = cfg.TRAIN.BATCH_SIZE | ||
self.test_size = cfg.DATASET.TEST_SIZE | ||
self.shuffle = cfg.DATASET.SHUFFLE | ||
self.num_workers = cfg.WORKERS | ||
self.source = cfg.DATASET.SOURCE | ||
self.target = cfg.DATASET.TARGET | ||
|
||
def load(self, domain: str = 'source'): | ||
''' | ||
加载数据集,返回pytorch官方提供的训练代码的DataLoader的样子(训练和测试,共两个),具体请参考torch官方的训练示例。 | ||
加载数据集, 返回pytorch官方提供的训练代码的DataLoader的样子(训练和测试,共两个), 具体请参考torch官方的训练示例。 | ||
也可用于普通网络的训练集、测试集加载 | ||
''' | ||
assert domain == 'source' or domain == 'target', f'domain {domain} not found' | ||
|
||
file_list = os.listdir(self.base_dir) | ||
if domain == 'source': | ||
file_list = list(map(lambda x:os.path.join(self.base_dir, | ||
x, self.soruce+'.npy'), file_list)) # 所有.npy的绝对路径 | ||
x_train_path = os.path.join(self.base_dir, self.source, 'x_train.pt') | ||
x_test_path = os.path.join(self.base_dir, self.source, 'x_test.pt') | ||
y_train_path = os.path.join(self.base_dir, self.source, 'y_train.pt') | ||
y_test_path = os.path.join(self.base_dir, self.source, 'y_test.pt') | ||
else: | ||
<<<<<<< HEAD | ||
x_train_path = os.path.join(self.base_dir, self.target, 'x_train.pt') | ||
x_test_path = os.path.join(self.base_dir, self.target, 'x_test.pt') | ||
y_train_path = os.path.join(self.base_dir, self.target, 'y_train.pt') | ||
y_test_path = os.path.join(self.base_dir, self.target, 'y_test.pt') | ||
|
||
x_train = torch.load(x_train_path) | ||
x_test = torch.load(x_test_path) | ||
y_train = torch.load(y_train_path) | ||
y_test = torch.load(y_test_path) | ||
======= | ||
file_list = list(map(lambda x:os.path.join(self.base_dir, | ||
x, self.target+'.npy'), file_list)) # 所有.npy的绝对路径 | ||
|
||
|
@@ -71,25 +60,16 @@ def load(self, domain: str = 'source'): | |
# 利用np.random.permutaion函数,获得打乱后的行数,输出permutation | ||
x_data = x_data[permutation] | ||
label = label[permutation] | ||
|
||
|
||
x_data = torch.tensor(x_data).to(torch.float32) | ||
x_data = torch.unsqueeze(x_data, dim=1) | ||
|
||
label = torch.tensor(label).to(torch.long) | ||
# 划分训练集和测试集 | ||
x_train, x_test, y_train, y_test = train_test_split(x_data, label, test_size=self.test_size) | ||
>>>>>>> d8c7acd5ff08caba0c2506de9b082671bcd6f928 | ||
|
||
# 转化成DataLoader | ||
x_train = torch.tensor(x_train) | ||
x_train = x_train.to(torch.float32) | ||
x_train = torch.unsqueeze(x_train, dim=1) # 添加一个维度,通道数 | ||
|
||
x_test = torch.tensor(x_test) | ||
x_test = x_test.to(torch.float32) | ||
x_test = torch.unsqueeze(x_test, dim=1) # 添加一个维度,通道数 | ||
|
||
y_train = torch.tensor(y_train) | ||
y_train = y_train.to(torch.long) | ||
|
||
y_test = torch.tensor(y_test) | ||
y_test = y_test.to(torch.long) | ||
|
||
combined_train = [] | ||
for x, y in zip(x_train, y_train): | ||
combined_train.append((x, y)) | ||
|
@@ -100,14 +80,12 @@ def load(self, domain: str = 'source'): | |
|
||
data_train = DataLoader(combined_train, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
num_workers=self.num_workers, | ||
shuffle=self.shuffle, | ||
drop_last=True) | ||
|
||
data_test = DataLoader(combined_test, | ||
batch_size=self.batch_size, | ||
shuffle=True, | ||
num_workers=self.num_workers, | ||
shuffle=self.shuffle, | ||
drop_last=False) | ||
|
||
return data_train, data_test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.