Skip to content

Commit

Permalink
Gligen training (huggingface#7906)
Browse files Browse the repository at this point in the history
* add training code of gligen

* fix code quality tests.

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
Hzzone and sayakpaul authored Jun 5, 2024
1 parent 48207d6 commit d3881f3
Show file tree
Hide file tree
Showing 7 changed files with 1,312 additions and 0 deletions.
156 changes: 156 additions & 0 deletions examples/research_projects/gligen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# GLIGEN: Open-Set Grounded Text-to-Image Generation

These scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset.

### Install the requirements

```bash
conda create -n diffusers python==3.10
conda activate diffusers
pip install -r requirements.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell e.g. a notebook

```python
from accelerate.utils import write_basic_config

write_basic_config()
```

### Prepare the training data

If you want to make your own grounding data, you need to install the requirements.

I used [RAM](https://github.com/xinyu1205/recognize-anything) to tag
images, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects,
and [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances.

Only RAM needs to be installed manually:

```bash
pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
```

Download the pre-trained model:

```bash
huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
huggingface-cli download --resume-download clip-vit-large-patch14
huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
```

Make the training data on 8 GPUs:

```bash
torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
--data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
--save_root /root/gligen_data \
--ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth
```

You can download the COCO training data from

```bash
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
```

It's in the format of

```json
[
...
{
'file_path': Path,
'annos': [
{
'caption': Instance
Caption,
'bbox': bbox
in
xyxy,
'text_embeddings_before_projection': CLIP
text
embedding
before
linear
projection
}
]
}
...
]
```

### Training commands

The training script is heavily based
on https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py

```bash
accelerate launch train_gligen_text.py \
--data_path /root/data/zhizhonghuang/coco_train2017.pth \
--image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
--train_batch_size 8 \
--max_train_steps 100000 \
--checkpointing_steps 1000 \
--checkpoints_total_limit 10 \
--learning_rate 5e-5 \
--dataloader_num_workers 16 \
--mixed_precision fp16 \
--report_to wandb \
--tracker_project_name gligen \
--output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO
```

I trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the
layout possibly at 50k iterations.

Note that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`)

The trained model can be downloaded from

```bash
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
```

You can run `demo.ipynb` to visualize the generated images.

Example prompts:

```python
prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'
boxes = [[0.041015625, 0.548828125, 0.453125, 0.859375],
[0.525390625, 0.552734375, 0.93359375, 0.865234375],
[0.12890625, 0.015625, 0.412109375, 0.279296875],
[0.578125, 0.08203125, 0.857421875, 0.27734375]]
gligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird']
```

Example images:
![alt text](generated-images-100000-00.png)

### Citation

```
@article{li2023gligen,
title={GLIGEN: Open-Set Grounded Text-to-Image Generation},
author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},
journal={CVPR},
year={2023}
}
```
110 changes: 110 additions & 0 deletions examples/research_projects/gligen/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import random

import torch
import torchvision.transforms as transforms
from PIL import Image


def recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size):
scale = image_size / min(original_image_size)
crop_y = (original_image_size[1] * scale - image_size) // 2
crop_x = (original_image_size[0] * scale - image_size) // 2
x0 = max(x * scale - crop_x, 0)
y0 = max(y * scale - crop_y, 0)
x1 = min((x + w) * scale - crop_x, image_size)
y1 = min((y + h) * scale - crop_y, image_size)
if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size:
return False, (None, None, None, None)
return True, (x0, y0, x1, y1)


class COCODataset(torch.utils.data.Dataset):
def __init__(
self,
data_path,
image_path,
image_size=512,
min_box_size=0.01,
max_boxes_per_data=8,
tokenizer=None,
):
super().__init__()
self.min_box_size = min_box_size
self.max_boxes_per_data = max_boxes_per_data
self.image_size = image_size
self.image_path = image_path
self.tokenizer = tokenizer
self.transforms = transforms.Compose(
[
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

self.data_list = torch.load(data_path, map_location="cpu")

def __getitem__(self, index):
if self.max_boxes_per_data > 99:
assert False, "Are you sure setting such large number of boxes per image?"

out = {}

data = self.data_list[index]
image = Image.open(os.path.join(self.image_path, data["file_path"])).convert("RGB")
original_image_size = image.size
out["pixel_values"] = self.transforms(image)

annos = data["annos"]

areas, valid_annos = [], []
for anno in annos:
# x, y, w, h = anno['bbox']
x0, y0, x1, y1 = anno["bbox"]
x, y, w, h = x0, y0, x1 - x0, y1 - y0
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(
x, y, w, h, self.image_size, original_image_size, self.min_box_size
)
if valid:
anno["bbox"] = [x0, y0, x1, y1]
areas.append((x1 - x0) * (y1 - y0))
valid_annos.append(anno)

# Sort according to area and choose the largest N objects
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
wanted_idxs = wanted_idxs[: self.max_boxes_per_data]
valid_annos = [valid_annos[i] for i in wanted_idxs]

out["boxes"] = torch.zeros(self.max_boxes_per_data, 4)
out["masks"] = torch.zeros(self.max_boxes_per_data)
out["text_embeddings_before_projection"] = torch.zeros(self.max_boxes_per_data, 768)

for i, anno in enumerate(valid_annos):
out["boxes"][i] = torch.tensor(anno["bbox"]) / self.image_size
out["masks"][i] = 1
out["text_embeddings_before_projection"][i] = anno["text_embeddings_before_projection"]

prob_drop_boxes = 0.1
if random.random() < prob_drop_boxes:
out["masks"][:] = 0

caption = random.choice(data["captions"])

prob_drop_captions = 0.5
if random.random() < prob_drop_captions:
caption = ""
caption = self.tokenizer(
caption,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
out["caption"] = caption

return out

def __len__(self):
return len(self.data_list)
Loading

0 comments on commit d3881f3

Please sign in to comment.