forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add training code of gligen * fix code quality tests. --------- Co-authored-by: Sayak Paul <[email protected]>
- Loading branch information
Showing
7 changed files
with
1,312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# GLIGEN: Open-Set Grounded Text-to-Image Generation | ||
|
||
These scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset. | ||
|
||
### Install the requirements | ||
|
||
```bash | ||
conda create -n diffusers python==3.10 | ||
conda activate diffusers | ||
pip install -r requirements.txt | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
Or for a default accelerate configuration without answering questions about your environment | ||
|
||
```bash | ||
accelerate config default | ||
``` | ||
|
||
Or if your environment doesn't support an interactive shell e.g. a notebook | ||
|
||
```python | ||
from accelerate.utils import write_basic_config | ||
|
||
write_basic_config() | ||
``` | ||
|
||
### Prepare the training data | ||
|
||
If you want to make your own grounding data, you need to install the requirements. | ||
|
||
I used [RAM](https://github.com/xinyu1205/recognize-anything) to tag | ||
images, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects, | ||
and [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances. | ||
|
||
Only RAM needs to be installed manually: | ||
|
||
```bash | ||
pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps | ||
``` | ||
|
||
Download the pre-trained model: | ||
|
||
```bash | ||
huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth | ||
huggingface-cli download --resume-download IDEA-Research/grounding-dino-base | ||
huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl | ||
huggingface-cli download --resume-download clip-vit-large-patch14 | ||
huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box | ||
``` | ||
|
||
Make the training data on 8 GPUs: | ||
|
||
```bash | ||
torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \ | ||
--data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \ | ||
--save_root /root/gligen_data \ | ||
--ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth | ||
``` | ||
|
||
You can download the COCO training data from | ||
|
||
```bash | ||
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth | ||
``` | ||
|
||
It's in the format of | ||
|
||
```json | ||
[ | ||
... | ||
{ | ||
'file_path': Path, | ||
'annos': [ | ||
{ | ||
'caption': Instance | ||
Caption, | ||
'bbox': bbox | ||
in | ||
xyxy, | ||
'text_embeddings_before_projection': CLIP | ||
text | ||
embedding | ||
before | ||
linear | ||
projection | ||
} | ||
] | ||
} | ||
... | ||
] | ||
``` | ||
|
||
### Training commands | ||
|
||
The training script is heavily based | ||
on https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py | ||
|
||
```bash | ||
accelerate launch train_gligen_text.py \ | ||
--data_path /root/data/zhizhonghuang/coco_train2017.pth \ | ||
--image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \ | ||
--train_batch_size 8 \ | ||
--max_train_steps 100000 \ | ||
--checkpointing_steps 1000 \ | ||
--checkpoints_total_limit 10 \ | ||
--learning_rate 5e-5 \ | ||
--dataloader_num_workers 16 \ | ||
--mixed_precision fp16 \ | ||
--report_to wandb \ | ||
--tracker_project_name gligen \ | ||
--output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO | ||
``` | ||
|
||
I trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the | ||
layout possibly at 50k iterations. | ||
|
||
Note that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`) | ||
|
||
The trained model can be downloaded from | ||
|
||
```bash | ||
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors | ||
``` | ||
|
||
You can run `demo.ipynb` to visualize the generated images. | ||
|
||
Example prompts: | ||
|
||
```python | ||
prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky' | ||
boxes = [[0.041015625, 0.548828125, 0.453125, 0.859375], | ||
[0.525390625, 0.552734375, 0.93359375, 0.865234375], | ||
[0.12890625, 0.015625, 0.412109375, 0.279296875], | ||
[0.578125, 0.08203125, 0.857421875, 0.27734375]] | ||
gligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird'] | ||
``` | ||
|
||
Example images: | ||
![alt text](generated-images-100000-00.png) | ||
|
||
### Citation | ||
|
||
``` | ||
@article{li2023gligen, | ||
title={GLIGEN: Open-Set Grounded Text-to-Image Generation}, | ||
author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae}, | ||
journal={CVPR}, | ||
year={2023} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
import random | ||
|
||
import torch | ||
import torchvision.transforms as transforms | ||
from PIL import Image | ||
|
||
|
||
def recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size): | ||
scale = image_size / min(original_image_size) | ||
crop_y = (original_image_size[1] * scale - image_size) // 2 | ||
crop_x = (original_image_size[0] * scale - image_size) // 2 | ||
x0 = max(x * scale - crop_x, 0) | ||
y0 = max(y * scale - crop_y, 0) | ||
x1 = min((x + w) * scale - crop_x, image_size) | ||
y1 = min((y + h) * scale - crop_y, image_size) | ||
if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size: | ||
return False, (None, None, None, None) | ||
return True, (x0, y0, x1, y1) | ||
|
||
|
||
class COCODataset(torch.utils.data.Dataset): | ||
def __init__( | ||
self, | ||
data_path, | ||
image_path, | ||
image_size=512, | ||
min_box_size=0.01, | ||
max_boxes_per_data=8, | ||
tokenizer=None, | ||
): | ||
super().__init__() | ||
self.min_box_size = min_box_size | ||
self.max_boxes_per_data = max_boxes_per_data | ||
self.image_size = image_size | ||
self.image_path = image_path | ||
self.tokenizer = tokenizer | ||
self.transforms = transforms.Compose( | ||
[ | ||
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR), | ||
transforms.CenterCrop(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.5], [0.5]), | ||
] | ||
) | ||
|
||
self.data_list = torch.load(data_path, map_location="cpu") | ||
|
||
def __getitem__(self, index): | ||
if self.max_boxes_per_data > 99: | ||
assert False, "Are you sure setting such large number of boxes per image?" | ||
|
||
out = {} | ||
|
||
data = self.data_list[index] | ||
image = Image.open(os.path.join(self.image_path, data["file_path"])).convert("RGB") | ||
original_image_size = image.size | ||
out["pixel_values"] = self.transforms(image) | ||
|
||
annos = data["annos"] | ||
|
||
areas, valid_annos = [], [] | ||
for anno in annos: | ||
# x, y, w, h = anno['bbox'] | ||
x0, y0, x1, y1 = anno["bbox"] | ||
x, y, w, h = x0, y0, x1 - x0, y1 - y0 | ||
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid( | ||
x, y, w, h, self.image_size, original_image_size, self.min_box_size | ||
) | ||
if valid: | ||
anno["bbox"] = [x0, y0, x1, y1] | ||
areas.append((x1 - x0) * (y1 - y0)) | ||
valid_annos.append(anno) | ||
|
||
# Sort according to area and choose the largest N objects | ||
wanted_idxs = torch.tensor(areas).sort(descending=True)[1] | ||
wanted_idxs = wanted_idxs[: self.max_boxes_per_data] | ||
valid_annos = [valid_annos[i] for i in wanted_idxs] | ||
|
||
out["boxes"] = torch.zeros(self.max_boxes_per_data, 4) | ||
out["masks"] = torch.zeros(self.max_boxes_per_data) | ||
out["text_embeddings_before_projection"] = torch.zeros(self.max_boxes_per_data, 768) | ||
|
||
for i, anno in enumerate(valid_annos): | ||
out["boxes"][i] = torch.tensor(anno["bbox"]) / self.image_size | ||
out["masks"][i] = 1 | ||
out["text_embeddings_before_projection"][i] = anno["text_embeddings_before_projection"] | ||
|
||
prob_drop_boxes = 0.1 | ||
if random.random() < prob_drop_boxes: | ||
out["masks"][:] = 0 | ||
|
||
caption = random.choice(data["captions"]) | ||
|
||
prob_drop_captions = 0.5 | ||
if random.random() < prob_drop_captions: | ||
caption = "" | ||
caption = self.tokenizer( | ||
caption, | ||
max_length=self.tokenizer.model_max_length, | ||
padding="max_length", | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
out["caption"] = caption | ||
|
||
return out | ||
|
||
def __len__(self): | ||
return len(self.data_list) |
Oops, something went wrong.