diff --git a/projects/controllable_dialogue/controllable_seq2seq/controllable_seq2seq.py b/projects/controllable_dialogue/controllable_seq2seq/controllable_seq2seq.py index 289f032b846..f3e9673290b 100644 --- a/projects/controllable_dialogue/controllable_seq2seq/controllable_seq2seq.py +++ b/projects/controllable_dialogue/controllable_seq2seq/controllable_seq2seq.py @@ -428,10 +428,10 @@ def build_model(self, states=None): if opt['embedding_type'].endswith('fixed'): print('Seq2seq: fixing embedding weights.') - self.model.decoder.lt.weight.requires_grad = False - self.model.encoder.lt.weight.requires_grad = False + model.decoder.lt.weight.requires_grad = False + model.encoder.lt.weight.requires_grad = False if opt['lookuptable'] in ['dec_out', 'all']: - self.model.decoder.e2s.weight.requires_grad = False + model.decoder.output.e2s.weight.requires_grad = False return model