Skip to content

Commit

Permalink
remove self. from model references in seq2seq init
Browse files Browse the repository at this point in the history
  • Loading branch information
alexholdenmiller committed Aug 12, 2019
1 parent ad429be commit 9f1691f
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions parlai/agents/seq2seq/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,23 @@ def build_model(self, states=None):
print('skipping preinitialization of embeddings for bpe')
elif not states and opt['embedding_type'] != 'random':
# `not states`: only set up embeddings if not loading model
self._copy_embeddings(self.model.decoder.lt.weight, opt['embedding_type'])
self._copy_embeddings(model.decoder.lt.weight, opt['embedding_type'])
if opt['lookuptable'] in ['unique', 'dec_out']:
# also set encoder lt, since it's not shared
self._copy_embeddings(
self.model.encoder.lt.weight, opt['embedding_type'], log=False
model.encoder.lt.weight, opt['embedding_type'], log=False
)

if states:
# set loaded states if applicable
self.model.load_state_dict(states['model'])
model.load_state_dict(states['model'])

if opt['embedding_type'].endswith('fixed'):
print('Seq2seq: fixing embedding weights.')
self.model.decoder.lt.weight.requires_grad = False
self.model.encoder.lt.weight.requires_grad = False
model.decoder.lt.weight.requires_grad = False
model.encoder.lt.weight.requires_grad = False
if opt['lookuptable'] in ['dec_out', 'all']:
self.model.decoder.e2s.weight.requires_grad = False
model.output.weight.requires_grad = False

return model

Expand Down

0 comments on commit 9f1691f

Please sign in to comment.