Skip to content

Commit

Permalink
FIX: num_workers = 1 for dots
Browse files Browse the repository at this point in the history
  • Loading branch information
fursovia committed Oct 21, 2019
1 parent e52b637 commit be2c6de
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions glf/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _create_dots(root, is_train, num_objects=3, num_samples=10000):
images = np.stack(images, axis=0).transpose((0, 3, 1, 2))
np.save(path_fo_file, images)
else:
print('Loading dots...')
images = np.load(path_fo_file)

dataset = TensorDataset(torch.from_numpy(images), torch.from_numpy(np.array([num_objects] * num_samples)))
Expand Down
4 changes: 2 additions & 2 deletions glf/options/train/train_glf_original_dots.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ datasets:
name: dots
dataroot: datasets/dots
use_shuffle: true
n_workers: 8 # per GPU
n_workers: 1 # per GPU
batch_size: 256
val:
name: dots
dataroot: datasets/dots
n_workers: 8 # per GPU
n_workers: 1 # per GPU
batch_size: 32

#### network structures
Expand Down

0 comments on commit be2c6de

Please sign in to comment.