Skip to content

Commit 3dc8f4b

Browse files
[FEATURE] Support Merging LoRA Weights Into Base Model (Issue-3603) (ludwig-ai#3649)
1 parent a3b7709 commit 3dc8f4b

File tree

5 files changed

+248
-15
lines changed

5 files changed

+248
-15
lines changed

ludwig/api.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,9 @@ def train(
429429
will contain the training statistics, TensorBoard logs, the saved
430430
model and the training progress files.
431431
:param random_seed: (int, default: `42`) a random seed that will be
432-
used anywhere there is a call to a random number generator: data
433-
splitting, parameter initialization and training set shuffling
432+
used anywhere there is a call to a random number generator: data
433+
splitting, parameter initialization and training set shuffling
434+
:param kwargs: (dict, default: {}) a dictionary of optional parameters.
434435
435436
# Return
436437
@@ -645,6 +646,9 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
645646
)
646647
(self.model, train_trainset_stats, train_valiset_stats, train_testset_stats) = train_stats
647648

649+
# For an LLM model trained with a LoRA adapter, handle merge and unload postprocessing directives.
650+
self._merge_and_unload()
651+
648652
# Calibrates output feature probabilities on validation set if calibration is enabled.
649653
# Must be done after training, and before final model parameters are saved.
650654
if self.backend.is_coordinator():
@@ -807,6 +811,21 @@ def train_online(
807811

808812
self.model = self._online_trainer.train_online(training_dataset)
809813

814+
def _merge_and_unload(self) -> None:
815+
"""For an LLM model trained with a LoRA adapter, handle merge and unload postprocessing directives.
816+
817+
First, check that the model is of the "llm" type. Then check if the "adapter" configuration section contains
818+
the "postprocessor" subsection and apply the "merge_adapter_into_base_model" and "progressbar" directives.
819+
"""
820+
if (
821+
self.config_obj.model_type == "llm"
822+
and self.config_obj.adapter is not None
823+
and self.config_obj.adapter.postprocessor is not None
824+
and self.config_obj.adapter.postprocessor.merge_adapter_into_base_model
825+
and hasattr(self.model, "merge_and_unload")
826+
):
827+
self.model.merge_and_unload(progressbar=self.config_obj.adapter.postprocessor.progressbar)
828+
810829
def _tune_batch_size(self, trainer, dataset, random_seed: int = default_random_seed):
811830
"""Sets AUTO batch-size-related parameters based on the trainer, backend type, and number of workers.
812831
@@ -1643,6 +1662,9 @@ def load(
16431662
# load model weights
16441663
ludwig_model.load_weights(model_dir)
16451664

1665+
# The LoRA layers appear to be loaded again (perhaps due to a potential bug); hence, we merge and unload again.
1666+
ludwig_model._merge_and_unload()
1667+
16461668
# load train set metadata
16471669
ludwig_model.training_set_metadata = backend.broadcast_return(
16481670
lambda: load_metadata(os.path.join(model_dir, TRAIN_SET_METADATA_FILE_NAME))

ludwig/constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@
283283
PROMPT = "prompt"
284284
ADAPTER = "adapter"
285285
PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights"
286+
MERGE_ADAPTER_INTO_BASE_MODEL = "merge_adapter_into_base_model"
287+
PROGRESSBAR = "progressbar"
286288

287289
# CrossEntropyLoss for LLMs
288290
IGNORE_INDEX_TOKEN_ID = -100

ludwig/models/llm.py

+15
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,21 @@ def generate(
451451

452452
return outputs
453453

454+
def merge_and_unload(self, progressbar: bool = False) -> None:
455+
"""This method merges the LoRa layers into the base model. This is needed if someone wants to use the base
456+
model as a standalone model. The implementation calls merge_and_unload() of the underlying LoraModel class
457+
(in peft).
458+
459+
Args:
460+
progressbar (bool): whether to show a progressbar indicating the unload and merge process
461+
"""
462+
from peft import LoraModel
463+
464+
if isinstance(self.model.base_model, LoraModel):
465+
self.model.base_model.merge_and_unload(progressbar=progressbar)
466+
else:
467+
raise ValueError("This operation requires an LLM model trained with a LoRA adapter.")
468+
454469
def _unpack_inputs(
455470
self,
456471
inputs: Union[

ludwig/schema/llms/peft.py

+28
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,32 @@ def wrap(config: BaseAdapterConfig):
2525
return wrap
2626

2727

28+
@DeveloperAPI
29+
@ludwig_dataclass
30+
class LoraPostprocessorConfig(schema_utils.BaseMarshmallowConfig):
31+
"""This Dataclass is a schema for the nested postprocessing config under adapter of type "lora"."""
32+
33+
merge_adapter_into_base_model: bool = schema_utils.Boolean(
34+
default=False,
35+
description="""Instructs whether or not the fine-tuned LoRA weights are to be merged into the base LLM model so
36+
that the complete fine-tuned model is available to be used and/or persisted, and then reused upon loading as a single
37+
model (rather than having to load base and fine-tuned models separately).""",
38+
)
39+
progressbar: bool = schema_utils.Boolean(
40+
default=False,
41+
description="Instructs whether or not to show a progress bar indicating the unload and merge process.",
42+
)
43+
44+
45+
@DeveloperAPI
46+
class LoraPostprocessorConfigField(schema_utils.DictMarshmallowField):
47+
def __init__(self):
48+
super().__init__(LoraPostprocessorConfig)
49+
50+
def _jsonschema_type_mapping(self):
51+
return schema_utils.unload_jsonschema_from_marshmallow_class(LoraPostprocessorConfig, title="LoraPostprocessor")
52+
53+
2854
@DeveloperAPI
2955
@ludwig_dataclass
3056
class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC):
@@ -34,6 +60,8 @@ class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC):
3460
default=None, description="Path to pretrained weights.", allow_none=True
3561
)
3662

63+
postprocessor: LoraPostprocessorConfig = LoraPostprocessorConfigField().get_default_field()
64+
3765
@abstractmethod
3866
def to_config(self, **kwargs) -> "PeftConfig":
3967
pass

tests/integration_tests/test_llm.py

+179-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Dict, Tuple
2+
from typing import Dict, Tuple, Union
33

44
import numpy as np
55
import pandas as pd
@@ -15,11 +15,14 @@
1515
EPOCHS,
1616
GENERATION,
1717
INPUT_FEATURES,
18+
MERGE_ADAPTER_INTO_BASE_MODEL,
1819
MODEL_LLM,
1920
MODEL_TYPE,
2021
OUTPUT_FEATURES,
22+
POSTPROCESSOR,
2123
PREPROCESSING,
2224
PRETRAINED_ADAPTER_WEIGHTS,
25+
PROGRESSBAR,
2326
PROMPT,
2427
TRAINER,
2528
TYPE,
@@ -352,6 +355,84 @@ def _prepare_finetuning_test(
352355
return train_df, prediction_df, config
353356

354357

358+
def _finetune_strategy_requires_cuda(finetune_strategy_name: str, quantization_args: Union[dict, None]) -> bool:
359+
"""This method returns whether or not a given finetine_strategy requires CUDA.
360+
361+
For all finetune strategies, except "qlora", the decision is based just on the name of the finetine_strategy; in the
362+
case of qlora, if the quantization dictionary is non-empty (i.e., contains quantization specifications), then the
363+
original finetine_strategy name of "lora" is interpreted as "qlora" and used in the lookup, based on the list of
364+
finetine strategies requiring CUDA.
365+
"""
366+
cuda_only_finetune_strategy_names: list[str] = [
367+
"prompt_tuning",
368+
"prefix_tuning",
369+
"p_tuning",
370+
"qlora",
371+
]
372+
373+
if finetune_strategy_name == "lora" and quantization_args:
374+
finetune_strategy_name = "qlora"
375+
376+
return finetune_strategy_name in cuda_only_finetune_strategy_names
377+
378+
379+
def _verify_lm_lora_finetuning_layers(
380+
attention_layer: torch.nn.Module,
381+
merge_adapter_into_base_model: bool,
382+
expected_lora_in_features: int,
383+
expected_lora_out_features: int,
384+
) -> bool:
385+
"""This method verifies that LoRA finetuning layers have correct types and shapes, depending on whether or not
386+
the optional "model.merge_and_unload()" method (based on the "merge_adapter_into_base_model" directive) was
387+
executed.
388+
389+
If merge_adapter_into_base_model is True, then both LoRA projection layers, V and Q, in the attention layer must
390+
contain square weight matrices (with the dimensions expected_lora_in_features by expected_lora_in_features).
391+
However, if merge_adapter_into_base_model is False, then the LoRA part of the attention layer must include Lora_A
392+
and Lora_B children layers for each of V and Q projections, such that the product of V and Q matrices is a square
393+
matrix (with the dimensions expected_lora_in_features by expected_lora_in_features) for both V and Q projections.
394+
"""
395+
success: bool = True
396+
success = success and isinstance(attention_layer.v_proj, torch.nn.Linear)
397+
success = success and isinstance(attention_layer.q_proj, torch.nn.Linear)
398+
if merge_adapter_into_base_model:
399+
success = success and (attention_layer.v_proj.in_features, attention_layer.v_proj.out_features) == (
400+
expected_lora_in_features,
401+
expected_lora_out_features,
402+
)
403+
success = success and (attention_layer.q_proj.in_features, attention_layer.q_proj.out_features) == (
404+
expected_lora_in_features,
405+
expected_lora_out_features,
406+
)
407+
success = success and not list(attention_layer.v_proj.children())
408+
success = success and not list(attention_layer.q_proj.children())
409+
else:
410+
v_proj_named_children: dict[str, torch.nn.Modeule] = dict(attention_layer.v_proj.named_children())
411+
assert isinstance(v_proj_named_children["lora_A"]["default"], torch.nn.Linear)
412+
assert (
413+
v_proj_named_children["lora_A"]["default"].in_features,
414+
v_proj_named_children["lora_A"]["default"].out_features,
415+
) == (expected_lora_in_features, expected_lora_out_features)
416+
assert isinstance(v_proj_named_children["lora_B"]["default"], torch.nn.Linear)
417+
assert (
418+
v_proj_named_children["lora_B"]["default"].in_features,
419+
v_proj_named_children["lora_B"]["default"].out_features,
420+
) == (expected_lora_out_features, expected_lora_in_features)
421+
q_proj_named_children: dict[str, torch.nn.Modeule] = dict(attention_layer.q_proj.named_children())
422+
assert isinstance(q_proj_named_children["lora_A"]["default"], torch.nn.Linear)
423+
assert (
424+
q_proj_named_children["lora_A"]["default"].in_features,
425+
q_proj_named_children["lora_A"]["default"].out_features,
426+
) == (expected_lora_in_features, expected_lora_out_features)
427+
assert isinstance(q_proj_named_children["lora_B"]["default"], torch.nn.Linear)
428+
assert (
429+
q_proj_named_children["lora_B"]["default"].in_features,
430+
q_proj_named_children["lora_B"]["default"].out_features,
431+
) == (expected_lora_out_features, expected_lora_in_features)
432+
433+
return success
434+
435+
355436
# TODO(arnav): p-tuning and prefix tuning have errors when enabled that seem to stem from DDP:
356437
#
357438
# prefix tuning:
@@ -376,8 +457,12 @@ def _prepare_finetuning_test(
376457
(None, {}),
377458
("lora", {}),
378459
("lora", {"r": 4, "dropout": 0.1}),
460+
("lora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: True, PROGRESSBAR: True}}),
461+
("lora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: False}}),
379462
("adalora", {}),
380463
("adalora", {"init_r": 8, "beta1": 0.8}),
464+
("adalora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: True, PROGRESSBAR: True}}),
465+
("adalora", {POSTPROCESSOR: {MERGE_ADAPTER_INTO_BASE_MODEL: False}}),
381466
("adaption_prompt", {}),
382467
("adaption_prompt", {"adapter_len": 6, "adapter_layers": 1}),
383468
# (
@@ -403,8 +488,12 @@ def _prepare_finetuning_test(
403488
"full",
404489
"lora-defaults",
405490
"lora-modified-defaults",
491+
"lora_merged",
492+
"lora_not_merged",
406493
"adalora-defaults",
407494
"adalora-modified-defaults",
495+
"adalora_merged",
496+
"adalora_not_merged",
408497
"adaption_prompt-defaults",
409498
"adaption_prompt-modified-defaults",
410499
# "prompt_tuning_init_random",
@@ -445,7 +534,10 @@ def test_llm_finetuning_strategies(tmpdir, csv_filename, backend, finetune_strat
445534
],
446535
)
447536
def test_llm_finetuning_strategies_quantized(tmpdir, csv_filename, finetune_strategy, adapter_args, quantization):
448-
if quantization and (not torch.cuda.is_available() or torch.cuda.device_count() == 0):
537+
if (
538+
_finetune_strategy_requires_cuda(finetune_strategy_name=finetune_strategy, quantization_args=quantization)
539+
and not (torch.cuda.is_available() and torch.cuda.device_count()) > 0
540+
):
449541
pytest.skip("Skip: quantization requires GPU and none are available.")
450542

451543
backend = LOCAL_BACKEND
@@ -469,6 +561,66 @@ def test_llm_finetuning_strategies_quantized(tmpdir, csv_filename, finetune_stra
469561
assert preds
470562

471563

564+
@pytest.mark.llm
565+
@pytest.mark.parametrize(
566+
"backend",
567+
[
568+
pytest.param(LOCAL_BACKEND, id="local"),
569+
# TODO: Re-enable once we can run tests on GPUs
570+
# This is because fine-tuning requires Ray with the deepspeed strategy, and deepspeed
571+
# only works with GPUs
572+
# pytest.param(RAY_BACKEND, id="ray"),
573+
],
574+
)
575+
@pytest.mark.parametrize(
576+
"merge_adapter_into_base_model,expected_lora_in_features,expected_lora_out_features",
577+
[
578+
pytest.param(
579+
False,
580+
32,
581+
8,
582+
id="lora_not_merged",
583+
),
584+
pytest.param(
585+
True,
586+
32,
587+
32,
588+
id="lora_merged",
589+
),
590+
],
591+
)
592+
def test_llm_lora_finetuning_merge_and_unload(
593+
tmpdir, csv_filename, backend, merge_adapter_into_base_model, expected_lora_in_features, expected_lora_out_features
594+
):
595+
finetune_strategy: str = "lora"
596+
adapter_args: dict = {
597+
POSTPROCESSOR: {
598+
MERGE_ADAPTER_INTO_BASE_MODEL: merge_adapter_into_base_model,
599+
},
600+
}
601+
train_df, prediction_df, config = _prepare_finetuning_test(
602+
csv_filename=csv_filename, finetune_strategy=finetune_strategy, backend=backend, adapter_args=adapter_args
603+
)
604+
605+
model = LudwigModel(config)
606+
model.train(dataset=train_df, output_directory=str(tmpdir), skip_save_processed_input=False)
607+
assert _verify_lm_lora_finetuning_layers(
608+
attention_layer=model.model.model.base_model.model.transformer.h[1].attn,
609+
merge_adapter_into_base_model=merge_adapter_into_base_model,
610+
expected_lora_in_features=expected_lora_in_features,
611+
expected_lora_out_features=expected_lora_out_features,
612+
)
613+
614+
# Make sure we can load the saved model and verify that the LoRA layers have expected shapes.
615+
model = LudwigModel.load(os.path.join(str(tmpdir), "api_experiment_run", "model"), backend=backend)
616+
assert _verify_lm_lora_finetuning_layers(
617+
attention_layer=model.model.model.base_model.model.transformer.h[1].attn,
618+
merge_adapter_into_base_model=merge_adapter_into_base_model,
619+
expected_lora_in_features=expected_lora_in_features,
620+
expected_lora_out_features=expected_lora_out_features,
621+
)
622+
623+
472624
@pytest.mark.llm
473625
@pytest.mark.parametrize("use_adapter", [True, False], ids=["with_adapter", "without_adapter"])
474626
def test_llm_training_with_gradient_checkpointing(tmpdir, csv_filename, use_adapter):
@@ -628,23 +780,37 @@ def test_load_pretrained_adapter_weights(adapter):
628780

629781
def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool:
630782
# For a full explanation of this 8-bit workaround, see https://github.com/ludwig-ai/ludwig/pull/3606
631-
def filter_for_weight_format(i):
632-
"""Remove bitsandbytes metadata keys added on state dict creation.
633783

634-
8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict.
635-
These contain strings that are used to reshape quantized tensors, however these have no impact until the state
636-
dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in the
637-
evaluation.
638-
"""
639-
return "weight_format" not in i[0]
784+
# TODO: Uncomment "filter_for_weight_format()" method definition and enable its usage once GPU tests are set up.
785+
# def filter_for_weight_format(i):
786+
# """Remove bitsandbytes metadata keys added on state dict creation.
787+
#
788+
# 8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict.
789+
# These contain strings that are used to reshape quantized tensors, however these have no impact until the state
790+
# dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in
791+
# the evaluation.
792+
# """
793+
# return "weight_format" not in i[0]
640794

641-
model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items())
642-
model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items())
795+
# model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items())
796+
# model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items())
643797

644798
# Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6
645-
for key_item_1, key_item_2 in zip(model_1_filtered_state_dict, model_2_filtered_state_dict):
799+
800+
if model_1.__class__.__name__ != model_2.__class__.__name__:
801+
return False
802+
803+
if (
804+
hasattr(model_1, "model")
805+
and hasattr(model_2, "model")
806+
and not _compare_models(model_1=model_1.model, model_2=model_2.model)
807+
):
808+
return False
809+
810+
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
646811
if not torch.equal(key_item_1[1], key_item_2[1]):
647812
return False
813+
648814
return True
649815

650816

0 commit comments

Comments
 (0)