Skip to content

Commit

Permalink
(add) wandb in code-to-text run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode committed Feb 24, 2021
1 parent 4ee7187 commit c0a3d97
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions experiment/code-to-text/code/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import sys
import bleu
import wandb
import pickle
import torch
import json
Expand Down Expand Up @@ -159,7 +160,9 @@ def set_seed(seed=42):
def main():
parser = argparse.ArgumentParser()

## Required parameters
## Required parameters
parser.add_argument("--lang", default=None, type=str, required=True,
help="programming language")
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type: e.g. roberta")
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
Expand Down Expand Up @@ -197,7 +200,7 @@ def main():
help="Set this flag if you are using an uncased model.")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")

parser.add_argument("--log_interval", default=10, type=int)
parser.add_argument("--train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--eval_batch_size", default=8, type=int,
Expand All @@ -216,6 +219,8 @@ def main():
help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.")
parser.add_argument("--num_decoder_layer", default=3, type=int,
help="Total number of decoder's layer")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--eval_steps", default=-1, type=int,
Expand All @@ -228,6 +233,11 @@ def main():
help="For distributed training: local_rank")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")

# experiment arguments
parser.add_argument("--name", type=str, required=True,
help="experiment name in showing wandb")

# print arguments
args = parser.parse_args()
logger.info(args)
Expand Down Expand Up @@ -257,7 +267,7 @@ def main():
#budild model
encoder = model_class.from_pretrained(args.model_name_or_path,config=config)
decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
decoder = nn.TransformerDecoder(decoder_layer, num_layers=args.num_decoder_layer)
model=Seq2Seq(encoder=encoder,decoder=decoder,config=config,
beam_size=args.beam_size,max_length=args.max_target_length,
sos_id=tokenizer.cls_token_id,eos_id=tokenizer.sep_token_id)
Expand All @@ -279,6 +289,9 @@ def main():
model = torch.nn.DataParallel(model)

if args.do_train:
wandb.init(project=f"{args.lang}", name=args.name)
wandb.config.update(args)

# Prepare training data loader
train_examples = read_examples(args.train_filename)
train_features = convert_examples_to_features(train_examples, tokenizer,args,stage='train')
Expand Down Expand Up @@ -319,8 +332,8 @@ def main():
dev_dataset={}
nb_tr_examples, nb_tr_steps,tr_loss,global_step,best_bleu,best_loss = 0, 0,0,0,0,1e6
for epoch in range(args.num_train_epochs):
bar = tqdm(train_dataloader,total=len(train_dataloader))
for batch in bar:
# bar = tqdm(train_dataloader,total=len(train_dataloader))
for batch in train_dataloader:
batch = tuple(t.to(device) for t in batch)
source_ids,source_mask,target_ids,target_mask = batch
loss,_,_ = model(source_ids=source_ids,source_mask=source_mask,target_ids=target_ids,target_mask=target_mask)
Expand All @@ -331,7 +344,10 @@ def main():
loss = loss / args.gradient_accumulation_steps
tr_loss += loss.item()
train_loss=round(tr_loss*args.gradient_accumulation_steps/(nb_tr_steps+1),4)
bar.set_description("epoch {} loss {}".format(epoch,train_loss))
# bar.set_description("epoch {} loss {}".format(epoch,train_loss))
if global_step % args.log_interval == 0:
print("epoch {} step {} loss {}".format(epoch, nb_tr_steps, train_loss))
wandb.log({"train_loss": train_loss}, step=global_step)
nb_tr_examples += source_ids.size(0)
nb_tr_steps += 1
loss.backward()
Expand Down Expand Up @@ -386,7 +402,8 @@ def main():
'train_loss': round(train_loss,5)}
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
logger.info(" "+"*"*20)
wandb.log({'eval_ppl': round(np.exp(eval_loss), 5)}, step=global_step)
logger.info(" "+"*"*20)

#save last checkpoint
last_output_dir = os.path.join(args.output_dir, 'checkpoint-last')
Expand All @@ -397,6 +414,7 @@ def main():
torch.save(model_to_save.state_dict(), output_model_file)
if eval_loss<best_loss:
logger.info(" Best ppl:%s",round(np.exp(eval_loss),5))
wandb.run.summary["eval_best_ppl"] = round(np.exp(eval_loss), 5)
logger.info(" "+"*"*20)
best_loss=eval_loss
# Save best checkpoint for best ppl
Expand Down Expand Up @@ -450,9 +468,11 @@ def main():
(goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(args.output_dir, "dev.gold"))
dev_bleu=round(bleu.bleuFromMaps(goldMap, predictionMap)[0],2)
logger.info(" %s = %s "%("bleu-4",str(dev_bleu)))
wandb.log({"eval_bleu_4": dev_bleu}, step=global_step)
logger.info(" "+"*"*20)
if dev_bleu>best_bleu:
logger.info(" Best bleu:%s",dev_bleu)
wandb.run.summary["eval_best_bleu_4"] = dev_bleu
logger.info(" "+"*"*20)
best_bleu=dev_bleu
# Save best checkpoint for best bleu
Expand Down

0 comments on commit c0a3d97

Please sign in to comment.