Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
suragnair committed Jul 18, 2017
1 parent ddf9f16 commit 0f40922
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 59 deletions.
53 changes: 27 additions & 26 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def sample(self, num_samples, start_letter=0):
- samples: num_samples x max_seq_length (a sampled sequence in each row)
"""

samples = torch.zeros(num_samples, self.max_seq_len)
samples = torch.zeros(num_samples, self.max_seq_len).type(torch.LongTensor)

h = self.init_hidden(num_samples)
inp = autograd.Variable(torch.LongTensor([start_letter]*num_samples))
Expand Down
8 changes: 4 additions & 4 deletions helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def prepare_discriminator_data(pos_samples, neg_samples, gpu=False):
- target: pos_size + neg_size (boolean 1/0)
"""

inp = torch.cat((pos_samples, neg_samples), 0)
inp = torch.cat((pos_samples, neg_samples), 0).type(torch.LongTensor)
target = torch.ones(pos_samples.size()[0] + neg_samples.size()[0])
target[pos_samples.size()[0]:] = 0

Expand All @@ -53,7 +53,7 @@ def prepare_discriminator_data(pos_samples, neg_samples, gpu=False):
target = target[perm]
inp = inp[perm]

inp = Variable(inp).type(torch.LongTensor)
inp = Variable(inp)
target = Variable(target)

if gpu:
Expand All @@ -70,7 +70,7 @@ def batchwise_sample(gen, num_samples, batch_size):
"""

samples = []
for i in range(ceil(num_samples/float(batch_size))):
for i in range(int(ceil(num_samples/float(batch_size)))):
samples.append(gen.sample(batch_size))

return torch.cat(samples, 0)
return torch.cat(samples, 0)[:num_samples]
62 changes: 34 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,44 @@
# MAIN
oracle = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
oracle.load_state_dict(torch.load(oracle_state_dict_path))
oracle_samples = torch.load(oracle_samples_path)
oracle_samples = torch.load(oracle_samples_path).type(torch.LongTensor)

gen = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
dis = discriminator.Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN)
dis = discriminator.Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)


def train_discriminator(discriminator, dis_opt, real_data_samples, generator, d_steps, epochs):
for d_step in range(d_steps):
s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
dis_inp, dis_target = helpers.prepare_discriminator_data(real_data_samples, s, gpu=CUDA)
for epoch in range(epochs):
print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
sys.stdout.flush()
total_loss = 0

for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE]
dis_opt.zero_grad()
loss = discriminator.batchBCELoss(inp, target)
loss.backward()
dis_opt.step()

total_loss += loss.data[0]

if (i / BATCH_SIZE) % ceil(ceil(2 * POS_NEG_SAMPLES / float(
BATCH_SIZE)) / 10.) == 0: # roughly every 10% of an epoch
print('.', end='')
sys.stdout.flush()

total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
print(' average_loss = %.4f' % total_loss)


if CUDA:
oracle = oracle.cuda()
gen = gen.cuda()
dis = dis.cuda()
oracle_samples = oracle_samples.cuda()

# GENERATOR MLE TRAINING
print('Starting Generator MLE Training...')
Expand Down Expand Up @@ -68,7 +97,7 @@
# total_loss = total_loss/ceil(POS_NEG_SAMPLES/float(BATCH_SIZE))/MAX_SEQ_LEN
#
# # sample from generator and compute oracle NLL
# s = gen.sample(1000)
# s = gen.sample(POS_NEG_SAMPLES/10)
# inp, target = helpers.prepare_generator_data(s, start_letter=START_LETTER, gpu=CUDA)
# oracle_loss = oracle.batchNLLLoss(inp, target)/MAX_SEQ_LEN
#
Expand All @@ -79,30 +108,7 @@
# TRAIN DISCRIMINATOR
print('\nStarting Discriminator Training...')
dis_optimizer = optim.Adam(dis.parameters())

for d_step in range(50):
s = helpers.batchwise_sample(gen, POS_NEG_SAMPLES, BATCH_SIZE)
dis_inp, dis_target = helpers.prepare_discriminator_data(oracle_samples, s, gpu=CUDA)
for epoch in range(3):
print('d-step %d epoch %d : ' % (d_step+1, epoch+1), end='')
sys.stdout.flush()
total_loss = 0

for i in range(0, 2*POS_NEG_SAMPLES, BATCH_SIZE):
inp, target = dis_inp[i:i+BATCH_SIZE], dis_target[i:i+BATCH_SIZE]
dis_optimizer.zero_grad()
loss = dis.batchBCELoss(inp, target)
loss.backward()
dis_optimizer.step()

total_loss += loss.data[0]

if (i / BATCH_SIZE) % ceil(ceil(2*POS_NEG_SAMPLES / float(BATCH_SIZE)) / 10.) == 0: # roughly every 10% of an epoch
print('.', end='')
sys.stdout.flush()

total_loss = total_loss/ceil(2*POS_NEG_SAMPLES/float(BATCH_SIZE))
print(' average_loss = %.4f' % total_loss)
train_discriminator(dis, dis_optimizer, oracle_samples, gen, 50, 3)

# ADVERSARIAL TRAINING

print('\nStarting Adersarial Training...')

0 comments on commit 0f40922

Please sign in to comment.