Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Warning] Merge lora module to 4-bit linear may get different generations #2321

Open
1 of 4 tasks
steveepreston opened this issue Jan 11, 2025 · 10 comments
Open
1 of 4 tasks

Comments

@steveepreston
Copy link

System Info

peft 0.14.0
transformers 4.48.0
bitsandbytes 0.45.0

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

code:

base_model_id = "gemma-2-27b-it"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_storage=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=quantization_config,
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    use_cache=True,
)

peft_model = PeftModel.from_pretrained(base_model, adapter_path)

--> merged_model = peft_model.merge_and_unload()

Warning:


UserWarning: Merge lora module to 4-bit linear may get different generations due to rounding errors.

Expected behavior

merge_and_unload() correctly and without warning.

@steveepreston steveepreston changed the title [Bug] Merge lora module to 4-bit linear may get different generations [Warning] Merge lora module to 4-bit linear may get different generations Jan 11, 2025
@BenjaminBossan
Copy link
Member

There is no way to avoid this. When merging weights in such a low precision regime, rounding errors are unavoidable. We give this warning purely to make users aware of that fact, not because they did anything wrong.

What you can try for your use case is to load the base model without quantization, merge the LoRA weights into the unquantized model, and then quantize the merged model to 4 bit. Please verify afterwards whether this gives better results or not. But the overall issue with low precision will still remain.

@steveepreston
Copy link
Author

Hey @BenjaminBossan

do you mean this?:

  1. Load Quantized Model
  2. SFT Train
  3. Save Adapters
  4. Load Unquantized Model
  5. Merge and unload
  6. Save Merged Model
  7. Inference Quantized Merged

@BenjaminBossan
Copy link
Member

Exactly, just make sure in step 7 that you load the merged model with the intended quantization applied. Some users have reported that this yields better results for them than merging into the quantized weights. But please verify that this is true for your use case (and report back please!).

@steveepreston
Copy link
Author

@BenjaminBossan Thanks for note.

I will try and report here.

@benjamin-marie
Copy link

Step 7. can significantly degrade the results compared to just loading the LoRA adapter on top of the quantized model. That's because the merged LoRA weights are quantized.
I found that quantizing the model at step 6. with AWQ or GPTQ works significantly better than with bitsandbytes.

@BenjaminBossan
Copy link
Member

Thanks for sharing your findings @benjamin-marie. It is true that merging will degrade precision, but it improves runtime performance, so it's a trade off.

I found that quantizing the model at step 6. with AWQ or GPTQ works significantly better than with bitsandbytes.

Do you mean without step 7 (merging) or do you mean that AWQ and GPTQ are better when merging the LoRA weights?

@benjamin-marie
Copy link

  1. Load Quantized Model (bitsandbytes)
  2. SFT Train
  3. Save Adapters
  4. Load Unquantized Model
  5. Merge and unload
  6. Save Merged Model

I agree that all these steps are correct and yield a model that should perform the same as the adapter obtained at the end of SFT.
But then, this model is much larger than the one fine-tuned. Intuitively, quantizing it again with bnb could be optimal since we used it during SFT. In my experiments, it can severely degrade the results. I don't know why. Maybe after merging, the weight distribution is difficult to handle for bnb.

However, quantizing the merged model with GPTQ or AWQ instead of bnb usually yields better results (perplexity) much closer to the unquantized merged model.

@BenjaminBossan
Copy link
Member

Thanks for explaining further. This is probably a topic we should explore further, as it has come up a few times in the past. Ideally, we can collect some best practices and share them in the docs. I'm very interested in running some experiments with different steps and quantization techniques. If you have any code to share (or checkpoints), please feel free to do so.

@benjamin-marie
Copy link

My experiments are almost one year old. I'll rerun some experiments with the updated packages and reevaluate everything. And I'll share the results and a notebook.

@BenjaminBossan
Copy link
Member

Fantastic, thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants