Skip to content

Commit

Permalink
Addition of max_new_tokens generation option and documentation update
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Oct 28, 2021
1 parent 184ebc2 commit 1c42cfa
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions src/pipelines/generation_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1423,26 +1423,51 @@ pub struct GeneratedIndicesOutput {
}

#[derive(Clone, Copy)]
/// # Generation options for text generation.
/// When provided to a `generate` method, these options will take priority over the `GenerateConfig` used to create the
/// `LanguageGenerator`. Some of these options may be left as `None`, options without a value will individually default
/// to the `GenerateConfig`.
pub struct GenerateOptions<'a> {
/// Minimum sequence length
pub min_length: Option<i64>,
/// Maximum sequence length
pub max_length: Option<i64>,
/// Maximum number of new tokens to generate (useful for causal generation models).
/// Only one of `max_length` and `max_new_tokens` should be provided.
/// When both are given, `max_new_tokens` is ignored and the `max_length` setting is used.
pub max_new_tokens: Option<i64>,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated
pub early_stopping: Option<bool>,
/// Number of sequences to return for each prompt text
pub num_return_sequences: Option<i64>,
/// Number of beams for beam search
pub num_beams: Option<i64>,
pub num_beam_groups: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding
pub do_sample: Option<bool>,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance
pub temperature: Option<f64>,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature
pub top_k: Option<i64>,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p
pub top_p: Option<f64>,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated.
pub repetition_penalty: Option<f64>,
/// Exponential penalty based on the length of the hypotheses generated
pub length_penalty: Option<f64>,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature
pub no_repeat_ngram_size: Option<i64>,
/// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups
pub diversity_penalty: Option<f64>,
/// Decoder start token id
pub decoder_start_token_id: Option<i64>,
/// Forced first token generated
pub forced_bos_token_id: Option<i64>,
/// Function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
pub prefix_allowed_tokens_fn: Option<&'a dyn Fn(i64, &Tensor) -> Vec<i64>>,
/// List of bad word ids (may be a sequence of word ids) that will be banned during the generation
pub bad_word_ids: Option<&'a Vec<Vec<i64>>>,
/// Flag indicating if text generation scores should be returned
pub output_scores: bool,
}

Expand Down Expand Up @@ -1727,8 +1752,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let device = Device::cuda_if_available();
///
/// let gpt2_generator = GPT2Generator::new(Default::default())?;
/// let input_tensor = Tensor::randn(&[32,128], (Kind::Int64, device:: Cpu));
/// let input_mask = Tensor::ones(&[32,128], (Kind::Int64, device:: Cpu));
/// let input_tensor = Tensor::randn(&[32,128], (Kind::Int64, Device::Cpu));
/// let input_mask = Tensor::ones(&[32,128], (Kind::Int64, Device::Cpu));
///
/// let generate_options = GenerateOptions {
/// min_length: Some(32),
Expand Down Expand Up @@ -1761,7 +1786,6 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let num_return_sequences = unpack_config!(num_return_sequences, generate_options, config);
let num_beams = unpack_config!(num_beams, generate_options, config);
let min_length = unpack_config!(min_length, generate_options, config);
let max_length = unpack_config!(max_length, generate_options, config);
let early_stopping = unpack_config!(early_stopping, generate_options, config);
let temperature = unpack_config!(temperature, generate_options, config);
let top_k = unpack_config!(top_k, generate_options, config);
Expand Down Expand Up @@ -1874,6 +1898,16 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
(input_ids, attention_mask)
};

let max_length = if let Some(generate_options) = generate_options {
match (generate_options.max_length, generate_options.max_new_tokens) {
(Some(max_length), _) => max_length,
(None, Some(max_new_tokens)) => max_new_tokens + input_ids.size().last().unwrap(),
(None, None) => config.max_length,
}
} else {
config.max_length
};

let gen_opt = InternalGenerateOptions {
min_length,
max_length,
Expand Down

0 comments on commit 1c42cfa

Please sign in to comment.