-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit bc27d70
Showing
20 changed files
with
1,372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
.DS_STORE | ||
.idea | ||
.cache/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Recovering the Pre-Fine-Tuning Weights of Generative Models | ||
Official PyTorch Implementation for the "Recovering the Pre-Fine-Tuning Weights of Generative Models" paper. | ||
<p align="center"> | ||
🌐 <a href="https://vision.huji.ac.il/spectral_detuning/" target="_blank">Project</a> | 📃 <a href="http://arxiv.org/abs/" target="_blank">Paper - Later Today</a> | 🤗 <a href="https://huggingface.co/datasets/Eliahu/LoWRA-Bench" target="_blank">Dataset</a> <br> | ||
</p> | ||
|
||
![](imgs/header.gif) | ||
|
||
|
||
***Pre-Fine-Tuning Weight Recovery Attack Setting:*** We uncover a vulnerability in LoRA fine-tuned models wherein an attacker is | ||
able to undo the fine-tuning process and recover the weights of the original pre-trained model. | ||
The setting for the vulnerability is as follows: | ||
|
||
(a) The attacker only has access to n different LoRA fine-tuned models. | ||
|
||
(b) The attacker assumes that all n models originated from the same source model. | ||
|
||
(c) Using only the n visible models, the attacker attempts to recover the original source model. | ||
|
||
Our method, *Spectral DeTuning*, can perform the attack in an unsupervised and data-free manner on real models such as Stable Diffusion and Mistral. | ||
For simplicity, we illustrate the attack on a single layer, in reality, the attack is carried out independently on all the fine-tuned layers. | ||
|
||
**Note: The attacker has no access to the low-rank decomposition of the fine-tuned models.** | ||
___ | ||
|
||
> **Recovering the Pre-Fine-Tuning Weights of Generative Models**<br> | ||
> Eliahu Horwitz, Jonathan Kahana, Yedid Hoshen<br> | ||
> <a href="http://arxiv.org/abs/" target="_blank">http://arxiv.org/abs/ </a> <br> | ||
> | ||
>**Abstract:** The dominant paradigm in generative modeling consists of two steps: | ||
> i) pre-training on a large-scale but unsafe dataset, ii) aligning the pre-trained model with human values via fine-tuning. | ||
> This practice is considered safe, as no current method can recover the unsafe, *pre-fine-tuning* model weights. | ||
> In this paper, we demonstrate that this assumption is often false. Concretely, we present *Spectral DeTuning*, | ||
> a method that can recover the weights of the pre-fine-tuning model using a few low-rank (LoRA) fine-tuned models. | ||
> In contrast to previous attacks that attempt to recover pre-fine-tuning capabilities, | ||
> our method aims to recover the exact pre-fine-tuning weights. | ||
> Our approach exploits this new vulnerability against large-scale models such as a personalized Stable Diffusion and an aligned Mistral. | ||
|
||
## Project Structure | ||
This project consists of: | ||
- `spectral_detuning.py` - main file for recovering the Pre-FT weights using Spectral DeTuning. | ||
- `distributed_spectral_detuning.py` - Distributing Spectral DeTuning across multiple CPU cores of a single machine. | ||
- `increase_rank_on_plateau_scheduler.py` - rank scheduler class. | ||
- [`slurm`](./slurm/) - Examples for distributing Spectral DeTuning across a slurm cluster. | ||
- [`lowra_bench`](./lowra_bench/) - Scripts for running inference and evaluation of the recovered weights. | ||
|
||
|
||
## Installation | ||
1. Clone the repo: | ||
```bash | ||
git clone https://github.com/eliahuhorwitz/spectral_detuning.git | ||
cd spectral_detuning | ||
``` | ||
2. Create a new environment and install the libraries: | ||
```bash | ||
python3 -m venv spectral_detuning_venv | ||
source spectral_detuning_venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
|
||
## Running Spectral DeTuning for Pre-Fine-Tuning Weight Recovery | ||
The `spectral_detuning.py` script is the main script in this project. | ||
It handles the downloading of the LoWRA Bench dataset that is hosted | ||
on Hugging Face. | ||
|
||
Below are examples for running runs Spectral DeTuning for Pre-FT weight recovery on the | ||
LoWRA Bench dataset subset using different distribution strategies. | ||
|
||
### Single GPU Execution | ||
These use a single GPU to recover all the layers one by one *sequentially*. | ||
|
||
#### ViT | ||
```bash | ||
python spectral_detuning.py --subset="vit" --output_path="./recovered_weights/vit/" \ | ||
--start_layer=0 --n_layers_to_recover=-1 --sched_end_rank=16 --n_loras=5 | ||
``` | ||
> [!TIP] | ||
> ViT contains 24 layers to recover and can be recovered *sequentially* in a few minutes on a desktop grade GPU. | ||
#### Stable Diffusion | ||
```bash | ||
python spectral_detuning.py --subset="stable-diffusion-1.5" \ | ||
--output_path="./recovered_weights/stable_diffusion_15/" --start_layer=0 \ | ||
--n_layers_to_recover=-1 --sched_end_rank=32 --n_loras=5 | ||
``` | ||
> [!IMPORTANT] | ||
> Stable Diffusion contains 264 layers to recover. See below for a faster option. | ||
#### Mistral SFT | ||
```bash | ||
python spectral_detuning.py --subset="mistral-7b-v0.1-sft" \ | ||
--output_path="./recovered_weights/mistral7b_01_sft/" --start_layer=0 \ | ||
--n_layers_to_recover=-1 --sched_end_rank=64 --n_loras=12 --n_iters=1000 | ||
``` | ||
|
||
#### Mistral DPO | ||
```bash | ||
python spectral_detuning.py --subset="mistral-7b-v0.1-dpo" \ | ||
--output_path="./recovered_weights/mistral7b_01_dpo/" --start_layer=0 \ | ||
--n_layers_to_recover=-1 --sched_end_rank=64 --n_loras=8 --n_iters=1000 | ||
``` | ||
> [!IMPORTANT] | ||
> Mistral contains 128 layers to recover, some of them are of high dimensions (up to 4096x4096), see below for a faster option. | ||
|
||
### Distributed Multiprocess CPU Execution | ||
Since Spectral DeTuning does not require gradients or running | ||
inference on the model, it can run quickly even on a CPU. | ||
Below are options for distributing Spectral DeTuning across the CPU cores | ||
of a single machine using multiple processes. | ||
|
||
To run using this strategy, run `distributed_spectral_detuning.py` with the same arguments as above. | ||
To control the number of CPU cores to distribute across use the `--n_cpus` argument, | ||
set `--n_cpus=-1` to use all available core. | ||
> [!TIP] | ||
> ViT contains 24 layers to recover and can be recovered in minutes when distributed across desktop CPU cores. | ||
|
||
### Distributed Execution on a Compute Cluster | ||
In cases where the model has many layers (e.g., Stable Diffusion and Mistral), | ||
it is recommended to distribute the recovery across a compute cluster (GPU or CPU). | ||
We provide example slurm scripts under the [`slurm`](./slurm/) dir. | ||
|
||
The main difference is the `--n_layers_to_recover` argument which controls how many layers | ||
each machine will recover. | ||
|
||
> [!TIP] | ||
> Spectral DeTuning can recover a *single layer* of a large model (e.g. Mistral-7B) | ||
> in under 5 minutes on a *single desktop GPU* (e.g. RTX2080). | ||
> The recovery speed of the entire model is a function of the number of machines in your cluster. | ||
|
||
|
||
## Using the Recovered Pre-Fine-Tuning Weights | ||
To run inference on the Pre-FT recovered weights use the following scripts: | ||
#### ViT: | ||
```bash | ||
python lowra_bench/inference/vit_inference.py --input_path="./recovered_weights/vit/" | ||
``` | ||
|
||
#### Stable Diffusion: | ||
```bash | ||
python lowra_bench/inference/stable_diffusion_inference.py \ | ||
--input_path="./recovered_weights/stable_diffusion/" | ||
``` | ||
|
||
#### Mistral SFT: | ||
```bash | ||
python lowra_bench/inference/mistral_inference.py \ | ||
--input_path="./recovered_weights/mistral7b_01_sft/" --subset="mistral-7b-v0.1-sft" | ||
``` | ||
|
||
#### Mistral DPO: | ||
```bash | ||
python lowra_bench/inference/mistral_inference.py \ | ||
--input_path="./recovered_weights/mistral7b_01_dpo/" --subset="mistral-7b-v0.1-dpo" | ||
``` | ||
|
||
|
||
## Using a Custom Dataset of Fine-tuned LoRAs and Pre-FT Models | ||
Coming soon... | ||
- [ ] Preprocessing scripts for constructing a LoRA dataset similar to the LoWRA Bench one. | ||
|
||
|
||
|
||
## Citation | ||
If you find this useful for your research, please use the following. | ||
|
||
``` | ||
``` | ||
|
||
|
||
## Acknowledgments | ||
- The project makes extensive use of the different Hugging Face libraries (e.g. [Diffusers](https://huggingface.co/docs/diffusers/en/index), [PEFT](https://huggingface.co/docs/peft/en/index), [Transformers](https://huggingface.co/docs/transformers/en/index)). | ||
- The [LoWRA Bench dataset](https://huggingface.co/datasets/Eliahu/LoWRA-Bench) is hosted on Hugging Face. | ||
- The fine-tuning of Mistral was performed based on the Zephyr model as seen [here](https://github.com/huggingface/alignment-handbook/tree/main). | ||
- The fine-tuned LoRA models for Stable Diffusion are taken from civitai and were fine-tuned by [RalFinger](https://civitai.com/user/RalFinger). | ||
- The rank scheduler is based on the PyTorch [ReduceLROnPlateau Scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import multiprocessing | ||
from spectral_detuning import * | ||
|
||
def get_recovery_args(args): | ||
# Note: Load the huggingface dataset | ||
dataset = load_dataset(args.dataset, name=args.subset, cache_dir=args.cache_dir) | ||
dataset = dataset.with_format("torch")["train"] | ||
layer_file_ids = list(range(0, len(dataset))) | ||
if args.n_layers_to_recover == -1: | ||
distributed_end_idx = len(layer_file_ids) | ||
else: | ||
distributed_end_idx = min(args.start_layer + args.n_layers_to_recover, len(layer_file_ids)) | ||
layer_file_ids = layer_file_ids[args.start_layer: distributed_end_idx] | ||
device = torch.device("cpu") # Note: Force CPU for distributed execution on the local CPU | ||
recovery_args = [(args, layer_idx, device) for layer_idx in layer_file_ids] | ||
return recovery_args | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
parser = define_args() | ||
parser.add_argument("--n_cpus", type=int, default=-1, help="number of CPU cores to distribute across, -1 to use all available core") | ||
args = parser.parse_args() | ||
|
||
fix_seeds(args) | ||
os.makedirs(args.output_path, exist_ok=True) | ||
|
||
total_n_loras = 15 # Note: In the LoWRA Bench dataset, each subset has 15 different loras | ||
if len(args.lora_ids) == 0: | ||
args.lora_ids = random.sample(range(total_n_loras), args.n_loras) | ||
|
||
recovery_args = get_recovery_args(args) | ||
if args.n_cpus == -1: # Note: Use all available CPU cores | ||
args.n_cpus = multiprocessing.cpu_count() - 1 | ||
print(f"Starting multiprocessing pool with {args.n_cpus} processes...") | ||
|
||
pool = multiprocessing.Pool(processes=args.n_cpus) | ||
pool.starmap(func=recover_layer, iterable=recovery_args) | ||
pool.close() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Based on PyTorch's ReduceLROnPlateau scheduler (https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html) | ||
from torch import inf | ||
|
||
|
||
class IncreaseRankOnPlateau: | ||
def __init__(self, n_iters, end_rank, start_rank=1, factor=2, patience=15, factor_type='mult', | ||
force_end_rank_percent=0.5, threshold=0, threshold_mode='rel', cooldown=0, mode='min', verbose=False, logger=None): | ||
self.n_iters = n_iters | ||
self.force_end_rank_percent = force_end_rank_percent | ||
self.force_end_rank_step = n_iters - (n_iters * force_end_rank_percent) | ||
self.factor_type = factor_type | ||
self.factor = factor | ||
self.start_rank = start_rank | ||
self.curr_rank = start_rank | ||
self.end_rank = end_rank | ||
self.logger = logger | ||
|
||
self.patience = patience | ||
self.verbose = verbose | ||
self.cooldown = cooldown | ||
self.cooldown_counter = 0 | ||
self.mode = mode | ||
self.threshold = threshold | ||
self.threshold_mode = threshold_mode | ||
self.best = None | ||
self.num_bad_epochs = None | ||
self.mode_worse = None # the worse value for the chosen mode | ||
self.last_epoch = 0 | ||
self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode) | ||
self._reset() | ||
|
||
def _reset(self): | ||
"""Resets num_bad_epochs counter and cooldown counter.""" | ||
self.best = self.mode_worse | ||
self.cooldown_counter = 0 | ||
self.num_bad_epochs = 0 | ||
|
||
def step(self, metrics, epoch=None): | ||
# convert `metrics` to float, in case it's a zero-dim Tensor | ||
|
||
current = float(metrics) | ||
if epoch is None: | ||
epoch = self.last_epoch + 1 | ||
self.last_epoch = epoch | ||
if epoch < self.force_end_rank_step: | ||
if self.is_better(current, self.best): | ||
self.best = current | ||
self.num_bad_epochs = 0 | ||
else: | ||
self.num_bad_epochs += 1 | ||
|
||
if self.in_cooldown: | ||
self.cooldown_counter -= 1 | ||
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown | ||
|
||
if self.num_bad_epochs > self.patience: | ||
self._increase_rank(epoch, current) | ||
self.cooldown_counter = self.cooldown | ||
self.num_bad_epochs = 0 | ||
else: | ||
old_rank = self.curr_rank | ||
new_rank = self.end_rank | ||
self.curr_rank = new_rank | ||
if new_rank > old_rank and self.verbose: | ||
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch | ||
if self.logger is not None: | ||
self.logger.info(f'Epoch {epoch_str}: forcing end rank (end_rank={new_rank}). Current loss={current}') | ||
else: | ||
print(f'Epoch {epoch_str}: forcing end rank (end_rank={new_rank}). Current loss={current}') | ||
|
||
def _increase_rank(self, epoch, loss): | ||
old_rank = self.curr_rank | ||
if self.factor_type == "mult": | ||
new_rank = min(old_rank * self.factor, self.end_rank) | ||
elif self.factor_type == "add": | ||
new_rank = min(old_rank + 1, self.end_rank) | ||
self.curr_rank = new_rank | ||
if new_rank > old_rank and self.verbose: | ||
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch | ||
if self.logger is not None: | ||
self.logger.info(f'Epoch {epoch_str}: increasing rank to {new_rank}. Current loss={loss}') | ||
else: | ||
print(f'Epoch {epoch_str}: increasing rank to {new_rank}. Current loss={loss}') | ||
|
||
@property | ||
def in_cooldown(self): | ||
return self.cooldown_counter > 0 | ||
|
||
def is_better(self, a, best): | ||
if self.mode == 'min' and self.threshold_mode == 'rel': | ||
rel_epsilon = 1. - self.threshold | ||
return a < best * rel_epsilon | ||
|
||
elif self.mode == 'min' and self.threshold_mode == 'abs': | ||
return a < best - self.threshold | ||
|
||
elif self.mode == 'max' and self.threshold_mode == 'rel': | ||
rel_epsilon = self.threshold + 1. | ||
return a > best * rel_epsilon | ||
|
||
else: # mode == 'max' and epsilon_mode == 'abs': | ||
return a > best + self.threshold | ||
|
||
def _init_is_better(self, mode, threshold, threshold_mode): | ||
if mode not in {'min', 'max'}: | ||
raise ValueError('mode ' + mode + ' is unknown!') | ||
if threshold_mode not in {'rel', 'abs'}: | ||
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') | ||
|
||
if mode == 'min': | ||
self.mode_worse = inf | ||
else: # mode == 'max': | ||
self.mode_worse = -inf | ||
|
||
self.mode = mode | ||
self.threshold = threshold | ||
self.threshold_mode = threshold_mode |
Oops, something went wrong.