Skip to content

Commit

Permalink
detail argument GPU in action-da & bingdingdb
Browse files Browse the repository at this point in the history
  • Loading branch information
xianyuanliu committed Dec 9, 2021
1 parent 466d295 commit 6c2a39e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
12 changes: 6 additions & 6 deletions examples/action_dann_lightn/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def arg_parse():
"""Parsing arguments"""
parser = argparse.ArgumentParser(description="Domain Adversarial Networks on Action Datasets")
parser.add_argument("--cfg", required=True, help="path to config file", type=str)
parser.add_argument("--gpus", default="0", help="gpu id(s) to use", type=str)
parser.add_argument(
"--gpus",
default="0",
help="gpu id(s) to use. None/int(0) for cpu. list[x,y] for xth, yth GPU. str(x) for the first x GPUs. str(-1)/int(-1) for all available GPUs",
)
parser.add_argument("--resume", default="", type=str)
parser.add_argument("--ckpt", default="", help="pre-trained parameters for the model (ckpt files)", type=str)
args = parser.parse_args()
Expand Down Expand Up @@ -64,11 +68,7 @@ def main():

# ---- setup model and logger ----
model, train_params = get_model(cfg, dataset, num_classes)
trainer = pl.Trainer(
logger=False,
resume_from_checkpoint=args.ckpt,
gpus=args.gpus,
)
trainer = pl.Trainer(logger=False, resume_from_checkpoint=args.ckpt, gpus=args.gpus,)

model_test = weights_update(model=model, checkpoint=torch.load(args.ckpt))

Expand Down
6 changes: 5 additions & 1 deletion examples/bindingdb_deepdta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ def arg_parse():
"""Parsing arguments"""
parser = argparse.ArgumentParser(description="DeepDTA on BindingDB dataset")
parser.add_argument("--cfg", required=True, help="path to config file", type=str)
parser.add_argument("--gpus", default="0", help="gpu id(s) to use", type=str)
parser.add_argument(
"--gpus",
default="0",
help="gpu id(s) to use. None/int(0) for cpu. list[x,y] for xth, yth GPU. str(x) for the first x GPUs. str(-1)/int(-1) for all available GPUs",
)
parser.add_argument("--resume", default="", type=str)
args = parser.parse_args()
return args
Expand Down

0 comments on commit 6c2a39e

Please sign in to comment.