diff --git a/parlai/agents/seq2seq/seq2seq.py b/parlai/agents/seq2seq/seq2seq.py index 3e050e450fc..42a0d20c2a6 100644 --- a/parlai/agents/seq2seq/seq2seq.py +++ b/parlai/agents/seq2seq/seq2seq.py @@ -176,23 +176,23 @@ def build_model(self, states=None): print('skipping preinitialization of embeddings for bpe') elif not states and opt['embedding_type'] != 'random': # `not states`: only set up embeddings if not loading model - self._copy_embeddings(self.model.decoder.lt.weight, opt['embedding_type']) + self._copy_embeddings(model.decoder.lt.weight, opt['embedding_type']) if opt['lookuptable'] in ['unique', 'dec_out']: # also set encoder lt, since it's not shared self._copy_embeddings( - self.model.encoder.lt.weight, opt['embedding_type'], log=False + model.encoder.lt.weight, opt['embedding_type'], log=False ) if states: # set loaded states if applicable - self.model.load_state_dict(states['model']) + model.load_state_dict(states['model']) if opt['embedding_type'].endswith('fixed'): print('Seq2seq: fixing embedding weights.') - self.model.decoder.lt.weight.requires_grad = False - self.model.encoder.lt.weight.requires_grad = False + model.decoder.lt.weight.requires_grad = False + model.encoder.lt.weight.requires_grad = False if opt['lookuptable'] in ['dec_out', 'all']: - self.model.decoder.e2s.weight.requires_grad = False + model.output.weight.requires_grad = False return model