|
25 | 25 | import os
|
26 | 26 | import sys
|
27 | 27 | import tempfile
|
| 28 | +import time |
28 | 29 | import traceback
|
29 | 30 | from collections import OrderedDict
|
30 | 31 | from pprint import pformat
|
|
103 | 104 | set_saved_weights_in_checkpoint_flag,
|
104 | 105 | )
|
105 | 106 | from ludwig.utils.print_utils import print_boxed
|
| 107 | +from ludwig.utils.tokenizers import HFTokenizer |
106 | 108 | from ludwig.utils.torch_utils import DEVICE
|
107 | 109 | from ludwig.utils.trainer_utils import get_training_report
|
108 | 110 | from ludwig.utils.types import DataFrame, TorchDevice
|
@@ -332,6 +334,27 @@ def __init__(
|
332 | 334 | # online training state
|
333 | 335 | self._online_trainer = None
|
334 | 336 |
|
| 337 | + # Zero-shot LLM usage. |
| 338 | + if ( |
| 339 | + self.config_obj.model_type == MODEL_LLM |
| 340 | + and self.config_obj.trainer.type == "none" |
| 341 | + # Category output features require a vocabulary. The LLM LudwigModel should be initialized with |
| 342 | + # model.train(dataset). |
| 343 | + and self.config_obj.output_features[0].type == "text" |
| 344 | + ): |
| 345 | + self._initialize_llm() |
| 346 | + |
| 347 | + def _initialize_llm(self, random_seed: int = default_random_seed): |
| 348 | + """Initialize the LLM model. |
| 349 | +
|
| 350 | + Should only be used in a zero-shot (NoneTrainer) setting. |
| 351 | + """ |
| 352 | + self.model = LudwigModel.create_model(self.config_obj, random_seed=random_seed) |
| 353 | + |
| 354 | + if self.model.model.device == "cpu": |
| 355 | + logger.warning(f"LLM was initialized on {self.model.model.device}. Moving to GPU for inference.") |
| 356 | + self.model.model.to(torch.device("cuda")) |
| 357 | + |
335 | 358 | def train(
|
336 | 359 | self,
|
337 | 360 | dataset: Optional[Union[str, dict, pd.DataFrame]] = None,
|
@@ -891,6 +914,53 @@ def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_s
|
891 | 914 | trainer.eval_batch_size = self.config_obj.trainer.eval_batch_size
|
892 | 915 | trainer.gradient_accumulation_steps = self.config_obj.trainer.gradient_accumulation_steps
|
893 | 916 |
|
| 917 | + def generate( |
| 918 | + self, |
| 919 | + input_strings: Union[str, List[str]], |
| 920 | + generation_config: Optional[dict] = None, |
| 921 | + ) -> Union[str, List[str]]: |
| 922 | + """A simple generate() method that directly uses the underlying transformers library to generate text.""" |
| 923 | + if self.config_obj.model_type != MODEL_LLM: |
| 924 | + raise ValueError( |
| 925 | + f"Model type {self.config_obj.model_type} is not supported by this method. Only `llm` model type is " |
| 926 | + "supported." |
| 927 | + ) |
| 928 | + if not torch.cuda.is_available() or torch.cuda.device_count() == 0: |
| 929 | + # GPU is generally well-advised for working with LLMs and is required for loading quantized models, see |
| 930 | + # https://github.com/ludwig-ai/ludwig/issues/3695. |
| 931 | + raise ValueError("GPU is not available.") |
| 932 | + |
| 933 | + # TODO(Justin): Decide if it's worth folding padding_side handling into llm.py's tokenizer initialization. |
| 934 | + # For batch inference with models like facebook/opt-350m, if the tokenizer padding side is off, HF prints a |
| 935 | + # warning, e.g.: |
| 936 | + # "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, " |
| 937 | + # "please set `padding_side='left'` when initializing the tokenizer. |
| 938 | + if not self.model.model.config.is_encoder_decoder: |
| 939 | + padding_side = "left" |
| 940 | + else: |
| 941 | + padding_side = "right" |
| 942 | + tokenizer = HFTokenizer(self.config_obj.base_model, padding_side=padding_side) |
| 943 | + |
| 944 | + with self.model.use_generation_config(generation_config): |
| 945 | + start_time = time.time() |
| 946 | + inputs = tokenizer.tokenizer(input_strings, return_tensors="pt", padding=True) |
| 947 | + input_ids = inputs["input_ids"].to("cuda") |
| 948 | + attention_mask = inputs["attention_mask"].to("cuda") |
| 949 | + with torch.no_grad(): |
| 950 | + outputs = self.model.model.generate( |
| 951 | + input_ids=input_ids, |
| 952 | + attention_mask=attention_mask, |
| 953 | + # NOTE: self.model.model.generation_config is not used here because it is the default |
| 954 | + # generation config that the CausalLM was initialized with, rather than the one set within the |
| 955 | + # context manager. |
| 956 | + generation_config=self.model.generation, |
| 957 | + ) |
| 958 | + decoded_outputs = tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| 959 | + logger.info(f"Finished generating in: {(time.time() - start_time):.2f}s.") |
| 960 | + if len(decoded_outputs) == 1: |
| 961 | + return decoded_outputs[0] |
| 962 | + return decoded_outputs |
| 963 | + |
894 | 964 | def predict(
|
895 | 965 | self,
|
896 | 966 | dataset: Optional[Union[str, dict, pd.DataFrame]] = None,
|
@@ -946,6 +1016,7 @@ def predict(
|
946 | 1016 | self._check_initialization()
|
947 | 1017 |
|
948 | 1018 | # preprocessing
|
| 1019 | + start_time = time.time() |
949 | 1020 | logger.debug("Preprocessing")
|
950 | 1021 | dataset, _ = preprocess_for_prediction( # TODO (Connor): Refactor to use self.config_obj
|
951 | 1022 | self.config_obj.to_dict(),
|
@@ -992,6 +1063,7 @@ def predict(
|
992 | 1063 |
|
993 | 1064 | logger.info(f"Saved to: {output_directory}")
|
994 | 1065 |
|
| 1066 | + logger.info(f"Finished predicting in: {(time.time() - start_time):.2f}s.") |
995 | 1067 | return converted_postproc_predictions, output_directory
|
996 | 1068 |
|
997 | 1069 | def evaluate(
|
|
0 commit comments