diff --git a/OpenNMT/translate.py b/OpenNMT/translate.py index f5b5713121..dd7f1d3f98 100644 --- a/OpenNMT/translate.py +++ b/OpenNMT/translate.py @@ -49,7 +49,7 @@ def reportScore(name, scoreTotal, wordsTotal): def main(): opt = parser.parse_args() - opt.cuda = True if opt.gpu > -1 + opt.cuda = opt.gpu > -1 torch.cuda.set_device(opt.gpu) translator = onmt.Translator(opt)