Skip to content

Commit

Permalink
Image Loading (facebookresearch#63)
Browse files Browse the repository at this point in the history
first draft of loading images in dialog data
  • Loading branch information
alexholdenmiller authored May 10, 2017
1 parent 7597af2 commit 1854066
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
1 change: 0 additions & 1 deletion parlai/core/build_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def download(path, url, redownload=True):
download tar file again if it is present (default true).
"""
if redownload or not os.path.isfile(path):
# only download if file hasn't been downloaded already
filename = wget.download(url, out=path)
print() # wget prints download status, without newline

Expand Down
80 changes: 56 additions & 24 deletions parlai/core/dialog_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .thread_utils import SharedTable
from .metrics import Metrics

import copy
from PIL import Image
import random
import sys
import time
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, opt, shared=None):
if shared and shared.get('data'):
self.data = shared['data']
else:
self.data = DialogData(self.setup_data(opt['datafile']),
self.data = DialogData(opt, self.setup_data(opt['datafile']),
cands=self.label_candidates())

# for ordered data in batch mode (especially, for validation and
Expand Down Expand Up @@ -153,12 +153,14 @@ class DialogData(object):
(x, ...), new_episode?
Where...
x is a query and possibly context
- x is a query and possibly context
... can contain additional fields, specifically
y is an iterable of label(s) for that query
r is the str reward for getting that query correct
c is an iterable of label candidates that the student can choose from
new_episode? is a boolean value specifying whether that example is the start
- y is an iterable of label(s) for that query
- r is the str reward for getting that query correct
- c is an iterable of label candidates that the student can choose from
- i is a str path to an image on disk, which will be loaded by the data
class at request-time. should always point to the raw image file.
- new_episode? is a boolean value specifying whether that example is the start
of a new episode. If you don't use episodes set this to True every time.
cands can be set to provide a list of candidate labels for every example
Expand All @@ -169,7 +171,11 @@ class DialogData(object):
or randomly when returning examples to the caller.
"""

def __init__(self, data_loader, cands=None):
def __init__(self, opt, data_loader, cands=None):
# self.data is a list of episodes
# each episode is a tuple of entries
# each entry is a tuple of values for the action/observation table
self.opt = opt
self.data = []
self._load(data_loader)
self.cands = None if cands == None else set(sys.intern(c) for c in cands)
Expand All @@ -191,41 +197,48 @@ def _load(self, data_loader):
episode = []
last_cands = None
for entry, new in data_loader:
if new:
if len(episode) > 0:
self.data.append(tuple(episode))
episode = []
last_cands = None
if new and len(episode) > 0:
self.data.append(tuple(episode))
episode = []
last_cands = None

# intern all strings so we don't store them more than once
new_entry = []
if len(entry) > 0:
# process text
# process text if available
if entry[0] is not None:
new_entry.append(sys.intern(entry[0]))
else:
new_entry.append(None)
if len(entry) > 1:
# process labels
# process labels if available
if entry[1] is not None:
new_entry.append(tuple(sys.intern(e) for e in entry[1]))
else:
new_entry.append(None)
if len(entry) > 2:
# process reward
# process reward if available
if entry[2] is not None:
new_entry.append(sys.intern(entry[2]))
else:
new_entry.append(None)
if len(entry) > 3 and entry[3] is not None:
# process label candidates
if last_cands and entry[3] is last_cands:
new_entry.append(
sys.intern('same as last time'))
if len(entry) > 3:
if entry[3] is not None:
# process label candidates if available
if last_cands and entry[3] is last_cands:
# if cands are shared, say "same" so we
# don't store them again
new_entry.append(
sys.intern('same as last time'))
else:
last_cands = entry[3]
new_entry.append(tuple(
sys.intern(e) for e in entry[3]))
else:
last_cands = entry[3]
new_entry.append(tuple(
sys.intern(e) for e in entry[3]))
new_entry.append(None)
if len(entry) > 4 and entry[4] is not None:
new_entry.append(sys.intern(entry[4]))

episode.append(tuple(new_entry))

if len(episode) > 0:
Expand All @@ -252,6 +265,9 @@ def get(self, episode_idx, entry_idx=0):
table['reward'] = entry[2]
if len(entry) > 3:
table['label_candidates'] = entry[3]
if len(entry) > 4 and not opt.get('no_images', False):
table['image'] = load_image(self.opt, entry[4])


if (table.get('labels', None) is not None
and self.cands is not None):
Expand All @@ -273,3 +289,19 @@ def get(self, episode_idx, entry_idx=0):
# last entry in this episode
table['episode_done'] = episode_done
return table, end_of_data

def load_image(opt, path):
if opt.get('no_images', False) or not path:
return None
mode = opt.get('image_preprocessor', 'raw')
if mode != 'raw':
prepath, imagefn = os.path.split(path)
new_path = os.path.join(prepath, mode, imagefn)
if not os.path.isfile(new_path):
raise NotImplementedError('image preprocessing mode' +
'{} not supported yet'.format(mode))
else:
return Image.open(path)
else:
# return the image
return Image.open(path).convert('RGB')
3 changes: 3 additions & 0 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def add_parlai_args(self):
help='choose from: train, train:ordered, valid, test. ' +
'by default: train is random with replacement, ' +
'valid is ordered, test is ordered.')
self.parser.add_argument(
'-ip', '--image_preprocessor', default=None, type=str,
help='image preprocessor to use. default is raw (none).')
self.parser.add_argument(
'-nt', '--numthreads', default=1, type=int,
help='number of threads, e.g. for hogwild')
Expand Down

0 comments on commit 1854066

Please sign in to comment.