Skip to content

Commit

Permalink
[Chunk Teacher] Reset bug fix (facebookresearch#3244)
Browse files Browse the repository at this point in the history
* chunk teacher problemz

* f

* bug

* move counter to enqueue chunsk

* typo

* check for output being none

* tot_samples_loaded has to be keyed on teh reset count

* also only reset when we share validation agent

* added a test

* lint

* test testing

* update test

* init

* lkasjdf

* alksdfj
  • Loading branch information
Emily Dinan authored Nov 5, 2020
1 parent 3a908bc commit 5429dae
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 29 deletions.
78 changes: 50 additions & 28 deletions parlai/core/teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
``FixedDialogTeacher``, and ``DialogData``/``StreamDialogData``, data
structures for accessing textual dialog data and utilized by ``DialogTeacher``
"""
import copy
from typing import List, Tuple, Optional, TypeVar

from parlai.core.agents import Agent, create_agent_from_shared
from parlai.core.image_featurizers import ImageLoader
from parlai.core.loader import load_teacher_module
Expand All @@ -44,23 +41,25 @@
from parlai.core.opt import Opt
from parlai.utils.conversations import Conversations
from parlai.utils.data import DatatypeHelper
from parlai.utils.misc import AttrDict, no_lock, str_to_msg, warn_once
from parlai.utils.misc import AttrDict, no_lock, str_to_msg, warn_once, SimpleCounter
from parlai.utils.distributed import get_rank, num_workers, is_distributed
import parlai.utils.torch as torch_utils
import parlai.utils.logging as logging
from parlai.utils.io import PathManager

from abc import ABC, abstractmethod

import argparse
from collections import defaultdict
import concurrent.futures
from threading import Thread
import copy
import json
import os
import queue
import random
from threading import Thread
import time
import os
import torch
import json
import argparse
from typing import List, Tuple, Optional, TypeVar


ChunkOutput = TypeVar('ChunkOutput')
Expand Down Expand Up @@ -2135,19 +2134,21 @@ def __init__(self, opt, shared=None):
self.is_root_teacher = False
self.chunks = shared['chunks']
self.samples = shared['samples']
self.reset_counter = shared['reset_counter']
self.rng = shared['rng']
else:
self.is_root_teacher = True
self.samples = queue.Queue(maxsize=self.buffersize)
self.chunks = queue.Queue()
self.reset_counter = SimpleCounter() # track no. of resets
if self.is_train:
# TODO: possible need a fixed seed here in the future
self.rng = random.Random()
else:
self.rng = random.Random(42)
self._enqueue_chunks()
# launch queue loader on the main thread
self.tot_samples_loaded = 0
self.tot_samples_loaded = defaultdict(int)
if not opt.get("no_auto_enqueues", False):
self._enqueue_request()

Expand Down Expand Up @@ -2201,6 +2202,7 @@ def share(self):
shared = super().share()
shared['samples'] = self.samples
shared['chunks'] = self.chunks
shared['reset_counter'] = self.reset_counter
shared['rng'] = self.rng
return shared

Expand Down Expand Up @@ -2234,17 +2236,24 @@ def receive_data(self, future):
Load data into self.samples until buffersize is reached.
"""
data = future.result()
if data is None:
output = future.result()
if output is None:
return
while data:
chunk_output, chunk_reset_cnt = output
if chunk_output is None:
return
while chunk_output:
# self.samples is a queue with maxsize
# self.buffersize, so will block if the
# buffer gets full
sample = data.pop(0)
if self.is_train or self.tot_samples_loaded % self.dws == self.rank:
self.samples.put(sample)
self.tot_samples_loaded += 1
sample = chunk_output.pop(0)
if (
self.is_train
or self.tot_samples_loaded[chunk_reset_cnt] % self.dws == self.rank
):
# log the reset count at the time the chunk was queued
self.samples.put((sample, chunk_reset_cnt))
self.tot_samples_loaded[chunk_reset_cnt] += 1
# and start loading the next chunk
self._enqueue_request()

Expand All @@ -2254,8 +2263,10 @@ def _enqueue_chunks(self):
"""
if self.is_train:
self.rng.shuffle(self.fold_chunks)
# save the reset count at the time a chunk was queued
reset_cnt = self.reset_counter.value()
for c in self.fold_chunks:
self.chunks.put(c)
self.chunks.put((c, reset_cnt))

@abstractmethod
def load_from_chunk(self, chunk_idx: int) -> List[ChunkOutput]:
Expand Down Expand Up @@ -2287,21 +2298,32 @@ def get_chunk(self):
# if we're in valid/test, we need to actually signal the end
return None

next_chunk = self.chunks.get()
next_chunk, chunk_reset_cnt = self.chunks.get()
# abstract method `load_from_chunk` returns a list of tuples
output = self.load_from_chunk(next_chunk)

if self.is_train:
# randomize the samples
random.Random().shuffle(output)
return output
return output, chunk_reset_cnt

def get(self, episode_idx, entry_idx=0):
curr_reset_cnt = self.reset_counter.value()
if self._episode_done:
# Get the next episode or example
queue_output = self.samples.get()
if queue_output is None:
output = self.samples.get()
if output is None:
return None
queue_output, reset_cnt = output
stale_exs = 0
while curr_reset_cnt > reset_cnt:
stale_exs += 1
output = self.samples.get()
if output is None:
return None
queue_output, reset_cnt = output
if stale_exs > 0:
logging.info(f"Removed {stale_exs} stale examples from the queue.")

# Update the last queue output in the case
# of multi-turn episodes
Expand All @@ -2314,21 +2336,21 @@ def get(self, episode_idx, entry_idx=0):
return msg

def _drain(self, q):
while not q.empty():
try:
q.get()
except queue.Empty:
return
with q.mutex:
q.queue.clear()

def reset(self):
super().reset()
if self.is_root_teacher:
self.reset_counter.increment()
# drain the queues and refill the chunk queue with a new epoch.
# additionally, we have to relaunch the loader
self._drain(self.samples)
self._drain(self.chunks)
self._enqueue_chunks()
self.tot_samples_loaded = 0 # reset the count of samples loaded
self.tot_samples_loaded = defaultdict(
int
) # reset the count of samples loaded
self._enqueue_request()


Expand Down
3 changes: 2 additions & 1 deletion parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,8 @@ def _run_single_eval(self, opt, valid_world, max_exs):
cnt = valid_world.report().get('exs') or 0

valid_report = valid_world.report()
valid_world.reset() # make sure world doesn't remember valid data
if opt.get('validation_share_agent', False):
valid_world.reset() # make sure world doesn't remember valid data

return valid_report

Expand Down
16 changes: 16 additions & 0 deletions parlai/tasks/integration_tests/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import json
from abc import ABC
from typing import Tuple, List
import time
from parlai.utils.io import PathManager

# default parameters
Expand Down Expand Up @@ -492,6 +493,21 @@ def get_num_samples(self, opt) -> Tuple[int, int]:
return NUM_TEST, NUM_TEST


class ChunkyUniqueSlowTeacher(ChunkyTeacher):
"""
Unique examples that load slowly.
"""

def load_from_chunk(self, chunk_idx: int):
output = []
for i in range(10):
text = str(i + chunk_idx * 10)
resp = str(i + chunk_idx * 10)
output.append((text, resp))
time.sleep(0.1)
return output


class ShortFixedTeacher(FixedDialogCandidateTeacher):
"""
Fixed Dialog Candidate teacher with only 10 training examples.
Expand Down
15 changes: 15 additions & 0 deletions parlai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,21 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
pass


class SimpleCounter:
"""
Simple counter object.
"""

def __init__(self, value=0):
self.val = value

def increment(self, value=1):
self.val += value

def value(self):
return self.val


def _report_sort_key(report_key: str) -> Tuple[str, str]:
"""
Sorting name for reports.
Expand Down
50 changes: 50 additions & 0 deletions tests/test_teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from parlai.core.opt import Opt
import parlai.utils.logging as logging
from parlai.utils.io import PathManager
from parlai.core.loader import register_agent
from collections import defaultdict
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent


class TestAbstractImageTeacher(unittest.TestCase):
Expand Down Expand Up @@ -249,6 +252,33 @@ def test_both_label(self):
self.assertEqual(num_episodes, 2)


@register_agent("unique_examples")
class UniqueExamplesAgent(RepeatLabelAgent):
"""
Simple agent which asserts that it has only seen unique examples.
Useful for debugging. Inherits from RepeatLabelAgent.
"""

def __init__(self, opt, shared=None):
super().__init__(opt)
self.unique_examples = defaultdict(int)

def reset(self):
super().reset()
self.unique_examples = defaultdict(int)

def act(self):
obs = self.observation
text = obs.get('text')
if text in self.unique_examples:
raise RuntimeError(f'Already saw example: {text}')
else:
self.unique_examples[text] += 1

return super().act()


class TestChunkTeacher(unittest.TestCase):
"""
Test chunked teacher.
Expand Down Expand Up @@ -314,6 +344,26 @@ def test_stream_only(self):
test_datatype='test',
)

def test_slow_loading(self):
"""
Test that a slow loading teacher sees the right examples during validation.
"""
with testing_utils.tempdir() as tmpdir:
model_file = os.path.join(tmpdir, 'model')
valid, test = testing_utils.train_model(
dict(
task='integration_tests:chunky_unique_slow',
model='unique_examples',
model_file=model_file,
datatype='train:stream',
num_epochs=0.5,
validation_every_n_epochs=0.1,
batchsize=1,
dynamic_batching='full',
dict_maxexs=0,
)
)


class CustomEvaluationTeacher(DialogTeacher):
def __init__(self, opt, shared=None):
Expand Down

0 comments on commit 5429dae

Please sign in to comment.