Skip to content

Commit

Permalink
Merge pull request facebookresearch#1221 from bhancock8/torch_ranker
Browse files Browse the repository at this point in the history
[wip] TorchRankerAgent abstract class
  • Loading branch information
bhancock8 authored Oct 23, 2018
2 parents bc3857b + f978597 commit 38a8c16
Show file tree
Hide file tree
Showing 7 changed files with 486 additions and 270 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ cache: pip
python:
- 3.6
install:
- pip install --upgrade pip
- pip install flake8
- python setup.py develop
- pip install gitpython # needed for some tests
- pip install regex scipy scikit-learn pexpect # tfidf-retriever dependencies
# install pytorch non-cuda version
- pip install torch torchvision
- pip install torch torchvision --progress-bar off
script:
- ./tests/lint_changed.sh
- python setup.py test -s tests.suites.travis -q
Expand Down
286 changes: 21 additions & 265 deletions parlai/agents/memnn/memnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,16 @@
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

from parlai.core.torch_agent import TorchAgent, Output
from parlai.core.thread_utils import SharedTable
from parlai.core.utils import round_sigfigs, padded_3d, padded_tensor
from functools import lru_cache

import torch
from torch import nn

from functools import lru_cache
import os
from parlai.core.torch_ranker_agent import TorchRankerAgent

from .modules import MemNN, opt_to_kwargs


class MemnnAgent(TorchAgent):
class MemnnAgent(TorchRankerAgent):
"""Memory Network agent.
Tips:
Expand Down Expand Up @@ -51,7 +47,7 @@ def add_cmdline_args(argparser):
arg_group.add_argument(
'-pe', '--position-encoding', type='bool', default=False,
help='use position encoding instead of bag of words embedding')
TorchAgent.add_cmdline_args(argparser)
TorchRankerAgent.add_cmdline_args(argparser)
MemnnAgent.dictionary_class().add_cmdline_args(argparser)
return arg_group

Expand All @@ -64,114 +60,44 @@ def model_version():
To use version 0, use --model legacy:memnn:0
(legacy agent code is located in parlai/agents/legacy_agents).
"""
return 1
# TODO: Update date that Version 2 split and move version 1 to legacy
return 2

def __init__(self, opt, shared=None):
init_model = None
if not shared: # only do this on first setup
# first check load path in case we need to override paths
if opt.get('init_model') and os.path.isfile(opt['init_model']):
# check first for 'init_model' for loading model from file
init_model = opt['init_model']

if opt.get('model_file') and os.path.isfile(opt['model_file']):
# next check for 'model_file', this would override init_model
init_model = opt['model_file']

if init_model is not None:
# if we are loading a model, should load its dict too
if (os.path.isfile(init_model + '.dict') or
opt['dict_file'] is None):
opt['dict_file'] = init_model + '.dict'
# all instances may need some params
super().__init__(opt, shared)

# all instances may need some params
self.id = 'MemNN'
self.memsize = opt['memsize']
if self.memsize < 0:
self.memsize = 0
self.use_time_features = opt['time_features']

if shared:
# set up shared properties
self.model = shared['model']
self.metrics = shared['metrics']
else:
self.metrics = {'loss': 0.0, 'batches': 0, 'rank': 0}

if not shared:
if opt['time_features']:
for i in range(self.memsize):
self.dict[self._time_feature(i)] = 100000000 + i

# initialize model from scratch
self._init_model()
if init_model is not None:
print('Loading existing model parameters from ' + init_model)
self.load(init_model)

# set up criteria
self.rank_loss = nn.CrossEntropyLoss() # TODO: rank loss option?

if self.use_cuda:
self.model.cuda()
self.rank_loss.cuda()

if 'train' in self.opt.get('datatype', ''):
# set up optimizer
optim_params = [p for p in self.model.parameters() if
p.requires_grad]
self._init_optim(optim_params)

def _init_model(self):
"""Initialize MemNN model."""
opt = self.opt
kwargs = opt_to_kwargs(opt)
self.model = MemNN(len(self.dict), opt['embedding_size'],
def build_model(self):
"""Build MemNN model."""
kwargs = opt_to_kwargs(self.opt)
self.model = MemNN(len(self.dict), self.opt['embedding_size'],
padding_idx=self.NULL_IDX, **kwargs)

def score_candidates(self, batch, cand_vecs):
mems = self._build_mems(batch.memory_vecs)
scores = self.model(batch.text_vec, mems, cand_vecs)
return scores

@lru_cache(maxsize=None) # bounded by opt['memsize'], cache string concats
def _time_feature(self, i):
"""Return time feature token at specified index."""
return '__tf{}__'.format(i)

def share(self):
"""Share model parameters."""
shared = super().share()
shared['model'] = self.model
if self.opt.get('numthreads', 1) > 1 and isinstance(self.metrics, dict):
torch.set_num_threads(1)
# move metrics and model to shared memory
self.metrics = SharedTable(self.metrics)
self.model.share_memory()
shared['metrics'] = self.metrics
return shared

def update_params(self):
"""Do optim step and clip gradients if needed."""
if self.clip > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()

def reset_metrics(self):
"""Reset metrics for reporting loss and perplexity."""
super().reset_metrics()
self.metrics['loss'] = 0.0
self.metrics['batches'] = 0
self.metrics['rank'] = 0

def report(self):
"""Report loss and mean_rank from model's perspective."""
m = {}
batches = self.metrics['batches']
if batches > 0:
if self.metrics['loss'] > 0:
m['loss'] = self.metrics['loss']
if self.metrics['rank'] > 0:
m['mean_rank'] = self.metrics['rank'] / batches
for k, v in m.items():
# clean up: rounds to sigfigs and converts tensors to floats
m[k] = round_sigfigs(v, 4)
return m
def get_dialog_history(self, *args, **kwargs):
"""Override options in get_dialog_history from parent."""
kwargs['add_p1_after_newln'] = True # will only happen if -pt True
return super().get_dialog_history(*args, **kwargs)

def vectorize(self, *args, **kwargs):
"""Override options in vectorize from parent."""
Expand All @@ -180,85 +106,6 @@ def vectorize(self, *args, **kwargs):
kwargs['split_lines'] = True
return super().vectorize(*args, **kwargs)

def get_dialog_history(self, *args, **kwargs):
"""Override options in get_dialog_history from parent."""
kwargs['add_p1_after_newln'] = True # will only happen if -pt True
return super().get_dialog_history(*args, **kwargs)

def _warn_once(self, flag, msg):
flag = '_warning_' + flag
if not hasattr(self, flag):
setattr(self, flag, True)
print(msg)

def _build_train_cands(self, labels, label_cands=None):
"""Build candidates from batch labels.
When the batchsize is 1, first we look for label_cands to be filled
(from batch.candidate_vecs). If available, we'll use those candidates.
Otherwise, we'll rank each token in the dictionary except NULL.
For batches of labels of a single token, we use torch.unique to return
only the unique tokens.
For batches of label sequences of length greater than one, we keep them
all so as not to waste too much time calculating uniqueness.
:param labels: (bsz x seqlen) LongTensor.
:param label_cands: default None. if bsz is 1 and label_cands is not
None, will use label_cands for training.
:return: tuple of tensors (cands, indices)
cands is (num_cands <= bsz x seqlen) candidates
indices is (bsz) index in cands of each original label
"""
assert labels.dim() == 2
if labels.size(0) == 1:
# we can't rank the batch of labels, see if there are label_cands
label = labels[0] # there's just one
if label_cands is not None:
self._warn_once(
'ranking_labelcands',
'[ Training using label_candidates fields as cands. ]')
label_cands, _ = padded_tensor(label_cands[0],
use_cuda=self.use_cuda)

if label_cands.size(1) == 1:
# use unique if cands are 1D
label_cands = label_cands.unique(return_inverse=False)
label_cands.unsqueeze_(1)
label_inds = (label_cands == label).all(1).nonzero()
if label_inds.size(0) > 1:
label_inds = self.random.choice(label_inds)
else:
label_inds.squeeze_(1)
return label_cands, label_inds
else:
self._warn_once(
'ranking_dict',
'[ Training using dictionary of tokens as cands. ]')
dict_size = len(self.dict)
full_dict = labels.new(range(1, dict_size))
# pick random token from label
if len(label) > 1:
token = self.random.choice(label)
else:
token = label[0] - 1
return full_dict.unsqueeze_(1), token.unsqueeze(0)
elif labels.size(1) == 1:
self._warn_once(
'ranking_unique',
'[ Training using unique labels in batch as cands. ]')
# use unique if input is 1D
cands, label_inds = labels.unique(return_inverse=True)
cands.unsqueeze_(1)
label_inds.squeeze_(1)
return cands, label_inds
else:
self._warn_once(
'ranking_batch',
'[ Training using other labels in batch as cands. ]')
return labels, labels.new(range(labels.size(0)))

def _build_mems(self, mems):
"""Build memory tensors.
Expand Down Expand Up @@ -312,94 +159,3 @@ def _build_mems(self, mems):
padded = padded.cuda()

return padded

def train_step(self, batch):
"""Train on a single batch of examples."""
if batch.text_vec is None:
return
batchsize = batch.text_vec.size(0)
self.model.train()
self.optimizer.zero_grad()
mems = self._build_mems(batch.memory_vecs)

cands, label_inds = self._build_train_cands(batch.label_vec,
batch.candidate_vecs)

scores = self.model(batch.text_vec, mems, cands)
loss = self.rank_loss(scores, label_inds)

self.metrics['loss'] += loss.item()
self.metrics['batches'] += batchsize
_, ranks = scores.sort(1, descending=True)
for b in range(batchsize):
rank = (ranks[b] == label_inds[b]).nonzero().item()
self.metrics['rank'] += 1 + rank
loss.backward()
self.update_params()

# get predictions but not full rankings--too slow to get hits@1 score
preds = [self._v2t(cands[row[0]]) for row in ranks]
return Output(preds)

def _build_label_cands(self, batch):
"""Convert batch.candidate_vecs to 3D padded vector."""
if not batch.candidates:
return None, None
cand_inds = [i for i in range(len(batch.candidates))
if batch.candidates[i]]
cands = padded_3d(batch.candidate_vecs, pad_idx=self.NULL_IDX,
use_cuda=self.use_cuda)
return cands, cand_inds

def eval_step(self, batch):
"""Evaluate a single batch of examples."""
if batch.text_vec is None:
return
batchsize = batch.text_vec.size(0)
self.model.eval()

mems = self._build_mems(batch.memory_vecs)
cands, cand_inds = self._build_label_cands(batch)
scores = self.model(batch.text_vec, mems, cands)

self.metrics['batches'] += batchsize
_, ranks = scores.sort(1, descending=True)

# calculate loss and mean rank
if batch.label_vec is not None and cands is not None:
label_inds = []
noproblems = True
for b in range(batchsize):
label_ind = (cands[b] == batch.label_vec[b]).all(1).nonzero()
if label_ind.size(0) == 0:
li = label_ind.item()
else:
# don't calculate loss
self._warn_once(
'eval_loss',
'[ WARNING: duplicate multitoken candidates detected. '
'This batch\'s loss and mean_rank calculation will be '
'skipped, skewing their totals. hits@k, accuracy, and '
'f1 will be unaffected so use those instead.')
noproblems = False
continue
label_inds.append(label_ind)
rank = (ranks[b] == li).nonzero().item()
self.metrics['rank'] += 1 + rank

if noproblems:
label_inds = torch.cat(label_inds, dim=0).squeeze(1)
loss = self.rank_loss(scores, label_inds)
self.metrics['loss'] += loss.item()

preds, cand_preds = None, None
if batch.candidates:
cand_preds = [[batch.candidates[b][i.item()] for i in row]
for b, row in enumerate(ranks)]
preds = [row[0] for row in cand_preds]
else:
cand_preds = [[self.dict[i.item()] for i in row]
for row in ranks]
preds = [row[0] for row in cand_preds]

return Output(preds, cand_preds)
2 changes: 1 addition & 1 deletion parlai/agents/seq2seq/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(self, opt, shared=None):
self.criterion.cuda()

if 'train' in opt.get('datatype', ''):
self._init_optim(
self.init_optim(
[p for p in self.model.parameters() if p.requires_grad],
optim_states=states.get('optimizer'),
saved_optim_type=states.get('optimizer_type'))
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(self, opt, shared=None):
self.rank_candidates = opt['rank_candidates']
self.add_person_tokens = opt.get('person_tokens', False)

def _init_optim(self, params, optim_states=None, saved_optim_type=None):
def init_optim(self, params, optim_states=None, saved_optim_type=None):
"""Initialize optimizer with model parameters.
:param params: parameters from the model, for example:
Expand Down
Loading

0 comments on commit 38a8c16

Please sign in to comment.