forked from guillaume-be/rust-bert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
marian.rs
74 lines (62 loc) · 2.66 KB
/
marian.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use rust_bert::marian::{
MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
MarianTargetLanguages, MarianVocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{
Language, TranslationConfig, TranslationModel, TranslationModelBuilder,
};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> anyhow::Result<()> {
// Set-up translation model
let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
let translation_config = TranslationConfig::new(
ModelType::Marian,
ModelResource::Torch(Box::new(model_resource)),
config_resource,
vocab_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::cuda_if_available(),
);
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
assert_eq!(outputs.len(), 2);
assert_eq!(
outputs[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(outputs[1], " Le chien ne s'est pas réveillé");
Ok(())
}
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation_builder() -> anyhow::Result<()> {
let model = TranslationModelBuilder::new()
.with_device(Device::cuda_if_available())
.with_model_type(ModelType::Marian)
.with_source_languages(vec![Language::English])
.with_target_languages(vec![Language::French])
.create_model()?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
let outputs = model.translate(&[input_context_1, input_context_2], None, Language::French)?;
assert_eq!(outputs.len(), 2);
assert_eq!(
outputs[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(outputs[1], " Le chien ne s'est pas réveillé");
Ok(())
}