Skip to content

Commit

Permalink
add lock for SynchronizedSeedDataset; add additional os level close s…
Browse files Browse the repository at this point in the history
…tderr for tests that launch failing process (#4463)
  • Loading branch information
ssnl authored and soumith committed Jan 4, 2018
1 parent cc70a33 commit cc9dc3f
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import sys
import os
import ctypes
import torch
import time
Expand Down Expand Up @@ -153,15 +154,16 @@ class SynchronizedSeedDataset(Dataset):

def __init__(self, size, num_workers):
assert size >= num_workers
self.count = multiprocessing.Value('i', 0)
self.count = multiprocessing.Value('i', 0, lock=True)
self.barrier = multiprocessing.Semaphore(0)
self.num_workers = num_workers
self.size = size

def __getitem__(self, idx):
self.count.value += 1
if self.count.value == self.num_workers:
self.barrier.release()
with self.count.get_lock():
self.count.value += 1
if self.count.value == self.num_workers:
self.barrier.release()
self.barrier.acquire()
self.barrier.release()
return torch.initial_seed()
Expand Down Expand Up @@ -249,6 +251,7 @@ def test_multiple_dataloaders(self):
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
def test_segfault(self):
def _test_segfault():
os.close(sys.stderr.fileno())
sys.stderr.close()
dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
Expand All @@ -266,6 +269,7 @@ def _test_segfault():
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
def test_timeout(self):
def _test_timeout():
os.close(sys.stderr.fileno())
sys.stderr.close()
dataset = SleepDataset(10, 10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
Expand Down

0 comments on commit cc9dc3f

Please sign in to comment.