Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eliahuhorwitz committed Feb 15, 2024
0 parents commit bc27d70
Show file tree
Hide file tree
Showing 20 changed files with 1,372 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.DS_STORE
.idea
.cache/

182 changes: 182 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Recovering the Pre-Fine-Tuning Weights of Generative Models
Official PyTorch Implementation for the "Recovering the Pre-Fine-Tuning Weights of Generative Models" paper.
<p align="center">
🌐 <a href="https://vision.huji.ac.il/spectral_detuning/" target="_blank">Project</a> | 📃 <a href="http://arxiv.org/abs/" target="_blank">Paper - Later Today</a> | 🤗 <a href="https://huggingface.co/datasets/Eliahu/LoWRA-Bench" target="_blank">Dataset</a> <br>
</p>

![](imgs/header.gif)


***Pre-Fine-Tuning Weight Recovery Attack Setting:*** We uncover a vulnerability in LoRA fine-tuned models wherein an attacker is
able to undo the fine-tuning process and recover the weights of the original pre-trained model.
The setting for the vulnerability is as follows:

(a) The attacker only has access to n different LoRA fine-tuned models.

(b) The attacker assumes that all n models originated from the same source model.

(c) Using only the n visible models, the attacker attempts to recover the original source model.

Our method, *Spectral DeTuning*, can perform the attack in an unsupervised and data-free manner on real models such as Stable Diffusion and Mistral.
For simplicity, we illustrate the attack on a single layer, in reality, the attack is carried out independently on all the fine-tuned layers.

**Note: The attacker has no access to the low-rank decomposition of the fine-tuned models.**
___

> **Recovering the Pre-Fine-Tuning Weights of Generative Models**<br>
> Eliahu Horwitz, Jonathan Kahana, Yedid Hoshen<br>
> <a href="http://arxiv.org/abs/" target="_blank">http://arxiv.org/abs/ </a> <br>
>
>**Abstract:** The dominant paradigm in generative modeling consists of two steps:
> i) pre-training on a large-scale but unsafe dataset, ii) aligning the pre-trained model with human values via fine-tuning.
> This practice is considered safe, as no current method can recover the unsafe, *pre-fine-tuning* model weights.
> In this paper, we demonstrate that this assumption is often false. Concretely, we present *Spectral DeTuning*,
> a method that can recover the weights of the pre-fine-tuning model using a few low-rank (LoRA) fine-tuned models.
> In contrast to previous attacks that attempt to recover pre-fine-tuning capabilities,
> our method aims to recover the exact pre-fine-tuning weights.
> Our approach exploits this new vulnerability against large-scale models such as a personalized Stable Diffusion and an aligned Mistral.

## Project Structure
This project consists of:
- `spectral_detuning.py` - main file for recovering the Pre-FT weights using Spectral DeTuning.
- `distributed_spectral_detuning.py` - Distributing Spectral DeTuning across multiple CPU cores of a single machine.
- `increase_rank_on_plateau_scheduler.py` - rank scheduler class.
- [`slurm`](./slurm/) - Examples for distributing Spectral DeTuning across a slurm cluster.
- [`lowra_bench`](./lowra_bench/) - Scripts for running inference and evaluation of the recovered weights.


## Installation
1. Clone the repo:
```bash
git clone https://github.com/eliahuhorwitz/spectral_detuning.git
cd spectral_detuning
```
2. Create a new environment and install the libraries:
```bash
python3 -m venv spectral_detuning_venv
source spectral_detuning_venv/bin/activate
pip install -r requirements.txt
```



## Running Spectral DeTuning for Pre-Fine-Tuning Weight Recovery
The `spectral_detuning.py` script is the main script in this project.
It handles the downloading of the LoWRA Bench dataset that is hosted
on Hugging Face.

Below are examples for running runs Spectral DeTuning for Pre-FT weight recovery on the
LoWRA Bench dataset subset using different distribution strategies.

### Single GPU Execution
These use a single GPU to recover all the layers one by one *sequentially*.

#### ViT
```bash
python spectral_detuning.py --subset="vit" --output_path="./recovered_weights/vit/" \
--start_layer=0 --n_layers_to_recover=-1 --sched_end_rank=16 --n_loras=5
```
> [!TIP]
> ViT contains 24 layers to recover and can be recovered *sequentially* in a few minutes on a desktop grade GPU.
#### Stable Diffusion
```bash
python spectral_detuning.py --subset="stable-diffusion-1.5" \
--output_path="./recovered_weights/stable_diffusion_15/" --start_layer=0 \
--n_layers_to_recover=-1 --sched_end_rank=32 --n_loras=5
```
> [!IMPORTANT]
> Stable Diffusion contains 264 layers to recover. See below for a faster option.
#### Mistral SFT
```bash
python spectral_detuning.py --subset="mistral-7b-v0.1-sft" \
--output_path="./recovered_weights/mistral7b_01_sft/" --start_layer=0 \
--n_layers_to_recover=-1 --sched_end_rank=64 --n_loras=12 --n_iters=1000
```

#### Mistral DPO
```bash
python spectral_detuning.py --subset="mistral-7b-v0.1-dpo" \
--output_path="./recovered_weights/mistral7b_01_dpo/" --start_layer=0 \
--n_layers_to_recover=-1 --sched_end_rank=64 --n_loras=8 --n_iters=1000
```
> [!IMPORTANT]
> Mistral contains 128 layers to recover, some of them are of high dimensions (up to 4096x4096), see below for a faster option.

### Distributed Multiprocess CPU Execution
Since Spectral DeTuning does not require gradients or running
inference on the model, it can run quickly even on a CPU.
Below are options for distributing Spectral DeTuning across the CPU cores
of a single machine using multiple processes.

To run using this strategy, run `distributed_spectral_detuning.py` with the same arguments as above.
To control the number of CPU cores to distribute across use the `--n_cpus` argument,
set `--n_cpus=-1` to use all available core.
> [!TIP]
> ViT contains 24 layers to recover and can be recovered in minutes when distributed across desktop CPU cores.

### Distributed Execution on a Compute Cluster
In cases where the model has many layers (e.g., Stable Diffusion and Mistral),
it is recommended to distribute the recovery across a compute cluster (GPU or CPU).
We provide example slurm scripts under the [`slurm`](./slurm/) dir.

The main difference is the `--n_layers_to_recover` argument which controls how many layers
each machine will recover.

> [!TIP]
> Spectral DeTuning can recover a *single layer* of a large model (e.g. Mistral-7B)
> in under 5 minutes on a *single desktop GPU* (e.g. RTX2080).
> The recovery speed of the entire model is a function of the number of machines in your cluster.


## Using the Recovered Pre-Fine-Tuning Weights
To run inference on the Pre-FT recovered weights use the following scripts:
#### ViT:
```bash
python lowra_bench/inference/vit_inference.py --input_path="./recovered_weights/vit/"
```

#### Stable Diffusion:
```bash
python lowra_bench/inference/stable_diffusion_inference.py \
--input_path="./recovered_weights/stable_diffusion/"
```

#### Mistral SFT:
```bash
python lowra_bench/inference/mistral_inference.py \
--input_path="./recovered_weights/mistral7b_01_sft/" --subset="mistral-7b-v0.1-sft"
```

#### Mistral DPO:
```bash
python lowra_bench/inference/mistral_inference.py \
--input_path="./recovered_weights/mistral7b_01_dpo/" --subset="mistral-7b-v0.1-dpo"
```


## Using a Custom Dataset of Fine-tuned LoRAs and Pre-FT Models
Coming soon...
- [ ] Preprocessing scripts for constructing a LoRA dataset similar to the LoWRA Bench one.



## Citation
If you find this useful for your research, please use the following.

```
```


## Acknowledgments
- The project makes extensive use of the different Hugging Face libraries (e.g. [Diffusers](https://huggingface.co/docs/diffusers/en/index), [PEFT](https://huggingface.co/docs/peft/en/index), [Transformers](https://huggingface.co/docs/transformers/en/index)).
- The [LoWRA Bench dataset](https://huggingface.co/datasets/Eliahu/LoWRA-Bench) is hosted on Hugging Face.
- The fine-tuning of Mistral was performed based on the Zephyr model as seen [here](https://github.com/huggingface/alignment-handbook/tree/main).
- The fine-tuned LoRA models for Stable Diffusion are taken from civitai and were fine-tuned by [RalFinger](https://civitai.com/user/RalFinger).
- The rank scheduler is based on the PyTorch [ReduceLROnPlateau Scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html).
39 changes: 39 additions & 0 deletions distributed_spectral_detuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import multiprocessing
from spectral_detuning import *

def get_recovery_args(args):
# Note: Load the huggingface dataset
dataset = load_dataset(args.dataset, name=args.subset, cache_dir=args.cache_dir)
dataset = dataset.with_format("torch")["train"]
layer_file_ids = list(range(0, len(dataset)))
if args.n_layers_to_recover == -1:
distributed_end_idx = len(layer_file_ids)
else:
distributed_end_idx = min(args.start_layer + args.n_layers_to_recover, len(layer_file_ids))
layer_file_ids = layer_file_ids[args.start_layer: distributed_end_idx]
device = torch.device("cpu") # Note: Force CPU for distributed execution on the local CPU
recovery_args = [(args, layer_idx, device) for layer_idx in layer_file_ids]
return recovery_args



if __name__ == '__main__':
parser = define_args()
parser.add_argument("--n_cpus", type=int, default=-1, help="number of CPU cores to distribute across, -1 to use all available core")
args = parser.parse_args()

fix_seeds(args)
os.makedirs(args.output_path, exist_ok=True)

total_n_loras = 15 # Note: In the LoWRA Bench dataset, each subset has 15 different loras
if len(args.lora_ids) == 0:
args.lora_ids = random.sample(range(total_n_loras), args.n_loras)

recovery_args = get_recovery_args(args)
if args.n_cpus == -1: # Note: Use all available CPU cores
args.n_cpus = multiprocessing.cpu_count() - 1
print(f"Starting multiprocessing pool with {args.n_cpus} processes...")

pool = multiprocessing.Pool(processes=args.n_cpus)
pool.starmap(func=recover_layer, iterable=recovery_args)
pool.close()
Binary file added imgs/header.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
117 changes: 117 additions & 0 deletions increase_rank_on_plateau_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Based on PyTorch's ReduceLROnPlateau scheduler (https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html)
from torch import inf


class IncreaseRankOnPlateau:
def __init__(self, n_iters, end_rank, start_rank=1, factor=2, patience=15, factor_type='mult',
force_end_rank_percent=0.5, threshold=0, threshold_mode='rel', cooldown=0, mode='min', verbose=False, logger=None):
self.n_iters = n_iters
self.force_end_rank_percent = force_end_rank_percent
self.force_end_rank_step = n_iters - (n_iters * force_end_rank_percent)
self.factor_type = factor_type
self.factor = factor
self.start_rank = start_rank
self.curr_rank = start_rank
self.end_rank = end_rank
self.logger = logger

self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0
self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
self.best = None
self.num_bad_epochs = None
self.mode_worse = None # the worse value for the chosen mode
self.last_epoch = 0
self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode)
self._reset()

def _reset(self):
"""Resets num_bad_epochs counter and cooldown counter."""
self.best = self.mode_worse
self.cooldown_counter = 0
self.num_bad_epochs = 0

def step(self, metrics, epoch=None):
# convert `metrics` to float, in case it's a zero-dim Tensor

current = float(metrics)
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
if epoch < self.force_end_rank_step:
if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown

if self.num_bad_epochs > self.patience:
self._increase_rank(epoch, current)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
else:
old_rank = self.curr_rank
new_rank = self.end_rank
self.curr_rank = new_rank
if new_rank > old_rank and self.verbose:
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
if self.logger is not None:
self.logger.info(f'Epoch {epoch_str}: forcing end rank (end_rank={new_rank}). Current loss={current}')
else:
print(f'Epoch {epoch_str}: forcing end rank (end_rank={new_rank}). Current loss={current}')

def _increase_rank(self, epoch, loss):
old_rank = self.curr_rank
if self.factor_type == "mult":
new_rank = min(old_rank * self.factor, self.end_rank)
elif self.factor_type == "add":
new_rank = min(old_rank + 1, self.end_rank)
self.curr_rank = new_rank
if new_rank > old_rank and self.verbose:
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
if self.logger is not None:
self.logger.info(f'Epoch {epoch_str}: increasing rank to {new_rank}. Current loss={loss}')
else:
print(f'Epoch {epoch_str}: increasing rank to {new_rank}. Current loss={loss}')

@property
def in_cooldown(self):
return self.cooldown_counter > 0

def is_better(self, a, best):
if self.mode == 'min' and self.threshold_mode == 'rel':
rel_epsilon = 1. - self.threshold
return a < best * rel_epsilon

elif self.mode == 'min' and self.threshold_mode == 'abs':
return a < best - self.threshold

elif self.mode == 'max' and self.threshold_mode == 'rel':
rel_epsilon = self.threshold + 1.
return a > best * rel_epsilon

else: # mode == 'max' and epsilon_mode == 'abs':
return a > best + self.threshold

def _init_is_better(self, mode, threshold, threshold_mode):
if mode not in {'min', 'max'}:
raise ValueError('mode ' + mode + ' is unknown!')
if threshold_mode not in {'rel', 'abs'}:
raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

if mode == 'min':
self.mode_worse = inf
else: # mode == 'max':
self.mode_worse = -inf

self.mode = mode
self.threshold = threshold
self.threshold_mode = threshold_mode
Loading

0 comments on commit bc27d70

Please sign in to comment.