Skip to content

Commit

Permalink
Clean up pool properly
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Sep 22, 2017
1 parent 4b9d9af commit 5cb7f9a
Showing 1 changed file with 42 additions and 41 deletions.
83 changes: 42 additions & 41 deletions deepchem/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,50 +644,51 @@ def iterate(dataset):
shard_perm = np.random.permutation(num_shards)
else:
shard_perm = np.arange(num_shards)
pool = Pool(1)
next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
for i in range(num_shards):
X, y, w, ids = next_shard.get()
if i < num_shards - 1:
next_shard = pool.apply_async(dataset.get_shard, (shard_perm[i + 1],))
n_samples = X.shape[0]
# TODO(rbharath): This happens in tests sometimes, but don't understand why?
# Handle edge case.
if n_samples == 0:
continue
if not deterministic:
sample_perm = np.random.permutation(n_samples)
else:
sample_perm = np.arange(n_samples)
if batch_size is None:
shard_batch_size = n_samples
else:
shard_batch_size = batch_size
interval_points = np.linspace(
0,
n_samples,
np.ceil(float(n_samples) / shard_batch_size) + 1,
dtype=int)
for j in range(len(interval_points) - 1):
indices = range(interval_points[j], interval_points[j + 1])
perm_indices = sample_perm[indices]
X_batch = X[perm_indices]

if y is not None:
y_batch = y[perm_indices]
with Pool(1) as pool:
next_shard = pool.apply_async(dataset.get_shard, (shard_perm[0],))
for i in range(num_shards):
X, y, w, ids = next_shard.get()
if i < num_shards - 1:
next_shard = pool.apply_async(dataset.get_shard,
(shard_perm[i + 1],))
n_samples = X.shape[0]
# TODO(rbharath): This happens in tests sometimes, but don't understand why?
# Handle edge case.
if n_samples == 0:
continue
if not deterministic:
sample_perm = np.random.permutation(n_samples)
else:
y_batch = None

if w is not None:
w_batch = w[perm_indices]
sample_perm = np.arange(n_samples)
if batch_size is None:
shard_batch_size = n_samples
else:
w_batch = None
shard_batch_size = batch_size
interval_points = np.linspace(
0,
n_samples,
np.ceil(float(n_samples) / shard_batch_size) + 1,
dtype=int)
for j in range(len(interval_points) - 1):
indices = range(interval_points[j], interval_points[j + 1])
perm_indices = sample_perm[indices]
X_batch = X[perm_indices]

if y is not None:
y_batch = y[perm_indices]
else:
y_batch = None

ids_batch = ids[perm_indices]
if pad_batches:
(X_batch, y_batch, w_batch, ids_batch) = pad_batch(
shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
yield (X_batch, y_batch, w_batch, ids_batch)
if w is not None:
w_batch = w[perm_indices]
else:
w_batch = None

ids_batch = ids[perm_indices]
if pad_batches:
(X_batch, y_batch, w_batch, ids_batch) = pad_batch(
shard_batch_size, X_batch, y_batch, w_batch, ids_batch)
yield (X_batch, y_batch, w_batch, ids_batch)

return iterate(self)

Expand Down

0 comments on commit 5cb7f9a

Please sign in to comment.