forked from rustformers/llm
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.rs
85 lines (75 loc) · 2.58 KB
/
inference.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use clap::Parser;
use std::{convert::Infallible, io::Write, path::PathBuf};
#[derive(Parser)]
struct Args {
model_architecture: llm::ModelArchitecture,
model_path: PathBuf,
#[arg(long, short = 'p')]
prompt: Option<String>,
#[arg(long, short = 'v')]
pub tokenizer_path: Option<PathBuf>,
#[arg(long, short = 'r')]
pub tokenizer_repository: Option<String>,
}
impl Args {
pub fn to_tokenizer_source(&self) -> llm::TokenizerSource {
match (&self.tokenizer_path, &self.tokenizer_repository) {
(Some(_), Some(_)) => {
panic!("Cannot specify both --tokenizer-path and --tokenizer-repository");
}
(Some(path), None) => llm::TokenizerSource::HuggingFaceTokenizerFile(path.to_owned()),
(None, Some(repo)) => llm::TokenizerSource::HuggingFaceRemote(repo.to_owned()),
(None, None) => llm::TokenizerSource::Embedded,
}
}
}
fn main() {
let args = Args::parse();
let tokenizer_source = args.to_tokenizer_source();
let model_architecture = args.model_architecture;
let model_path = args.model_path;
let prompt = args
.prompt
.as_deref()
.unwrap_or("Rust is a cool programming language because");
let now = std::time::Instant::now();
let model = llm::load_dynamic(
Some(model_architecture),
&model_path,
tokenizer_source,
Default::default(),
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| {
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
});
println!(
"Model fully loaded! Elapsed: {}ms",
now.elapsed().as_millis()
);
let mut session = model.start_session(Default::default());
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rand::thread_rng(),
&llm::InferenceRequest {
prompt: prompt.into(),
parameters: &llm::InferenceParameters::default(),
play_back_previous_tokens: false,
maximum_token_count: None,
},
// OutputRequest
&mut Default::default(),
|r| match r {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
},
);
match res {
Ok(result) => println!("\n\nInference stats:\n{result}"),
Err(err) => println!("\n{err}"),
}
}