Skip to content

Commit

Permalink
Add --prefetch-factor option in trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
nagadomi committed Feb 7, 2023
1 parent fd8f8d1 commit 42eecf6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions imagenet/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_dataloader(self, type):
shuffle=False,
pin_memory=True,
num_workers=self.args.num_workers,
prefetch_factor=self.args.prefetch_factor,
drop_last=True)
return loader
else:
Expand All @@ -45,6 +46,7 @@ def create_dataloader(self, type):
shuffle=False,
pin_memory=True,
num_workers=self.args.num_workers,
prefetch_factor=self.args.prefetch_factor,
drop_last=False)
return loader

Expand Down
2 changes: 2 additions & 0 deletions nunif/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def create_trainer_default_parser():
help="momentum for sgd")
parser.add_argument("--num-workers", type=int, default=num_workers,
help="number of worker processes for data loader")
parser.add_argument("--prefetch-factor", type=int, default=4,
help="number of batches loaded in advance by each worker")
parser.add_argument("--max-epoch", type=int, default=200,
help="max epoch")
parser.add_argument("--gpu", type=int, nargs="+", default=[0],
Expand Down
2 changes: 2 additions & 0 deletions waifu2x/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def create_dataloader(self, type):
pin_memory=True,
sampler=dataset.sampler(),
num_workers=self.args.num_workers,
prefetch_factor=self.args.prefetch_factor,
drop_last=True)
elif type == "eval":
dataset = Waifu2xDataset(
Expand All @@ -94,6 +95,7 @@ def create_dataloader(self, type):
worker_init_fn=dataset.worker_init,
shuffle=False,
num_workers=self.args.num_workers,
prefetch_factor=self.args.prefetch_factor,
drop_last=False)

def create_env(self):
Expand Down

0 comments on commit 42eecf6

Please sign in to comment.