Skip to content

Commit cd20458

Browse files
authored
Define initialize_llm() and generate() methods. Remove extra logging in llm.py (ludwig-ai#3711)
1 parent a0c42d8 commit cd20458

File tree

4 files changed

+74
-11
lines changed

4 files changed

+74
-11
lines changed

ludwig/api.py

+72
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import os
2626
import sys
2727
import tempfile
28+
import time
2829
import traceback
2930
from collections import OrderedDict
3031
from pprint import pformat
@@ -103,6 +104,7 @@
103104
set_saved_weights_in_checkpoint_flag,
104105
)
105106
from ludwig.utils.print_utils import print_boxed
107+
from ludwig.utils.tokenizers import HFTokenizer
106108
from ludwig.utils.torch_utils import DEVICE
107109
from ludwig.utils.trainer_utils import get_training_report
108110
from ludwig.utils.types import DataFrame, TorchDevice
@@ -332,6 +334,27 @@ def __init__(
332334
# online training state
333335
self._online_trainer = None
334336

337+
# Zero-shot LLM usage.
338+
if (
339+
self.config_obj.model_type == MODEL_LLM
340+
and self.config_obj.trainer.type == "none"
341+
# Category output features require a vocabulary. The LLM LudwigModel should be initialized with
342+
# model.train(dataset).
343+
and self.config_obj.output_features[0].type == "text"
344+
):
345+
self._initialize_llm()
346+
347+
def _initialize_llm(self, random_seed: int = default_random_seed):
348+
"""Initialize the LLM model.
349+
350+
Should only be used in a zero-shot (NoneTrainer) setting.
351+
"""
352+
self.model = LudwigModel.create_model(self.config_obj, random_seed=random_seed)
353+
354+
if self.model.model.device == "cpu":
355+
logger.warning(f"LLM was initialized on {self.model.model.device}. Moving to GPU for inference.")
356+
self.model.model.to(torch.device("cuda"))
357+
335358
def train(
336359
self,
337360
dataset: Optional[Union[str, dict, pd.DataFrame]] = None,
@@ -891,6 +914,53 @@ def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_s
891914
trainer.eval_batch_size = self.config_obj.trainer.eval_batch_size
892915
trainer.gradient_accumulation_steps = self.config_obj.trainer.gradient_accumulation_steps
893916

917+
def generate(
918+
self,
919+
input_strings: Union[str, List[str]],
920+
generation_config: Optional[dict] = None,
921+
) -> Union[str, List[str]]:
922+
"""A simple generate() method that directly uses the underlying transformers library to generate text."""
923+
if self.config_obj.model_type != MODEL_LLM:
924+
raise ValueError(
925+
f"Model type {self.config_obj.model_type} is not supported by this method. Only `llm` model type is "
926+
"supported."
927+
)
928+
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
929+
# GPU is generally well-advised for working with LLMs and is required for loading quantized models, see
930+
# https://github.com/ludwig-ai/ludwig/issues/3695.
931+
raise ValueError("GPU is not available.")
932+
933+
# TODO(Justin): Decide if it's worth folding padding_side handling into llm.py's tokenizer initialization.
934+
# For batch inference with models like facebook/opt-350m, if the tokenizer padding side is off, HF prints a
935+
# warning, e.g.:
936+
# "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, "
937+
# "please set `padding_side='left'` when initializing the tokenizer.
938+
if not self.model.model.config.is_encoder_decoder:
939+
padding_side = "left"
940+
else:
941+
padding_side = "right"
942+
tokenizer = HFTokenizer(self.config_obj.base_model, padding_side=padding_side)
943+
944+
with self.model.use_generation_config(generation_config):
945+
start_time = time.time()
946+
inputs = tokenizer.tokenizer(input_strings, return_tensors="pt", padding=True)
947+
input_ids = inputs["input_ids"].to("cuda")
948+
attention_mask = inputs["attention_mask"].to("cuda")
949+
with torch.no_grad():
950+
outputs = self.model.model.generate(
951+
input_ids=input_ids,
952+
attention_mask=attention_mask,
953+
# NOTE: self.model.model.generation_config is not used here because it is the default
954+
# generation config that the CausalLM was initialized with, rather than the one set within the
955+
# context manager.
956+
generation_config=self.model.generation,
957+
)
958+
decoded_outputs = tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
959+
logger.info(f"Finished generating in: {(time.time() - start_time):.2f}s.")
960+
if len(decoded_outputs) == 1:
961+
return decoded_outputs[0]
962+
return decoded_outputs
963+
894964
def predict(
895965
self,
896966
dataset: Optional[Union[str, dict, pd.DataFrame]] = None,
@@ -946,6 +1016,7 @@ def predict(
9461016
self._check_initialization()
9471017

9481018
# preprocessing
1019+
start_time = time.time()
9491020
logger.debug("Preprocessing")
9501021
dataset, _ = preprocess_for_prediction( # TODO (Connor): Refactor to use self.config_obj
9511022
self.config_obj.to_dict(),
@@ -992,6 +1063,7 @@ def predict(
9921063

9931064
logger.info(f"Saved to: {output_directory}")
9941065

1066+
logger.info(f"Finished predicting in: {(time.time() - start_time):.2f}s.")
9951067
return converted_postproc_predictions, output_directory
9961068

9971069
def evaluate(

ludwig/models/llm.py

-8
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,6 @@ def generate(
417417
sequences_list = []
418418
for input_ids_sample in input_ids:
419419
input_ids_sample_no_padding = remove_left_padding(input_ids_sample, self.tokenizer)
420-
logger.info(
421-
"Decoded text inputs for the first example in batch: "
422-
f"{self.tokenizer.decode(input_ids_sample_no_padding[0], skip_special_tokens=True)}"
423-
)
424420

425421
if input_ids_sample_no_padding.shape[1] > self.max_input_length:
426422
logger.warning(
@@ -443,10 +439,6 @@ def generate(
443439
return_dict_in_generate=True,
444440
output_scores=True,
445441
)
446-
logger.info(
447-
"Decoded generated output for the first example in batch: "
448-
f"{self.tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True)[0]}"
449-
)
450442

451443
sequences_list.append(model_outputs.sequences[0])
452444

tests/integration_tests/test_api.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ def test_constant_metadata(tmpdir):
768768
assert metadata1 == metadata2
769769

770770

771+
@pytest.mark.integration_tests_e
771772
@pytest.mark.parametrize(
772773
"input_max_sequence_length, global_max_sequence_length, expect_raise",
773774
[
@@ -797,9 +798,6 @@ def test_llm_template_too_long(tmpdir, input_max_sequence_length, global_max_seq
797798
798799
preprocessing:
799800
global_max_sequence_length: {global_max_sequence_length}
800-
801-
quantization:
802-
bits: 4
803801
"""
804802
)
805803
zero_shot_config["prompt"] = {}

tests/integration_tests/test_peft.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tests.integration_tests.utils import binary_feature, generate_data, run_test_suite, text_feature
77

88

9+
@pytest.mark.integration_tests_e
910
@pytest.mark.parametrize(
1011
"backend",
1112
[

0 commit comments

Comments
 (0)