Skip to content

Commit

Permalink
bug fixed and added
Browse files Browse the repository at this point in the history
  • Loading branch information
afdeaf committed Apr 20, 2022
1 parent 30a69cc commit 67ed71f
Show file tree
Hide file tree
Showing 31 changed files with 917 additions and 251 deletions.
23 changes: 14 additions & 9 deletions DANN/configs/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@

_C = CN()
_C.SEED = 77
_C.WORKERS = 8
_C.WORKERS = 1
_C.TRAINER = 'DANN'
_C.SHUTDOWN = False

# ========== training ==========
_C.TRAIN = CN()
_C.TRAIN.TEST_FREQ = 50
_C.TRAIN.PRINT_FREQ = 50
_C.TRAIN.SAVE_FREQ = 100
_C.TRAIN.TTL_ITE = 1000
_C.TRAIN.TEST_FREQ = 100
_C.TRAIN.PRINT_FREQ = 100
_C.TRAIN.SAVE_FREQ = 300
_C.TRAIN.TTL_ITE = 5000

_C.TRAIN.BATCH_SIZE = 32
_C.TRAIN.LR = 1e-4
Expand All @@ -43,13 +44,13 @@

# ========== datasets ==========
_C.DATASET = CN()
_C.DATASET.NUM_CLASSES = 4
_C.DATASET.NUM_CLASSES = 9
_C.DATASET.NAME = 'DDS'
_C.DATASET.SOURCE = '20R_0HP'
_C.DATASET.TARGET = '20R_8HP'
_C.DATASET.ROOT = r'F:\work\jupyter\数据预处理\datasets'
_C.DATASET.ROOT = r'E:\Raven\jupyter\Transfer Learning\Dataset\dds划分_full'
_C.DATASET.SHUFFLE = True
_C.DATASET.TEST_SIZE = 0.1
_C.DATASET.TEST_SIZE = 0.2

# ========== method ==========
_C.METHOD = CN()
Expand All @@ -72,6 +73,10 @@ def get_default_and_update_cfg(args):
cfg.DATASET_ROOT = args.data_root
if args.num_classes != 4:
cfg.DATASET.NUM_CLASSES = args.num_classes
if args.source:
cfg.DATASET.SOURCE = args.source
if args.target:
cfg.DATASET.TARGET = args.target

# ====output====
if args.output_root:
Expand All @@ -80,7 +85,7 @@ def get_default_and_update_cfg(args):
cfg.TRAIN.OUTPUT_DIR = args.output_dir
else:
# eg:20R_0HP_To_20R_8HP_seed77
cfg.TRAIN.OUTPUT_DIR = ''.join(cfg.DATASET.SOURCE) + '_To' + '_'.join(cfg.DATASET.TARGET) + '_seed' + str(args.seed)
cfg.TRAIN.OUTPUT_DIR = ''.join(cfg.DATASET.SOURCE) + '_To_' + ''.join(cfg.DATASET.TARGET) + '_seed' + str(args.seed)
cfg.TRAIN.OUTPUT_CKPT = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'ckpt', cfg.TRAIN.OUTPUT_DIR)
cfg.TRAIN.OUTPUT_LOG = os.path.join(cfg.TRAIN.OUTPUT_ROOT, 'log', cfg.TRAIN.OUTPUT_DIR)
# make dirs
Expand Down
76 changes: 27 additions & 49 deletions DANN/datasets/dds.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,43 @@
# Copyright (c) 2022 Raven Stock. email:[email protected]

import os
from numpy import source
import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split


__all__ = ['DDS']


class DDS(object):
def __init__(self, cfg):
'''
base_dir:存放故障类别的根目录
source:源域
target:目标域
batch_size:batch_size
test_size:测试集大小
'''
self.table = ['20R_0HP', '20R_4HP', '20R_8HP',
'30R_0HP', '30R_4HP', '30R_8HP',
'40R_0HP', '40R_4HP', '40R_8HP',]
if cfg.DATASET.SOURCE not in self.table:
raise ValueError("param \'soruce\' error")
if cfg.DATASET.TARGET not in self.table:
raise ValueError("param \'target\' error")
if cfg.DATASET.SOURCE == cfg.DATASET.TARGET:
Warning('source and target are the same param!')

self.base_dir = cfg.DATASET.ROOT
self.soruce = cfg.DATASET.SOURCE
self.target = cfg.DATASET.TARGET
self.batch_size = cfg.TRAIN.BATCH_SIZE
self.test_size = cfg.DATASET.TEST_SIZE
self.shuffle = cfg.DATASET.SHUFFLE
self.num_workers = cfg.WORKERS
self.source = cfg.DATASET.SOURCE
self.target = cfg.DATASET.TARGET

def load(self, domain: str = 'source'):
'''
加载数据集返回pytorch官方提供的训练代码的DataLoader的样子(训练和测试,共两个)具体请参考torch官方的训练示例。
加载数据集, 返回pytorch官方提供的训练代码的DataLoader的样子(训练和测试,共两个), 具体请参考torch官方的训练示例。
也可用于普通网络的训练集、测试集加载
'''
assert domain == 'source' or domain == 'target', f'domain {domain} not found'

file_list = os.listdir(self.base_dir)
if domain == 'source':
file_list = list(map(lambda x:os.path.join(self.base_dir,
x, self.soruce+'.npy'), file_list)) # 所有.npy的绝对路径
x_train_path = os.path.join(self.base_dir, self.source, 'x_train.pt')
x_test_path = os.path.join(self.base_dir, self.source, 'x_test.pt')
y_train_path = os.path.join(self.base_dir, self.source, 'y_train.pt')
y_test_path = os.path.join(self.base_dir, self.source, 'y_test.pt')
else:
<<<<<<< HEAD
x_train_path = os.path.join(self.base_dir, self.target, 'x_train.pt')
x_test_path = os.path.join(self.base_dir, self.target, 'x_test.pt')
y_train_path = os.path.join(self.base_dir, self.target, 'y_train.pt')
y_test_path = os.path.join(self.base_dir, self.target, 'y_test.pt')

x_train = torch.load(x_train_path)
x_test = torch.load(x_test_path)
y_train = torch.load(y_train_path)
y_test = torch.load(y_test_path)
=======
file_list = list(map(lambda x:os.path.join(self.base_dir,
x, self.target+'.npy'), file_list)) # 所有.npy的绝对路径

Expand All @@ -71,25 +60,16 @@ def load(self, domain: str = 'source'):
# 利用np.random.permutaion函数,获得打乱后的行数,输出permutation
x_data = x_data[permutation]
label = label[permutation]


x_data = torch.tensor(x_data).to(torch.float32)
x_data = torch.unsqueeze(x_data, dim=1)

label = torch.tensor(label).to(torch.long)
# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x_data, label, test_size=self.test_size)
>>>>>>> d8c7acd5ff08caba0c2506de9b082671bcd6f928

# 转化成DataLoader
x_train = torch.tensor(x_train)
x_train = x_train.to(torch.float32)
x_train = torch.unsqueeze(x_train, dim=1) # 添加一个维度,通道数

x_test = torch.tensor(x_test)
x_test = x_test.to(torch.float32)
x_test = torch.unsqueeze(x_test, dim=1) # 添加一个维度,通道数

y_train = torch.tensor(y_train)
y_train = y_train.to(torch.long)

y_test = torch.tensor(y_test)
y_test = y_test.to(torch.long)

combined_train = []
for x, y in zip(x_train, y_train):
combined_train.append((x, y))
Expand All @@ -100,14 +80,12 @@ def load(self, domain: str = 'source'):

data_train = DataLoader(combined_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
shuffle=self.shuffle,
drop_last=True)

data_test = DataLoader(combined_test,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
shuffle=self.shuffle,
drop_last=False)

return data_train, data_test
11 changes: 8 additions & 3 deletions DANN/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from configs.defaults import get_default_and_update_cfg
from utils.utils import set_seed, create_logger
from trainer.dann import *
import shutil


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=77, type=int)
parser.add_argument('--source', default='20R_0HP', help='Source domain name')
parser.add_argument('--target', default='20R_8HP', help='Target domain name')
parser.add_argument('--output_root', default=None, type=str, help='Output root path')
parser.add_argument('--target', default='20R_4HP', help='Target domain name')
parser.add_argument('--output_root', default='OUTPUT', type=str, help='Output root path')
parser.add_argument('--output_dir', default=None, type=str, help='Output path, subdir under output_root')
parser.add_argument('--data_root', default=None, type=str, help='path to dataset root')
parser.add_argument('--num_classes', default=4, type=int, help='The number of classes')
parser.add_argument('--num_classes', default=9, type=int, help='The number of classes')
args = parser.parse_args()
return args

Expand All @@ -35,4 +36,8 @@ def main():


if __name__ == '__main__':
try:
shutil.rmtree(r'E:\Raven\VScode\DANN\OUTPUT')
except:
pass
main()
4 changes: 4 additions & 0 deletions DANN/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def build_head(self):
'''
Build classification head
'''
# self.fc = nn.Sequential(
# nn.Flatten(),
# nn.Linear(self.fdim, self.num_classes)
# )
self.fc = nn.Linear(self.fdim, self.num_classes)
nn.init.kaiming_normal_(self.fc.weight)
nn.init.zeros_(self.fc.bias)
Expand Down
7 changes: 6 additions & 1 deletion DANN/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from models.base_model import BaseModel
from torch import nn
import torch

__all__ = ['cnn']

Expand All @@ -14,16 +15,19 @@ def __init__(self, num_classes: int = 4, **kwargs):
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 输出为29*29
nn.Dropout2d(0.3),

nn.Conv2d(16, 32, 5), # 输出为25*25
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 输出为12*12
nn.Dropout2d(0.3),

nn.Conv2d(32, 64, 3), # 输出为10*10
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 输出为5*5, 5*5*64 = 1600, Linear的输入,_fdim
nn.Dropout2d(0.2),
)
self._init_params()
self._fdim = 1600
Expand Down Expand Up @@ -53,8 +57,9 @@ def get_backbone_parameters(self) -> list:

def forward_backbone(self, x):
feature = self.feature_extractor(x)
feature = torch.flatten(feature, 1)
return feature

def cnn(num_classes: int=4, **kwargs):
model = CNN(num_classes=num_classes, **kwargs)
return model
return model.cuda()
4 changes: 2 additions & 2 deletions DANN/models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, in_feature: int, hidden_size: int, out_feature: int = 1):
self.relu2 = nn.ReLU()
self.drop1 = nn.Dropout(0.5)
self.drop2 = nn.Dropout(0.5)
self._init_params()
# self._init_params()

def _init_params(self):
for layer in self.modules():
Expand All @@ -33,7 +33,7 @@ def _init_params(self):
nn.init.zeros_(layer.bias)

def get_parameters(self):
return [{'params': self.parameters()}]
return [{'params': self.parameters(), 'lr_mult':1}]

def forward(self, x, coeff: float):
x.register_hook(grl_hook(coeff)) # GRL
Expand Down
28 changes: 14 additions & 14 deletions DANN/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,12 @@ def model_parameters(self):
f'{(sum(p.numel() for p in v.parameters()) / 1e6):.2f}M')

def build_optim(self, parameter_list: list):
self.optimizer = optim.SGD(
self.optimizer = optim.Adam(
parameter_list,
lr=self.cfg.TRAIN.LR,
momentum=self.cfg.OPTIM.MOMENTUM,
# momentum=self.cfg.OPTIM.MOMENTUM,
weight_decay=self.cfg.OPTIM.WEIGHT_DECAY,
nesterov=True
# nesterov=True
)
self.lr_scheduler = inv_lr_scheduler

Expand All @@ -92,9 +92,9 @@ def resume_from_ckpt(self):
for k, v in self.registed_models.items():
v.load_state_dict(ckpt[k])
self.optimizer.load_state_dict(ckpt['optimizer'])
self.start_iter = ckpt['ite']
self.start_iter = ckpt['iter']
self.best_acc = ckpt['best_acc']
logging.info(f'> loading ckpt from {last_ckpt} | ite: {self.start_iter} | best_acc: {self.best_acc:.3f}')
logging.info(f'> loading ckpt from {last_ckpt} | iter: {self.start_iter} | best_acc: {self.best_acc:.3f}')
else:
logging.info('--> training from scratch')

Expand Down Expand Up @@ -128,7 +128,7 @@ def train(self):
self.one_step(data_src, data_tar)
if self.iter % self.cfg.TRAIN.SAVE_FREQ == 0 and self.iter != 0:
self.save_model(is_best=False, snap=True)
return self.best_acc
@abstractmethod
def one_step(self, data_src, data_tar):
pass
Expand Down Expand Up @@ -163,15 +163,15 @@ def test(self):
# save results
log_dict = {
'I': self.iter,
'src_acc': src_acc,
'tar_acc': tar_acc,
'best_acc': self.best_acc
'src_acc': round(src_acc, 3),
'tar_acc': round(tar_acc, 3),
'best_acc': round(self.best_acc, 3)
}
write_log(self.cfg.TRAIN.OUTPUT_RESFILE, log_dict)

# tensorboard
self.tb_writer.add_scalar('tar_acc', tar_acc, self.iter)
self.tb_writer.add_scalar('src_acc', src_acc, self.iter)
# self.tb_writer.add_scalar('tar_acc', tar_acc, self.iter)
# self.tb_writer.add_scalar('src_acc', src_acc, self.iter)

self.save_model(is_best=is_best)

Expand All @@ -184,7 +184,7 @@ def test_func(self, loader, model):
if i % print_freq == print_freq - 1:
logging.info(' I: {}/{} | acc: {:.3f}'.format(i, len(loader), accs.avg))
data = iter_test.__next__()
inputs, labels = data['image'].cuda(), data['label'].cuda()
inputs, labels = data[0].cuda(), data[1].cuda()
outputs_all = model(inputs) # [f, y, ...]
outputs = outputs_all[1]

Expand All @@ -196,9 +196,9 @@ def test_func(self, loader, model):
def save_model(self, is_best=False, snap=False):
data_dict = {
'optimizer': self.optimizer.state_dict(),
'ite': self.iter,
'iter': self.iter,
'best_acc': self.best_acc
}
for k, v in self.registed_models.items():
data_dict.update({k: v.state_dict()})
save_model(self.cfg.TRAIN.OUTPUT_CKPT, data_dict=data_dict, ite=self.ite, is_best=is_best, snap=snap)
save_model(self.cfg.TRAIN.OUTPUT_CKPT, data_dict=data_dict, ite=self.iter, is_best=is_best, snap=snap)
Loading

0 comments on commit 67ed71f

Please sign in to comment.