Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/UKPLab/EasyNMT into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Feb 1, 2021
2 parents 3c9f959 + 37ec20f commit fe52d2d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions easynmt/models/OpusMT.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def load_model(self, model_name):
self.models[model_name] = {'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()}
return tokenizer, model

def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, max_length: int = None):
def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, max_length: int = None, do_sample=False, top_k=50, top_p=1.0):
model_name = 'Helsinki-NLP/opus-mt-{}-{}'.format(source_lang, target_lang)
tokenizer, model = self.load_model(model_name)
model.to(device)
Expand All @@ -52,7 +52,7 @@ def translate_sentences(self, sentences: List[str], source_lang: str, target_lan
inputs[key] = inputs[key].to(device)

with torch.no_grad():
translated = model.generate(**inputs, num_beams=beam_size)
translated = model.generate(**inputs, num_beams=beam_size, do_sample=do_sample, top_k=top_k, top_p=top_p)
output = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]

return output
Expand Down

0 comments on commit fe52d2d

Please sign in to comment.