Skip to content

Commit

Permalink
3080 setup
Browse files Browse the repository at this point in the history
  • Loading branch information
hwnam831 committed Jan 17, 2023
1 parent 805d619 commit cc2a09b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, c_in, c_mid, dilation=1):

def forward(self, x):
out = self.block(x)
out += x
out = out + x
return out

class CNNModel(nn.Module):
Expand Down
24 changes: 13 additions & 11 deletions Util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_parser():
parser.add_argument(
"--gen",
type=str,
choices=['gau', 'sin', 'adv', 'off', 'cnn', 'rnn', 'mlp'],
choices=['gau', 'sin', 'adv', 'off', 'cnn', 'rnn', 'mlp', 'rnn3'],
default='adv',
help='Generator choices')
parser.add_argument(
Expand Down Expand Up @@ -134,12 +134,12 @@ def __init__(self, args):
self.both = False
if args.victim == 'both': # 'both'
self.both=True
file_prefix='core4ToSlice3'
trainset = RingDataset(file_prefix+'_train.pkl', threshold=args.window)
testset = RingDataset(file_prefix+'_test.pkl', threshold=args.window, std=trainset.std)
valset = RingDataset(file_prefix+'_valid.pkl', threshold=args.window, std=trainset.std)
file_prefix='rsa_noise'
trainset = EDDSADataset(file_prefix+'_train.pkl')
testset = EDDSADataset(file_prefix+'_test.pkl', std=trainset.std, window=trainset.window)
valset = EDDSADataset(file_prefix+'_valid.pkl', std=trainset.std, window=trainset.window)
self.window = args.window
file_prefix2='eddsa'
file_prefix2='eddsa_noise'
trainset2 = EDDSADataset(file_prefix2+'_train.pkl', std=trainset.std)
testset2 = EDDSADataset(file_prefix2+'_test.pkl', std=trainset.std)
valset2 = EDDSADataset(file_prefix2+'_valid.pkl', std=trainset.std)
Expand All @@ -151,13 +151,13 @@ def __init__(self, args):
valset = EDDSADataset(file_prefix+'_valid.pkl', std=trainset.std, window=trainset.window)
self.window = trainset.window
elif args.victim == 'rsa':
file_prefix='core4ToSlice3'
trainset = RingDataset(file_prefix+'_train.pkl', threshold=args.window)
testset = RingDataset(file_prefix+'_test.pkl', threshold=args.window, std=trainset.std)
valset = RingDataset(file_prefix+'_valid.pkl', threshold=args.window, std=trainset.std)
file_prefix='rsa_noise'
trainset = EDDSADataset(file_prefix+'_train.pkl')
testset = EDDSADataset(file_prefix+'_test.pkl', std=trainset.std, window=trainset.window)
valset = EDDSADataset(file_prefix+'_valid.pkl', std=trainset.std, window=trainset.window)
self.window = args.window
elif args.victim == 'eddsa':
file_prefix='eddsa'
file_prefix='eddsa_noise'
trainset = EDDSADataset(file_prefix+'_train.pkl')
testset = EDDSADataset(file_prefix+'_test.pkl', std=trainset.std)
valset = EDDSADataset(file_prefix+'_valid.pkl', std=trainset.std)
Expand Down Expand Up @@ -185,6 +185,8 @@ def __init__(self, args):
self.gen=Models.CNNGenerator(self.window, scale=0.25, dim=args.dim).cuda()
elif args.gen == 'adv':
self.gen=Models.RNNGenerator2(self.window, scale=0.25, dim=args.dim).cuda()
elif args.gen == 'rnn3':
self.gen=Models.RNNGenerator3(self.window, scale=0.25, dim=args.dim).cuda()
elif args.gen == 'rnn':
self.gen=Models.RNNGenerator(self.window, scale=0.25, dim=args.dim).cuda()
elif args.gen == 'mlp':
Expand Down

0 comments on commit cc2a09b

Please sign in to comment.