Skip to content

Commit

Permalink
better gan compression
Browse files Browse the repository at this point in the history
  • Loading branch information
hwnam831 committed Feb 22, 2021
1 parent 02b06cb commit 866200d
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 67 deletions.
104 changes: 85 additions & 19 deletions Compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
import re
import time
from sklearn import svm

window=32 #this is fixed

Expand Down Expand Up @@ -58,13 +59,18 @@ def get_args():
parser.add_argument(
"--student",
type=int,
default='16',
default='12',
help='student channel dimension')
parser.add_argument(
"--lr",
type=float,
default=2e-5,
default=5e-5,
help='Default learning rate')
parser.add_argument(
"--model_path",
type=str,
default='models',
help='where to find the pth files')

return parser.parse_args()

Expand All @@ -79,6 +85,7 @@ def quantizer(arr, std=8):

if __name__ == '__main__':
args = get_args()
path = args.model_path
dataset = RingDataset.RingDataset(args.file_prefix+'_train.pkl', threshold=args.threshold)
testset = RingDataset.RingDataset(args.file_prefix+'_test.pkl', threshold=args.threshold)
valset = RingDataset.RingDataset(args.file_prefix+'_valid.pkl', threshold=args.threshold)
Expand All @@ -89,8 +96,8 @@ def quantizer(arr, std=8):
valloader = DataLoader(valset, batch_size=args.batch_size, num_workers=4, shuffle=True)

gen=RNNGenerator2(args.threshold, scale=0.25, dim=args.dim, drop=0.0).cuda()
assert os.path.isfile('./models/best_{}_{}.pth'.format('adv', args.dim))
gen.load_state_dict(torch.load('./models/best_{}_{}.pth'.format('adv', args.dim)))
assert os.path.isfile('./'+ path +'/best_{}_{}.pth'.format('adv', args.dim))
gen.load_state_dict(torch.load('./'+ path +'/best_{}_{}.pth'.format('adv', args.dim)))

student=QGRU2(args.threshold, scale=0.25, dim=args.student, drop=0.0).cuda()
distiller = Distiller(args.threshold, args.dim, args.student, lamb_r = 0.1).cuda()
Expand All @@ -109,13 +116,29 @@ def quantizer(arr, std=8):
else:
classifier_test = CNNModel(args.threshold, dim=args.dim).cuda()

disc = Models.CNNDiscriminator(args.threshold, dim=args.dim).cuda() #discriminator
train_x = []
train_y = []
with torch.no_grad():
for x,y in trainloader:
xdata= x.cuda()
shifted = shifter(xdata)
#train classifier
perturb = gen(shifted).view(shifted.size(0),-1)
perturbed_x = xdata[:,31:]+perturb
for p in perturbed_x:
train_x.append(p.cpu().numpy())
for y_i in y:
train_y.append(y_i.item())
clf = svm.SVC(gamma=0.02)
clf.fit(train_x, train_y)

disc = Models.SVMDiscriminator(args.threshold, clf, 0.02).cuda() #discriminator

optim_disc = torch.optim.RMSprop(disc.parameters(), lr=args.lr)
optim_c = torch.optim.Adam(classifier.parameters())
optim_c = torch.optim.Adam(classifier.parameters(), lr=args.lr)
optim_c2 = torch.optim.Adam(classifier_test.parameters(), lr=args.lr)
optim_g = torch.optim.Adam(student.parameters())
optim_d = torch.optim.Adam(distiller.parameters())
optim_d = torch.optim.RMSprop(distiller.parameters())

criterion = nn.CrossEntropyLoss()
warmup = 20
Expand All @@ -130,7 +153,8 @@ def quantizer(arr, std=8):
mloss = 0.0
for x,y in trainloader:
xdata, ydata = x.cuda(), y.cuda()
disc_label = 2*(ydata.float()-0.5) # 1 for ones, -1 for zeros
oneratio = ydata.sum().item()/len(ydata)
disc_label = 2*(ydata.float()-oneratio)
shifted = shifter(xdata)
#train classifier
optim_c.zero_grad()
Expand All @@ -154,8 +178,9 @@ def quantizer(arr, std=8):
loss_c.backward()
optim_c.step()
optim_disc.step()
for p in disc.parameters():
p.data.clamp_(-0.01, 0.01)
disc.clip()
#for p in disc.parameters():
# p.data.clamp_(-0.01, 0.01)
print("Warmup {} \t Distill loss {:.4f}".format(e+1, mloss))
optim_c = torch.optim.Adam(classifier.parameters(), lr=args.lr)
optim_g = torch.optim.Adam(student.parameters(), lr=args.lr)
Expand All @@ -174,7 +199,8 @@ def quantizer(arr, std=8):
trainstart = time.time()
for x,y in trainloader:
xdata, ydata = x.cuda(), y.cuda()
disc_label = 2*(ydata.float()-0.5) # 1 for ones, -1 for zeros
oneratio = ydata.sum().item()/len(ydata)
disc_label = 2*(ydata.float()-oneratio)
shifted = shifter(xdata)
#train classifier
optim_c.zero_grad()
Expand All @@ -191,8 +217,9 @@ def quantizer(arr, std=8):
loss_disc.backward()
optim_c.step()
optim_disc.step()
for p in disc.parameters():
p.data.clamp_(-0.01, 0.01)
disc.clip()
#for p in disc.parameters():
# p.data.clamp_(-0.01, 0.01)
#train student
optim_g.zero_grad()
optim_d.zero_grad()
Expand All @@ -206,7 +233,8 @@ def quantizer(arr, std=8):
loss_disc = -torch.mean(fakes*disc_label)
loss_adv1 = criterion(output, fake_target)

loss = loss_adv1 + loss_comp + loss_disc
#loss = 0.5*loss_adv1 + loss_comp + 0.001*loss_disc
loss = loss_adv1 + 0.002*loss_comp + 0.02*loss_disc

loss.backward()
optim_g.step()
Expand All @@ -215,7 +243,8 @@ def quantizer(arr, std=8):
student.eval()
for x,y in valloader:
xdata, ydata = x.cuda(), y.cuda()
disc_label = 2*(ydata.float()-0.5) # 1 for ones, -1 for zeros
oneratio = ydata.sum().item()/len(ydata)
disc_label = 2*(ydata.float()-oneratio)
shifted = shifter(xdata)
#train classifier
optim_c2.zero_grad()
Expand Down Expand Up @@ -316,13 +345,50 @@ def quantizer(arr, std=8):
macc = float(totcorrect)/totcount
zacc = float(zerocorrect)/zerocount
oacc = float(onecorrect)/onecount
print("epoch {} \t zacc {:.6f}\t oneacc {:.6f}\t loss {:.6f}\t Avg perturb {:.6f}\n".format(e+1, zacc, oacc, mloss, mnorm))
if (e+1)%10 == 0:
print("epoch {} \t zacc {:.6f}\t oneacc {:.6f}\t loss {:.6f}\t Avg perturb {:.6f}\n".format(e+1, zacc, oacc, mloss, mnorm))
if cooldown - e <= 10:
lastacc += macc/10
lastnorm += mnorm/10
print("Last 10 acc: {:.6f}\t perturb: {:.6f}".format(lastacc,lastnorm))

train_x = []
train_y = []
test_x = []
test_y = []
with torch.no_grad():
for x,y in valloader:
xdata= x.cuda()
shifted = shifter(xdata)
#train classifier
perturb = halfstudent(shifted.half()).view(shifted.size(0),-1)
perturb = quantizer(perturb)
perturbed_x = xdata[:,31:]+perturb
for p in perturbed_x:
train_x.append(p.cpu().numpy())
for y_i in y:
train_y.append(y_i.item())
for x,y in testloader:
xdata= x.cuda()
shifted = shifter(xdata)
#train classifier
perturb = halfstudent(shifted.half()).view(shifted.size(0),-1)
perturb = quantizer(perturb)
perturbed_x = xdata[:,31:]+perturb
for p in perturbed_x:
test_x.append(p.cpu().numpy())
for y_i in y:
test_y.append(y_i.item())
clf = svm.SVC(gamma=0.02)
clf.fit(train_x, train_y)
pred_y = clf.predict(test_x)

svmacc = (pred_y == test_y).sum()/len(pred_y)
print("SVM acc: {:.6f}".format(svmacc))
lastacc = max(lastacc, svmacc)

filename = "qgru_{}_{:.3f}_{:.3f}.pth".format(args.student,lastnorm, lastacc)
flist = os.listdir('models')
flist = os.listdir(path)
best = 1.0
rp = re.compile(r"qgru_{}_(\d\.\d+)_(\d\.\d+)\.pth".format(args.student))
for fn in flist:
Expand All @@ -333,5 +399,5 @@ def quantizer(arr, std=8):
best = facc
if lastacc <= best:
print('New best found')
torch.save(student.state_dict(), './models/'+filename)
torch.save(student.state_dict(), './models/'+'best_qgru_{}.pth'.format(args.student))
torch.save(student.state_dict(), './'+ path +'/'+filename)
torch.save(student.state_dict(), './'+ path +'/'+'best_qgru_{}.pth'.format(args.student))
Loading

0 comments on commit 866200d

Please sign in to comment.