Zhennan Chen1* Β· Yajie Li1* Β· Haofan Wang2,3 Β· Zhibo Chen3 Β· Zhengkai Jiang4 Β· Jun Li1 Β· Qian Wang5 Β· Jian Yang1 Β·Ying Tai1β
1Nanjing University Β· 2InstantX Team Β· 3Liblib AI Β· 4HKUST Β· 5China Mobile
We present RAG, a Regional-Aware text-to-image Generation method conditioned on regional descriptions for precise layout composition. Regional prompting, or compositional generation, which enables fine-grained spatial control, has gained increasing attention for its practicality in real-world applications. However, previous methods either introduce additional trainable modules, thus only applicable to specific models, or manipulate on score maps within cross-attention layers using attention masks, resulting in limited control strength when the number of regions increases. To handle these limitations, we decouple the multi-region generation into two sub-tasks, the construction of individual region (Regional Hard Binding) that ensures the regional prompt is properly executed, and the overall detail refinement (Regional Soft Refinement) over regions that dismiss the visual boundaries and enhance adjacent interactions. Furthermore, RAG novelly makes repainting feasible, where users can modify specific unsatisfied regions in the last generation while keeping all other regions unchanged, without relying on additional inpainting models. Our approach is tuning-free and applicable to other frameworks as an enhancement to the prompt following property. Quantitative and qualitative experiments demonstrate that RAG achieves superior performance over attribute binding and object relationship than previous tuning-free methods.
- 2024.11.29: π― RAG-Diffusion now supports FLUX.1 Redux!
- 2024.11.27: π’ Repainting code is released.
- 2024.11.20: π RAG Online Demo is Live. Try it now! (Link)
- 2024.11.12: π Our code and technical report are released.
conda create -n RAG python==3.9
conda activate RAG
pip install xformers==0.0.28.post1 diffusers peft torchvision==0.19.1 opencv-python==4.10.0.84 sentencepiece==0.2.0 protobuf==5.28.1 scipy==1.13.1
import torch
from RAG_pipeline_flux import RAG_FluxPipeline
pipe = RAG_FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = "a balloon on the bottom of a dog"
HB_replace = 2
HB_prompt_list = [
"Balloon",
"Dog"
]
HB_m_offset_list = [
0.1,
0.1
]
HB_n_offset_list = [
0.55,
0.05
]
HB_m_scale_list = [
0.8,
0.8
]
HB_n_scale_list = [
0.4,
0.45
]
SR_delta = 1.0
SR_hw_split_ratio = "0.5; 0.5"
SR_prompt = "A playful dog, perhaps a golden retriever, with its ears perked up, sitting on the balloon, giving an enthusiastic demeanor. BREAK A colorful balloon floating gently, its string dangling gracefully, just beneath the dog."
height, width = 1024, 1024
seed = 1234
image = pipe(
SR_delta=SR_delta,
SR_hw_split_ratio=SR_hw_split_ratio,
SR_prompt=SR_prompt,
HB_prompt_list=HB_prompt_list,
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
HB_replace=HB_replace,
seed=seed,
prompt=prompt,
height=height,
width=width,
num_inference_steps=20,
guidance_scale=3.5,
).images[0]
filename = "RAG.png"
image.save(filename)
print(f"Image saved as {filename}")
HB_replace
(int
): The times of hard binding. More times can make the position control more precise, but may lead to obvious boundaries.HB_prompt_list
(List[str]
): Fundamental descriptions for each individual region or object.HB_m_offset_list
,HB_n_offset_list
,HB_m_scale_list
,HB_n_scale_list
(List[float]
): Corresponding to the coordinates of each fundamental prompt in HB_prompt_list.SR_delta
(float
): The fusion strength of image latent and regional-aware local latent. This is a flexible parameter, you can try 0.25, 0.5, 0.75, 1.0.SR_prompt
(str
): Highly descriptive sub-prompts for each individual region or object. Each sub-prompt is separated by BREAK.SR_hw_split_ratio
(str
): The global region divisions correspond to each highly descriptive sub-prompt in SR_prompt.
The following shows several schematic diagrams of `HB_m_offset_list`, `HB_n_offset_list`, `HB_m_scale_list`, `HB_n_scale_list`, `SR_hw_split_ratio`.
import torch
from RAG_pipeline_flux import RAG_FluxPipeline
from RAG_MLLM import local_llm, GPT4
pipe = RAG_FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = "A small elephant on the left and a huge rabbit on the right."
para_dict = GPT4(prompt,key='')
print(para_dict)
HB_replace = 2
HB_prompt_list = para_dict["HB_prompt_list"]
HB_m_offset_list = eval(para_dict["HB_m_offset_list"])
HB_n_offset_list = eval(para_dict["HB_n_offset_list"])
HB_m_scale_list = eval(para_dict["HB_m_scale_list"])
HB_n_scale_list = eval(para_dict["HB_n_scale_list"])
SR_delta = 1.0
SR_hw_split_ratio = para_dict["SR_hw_split_ratio"]
SR_prompt = para_dict["SR_prompt"]
height = 1024
width = 1024
seed = 1234
image = pipe(
SR_delta=SR_delta,
SR_hw_split_ratio=SR_hw_split_ratio,
SR_prompt=SR_prompt,
HB_prompt_list=HB_prompt_list,
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
HB_replace=HB_replace,
seed=seed,
prompt=prompt,
height=height,
width=width,
num_inference_steps=20,
guidance_scale=3.5,
).images[0]
filename = "RAG.png"
image.save(filename)
print(f"Image saved as {filename}")
Examples
from RAG_pipeline_flux import RAG_FluxPipeline
import argparse
import torch
from PIL import Image
import json
pipe = RAG_FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = "A vase and an apple."
HB_replace = 2
HB_prompt_list = [
"Vase",
"Apple"
]
HB_m_offset_list = [
0.05,
0.65
]
HB_n_offset_list = [
0.1,
0.25
]
HB_m_scale_list = [
0.5,
0.3
]
HB_n_scale_list = [
0.8,
0.5
]
SR_delta = 0.5
SR_hw_split_ratio = "0.6, 0.4"
SR_prompt = "A beautifully crafted vase, its elegant curves and floral embellishments standing prominently on the left side. Its delicate design echoes a sense of timeless artistry. BREAK On the right, a shiny apple with vibrant red skin, enticing with its perfectly smooth surface and hints of green around the stem."
height = 1024
width = 1024
seed = 1202
Repainting_prompt = "A vase and a Rubik's Cube."
Repainting_SR_prompt = "A beautifully crafted vase, its elegant curves and floral embellishments standing prominently on the left side. Its delicate design echoes a sense of timeless artistry. BREAK On the right, a vibrant Rubik's Cube, with its distinct colorful squares, sitting next to the vase, adding a playful and dynamic contrast to the still life composition."
Repainting_HB_prompt = "Rubik's Cube"
Repainting_mask = Image.open("data/Repainting_mask/mask_6.png").convert("L")
Repainting_HB_replace = 3
Repainting_seed = 100
Repainting_single = 0
image, Repainting_image_output = pipe(
SR_delta=SR_delta,
SR_hw_split_ratio=SR_hw_split_ratio,
SR_prompt=SR_prompt,
HB_prompt_list=HB_prompt_list,
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
HB_replace=HB_replace,
seed=seed,
Repainting_mask=Repainting_mask,
Repainting_prompt=Repainting_prompt,
Repainting_SR_prompt=Repainting_SR_prompt,
Repainting_HB_prompt=Repainting_HB_prompt,
Repainting_HB_replace=Repainting_HB_replace,
Repainting_seed=Repainting_seed,
Repainting_single=Repainting_single,
prompt=prompt,
height=height,
width=width,
num_inference_steps=20,
guidance_scale=3.5
)
image.images[0].save("RAG_Original.png")
Repainting_image_output.images[0].save("RAG_Repainting.png")
RAG showcases its image repainting capabilities, achieving competitive results against the latest Flux.1-Fill-Dev and BrushNet.
Example
Text prompt: "A vase and an apple."
Repainting prompt: "A vase and a Rubik's Cube." |
||
python RAG_Repainting.py --idx=0 |
Text prompt: "Three plush toys on the table."
Repainting prompt: "Two plush toys and one balloon on the table." |
||
python RAG_Repainting.py --idx=1 |
Text prompt: "A Siamese cat lying on the grass."
Repainting prompt: "A Corgi lying on the grass." |
||
python RAG_Repainting.py --idx=2 |
Text prompt: "A boy holding a basketball in one hand."
Repainting prompt: "A boy holding a soccer in one hand." |
||
python RAG_Repainting.py --idx=3 |
Text prompt: "A brown curly hair African woman in blue puffy skirt."
Repainting prompt: "A brown curly hair African woman in pink suit." |
||
python RAG_Repainting.py --idx=4 |
Text prompt: "A man on the left, a woman on the right."
Repainting prompt: "A man on the left, an anime woman on the right." |
||
python RAG_Repainting.py --idx=5 |
from RAG_pipeline_flux import RAG_FluxPipeline
from diffusers import FluxPriorReduxPipeline
from diffusers.utils import load_image
import torch
import argparse
import json
pipe = RAG_FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16).to("cuda")
prompt = "A man is holding a sign that says RAG-Diffusion, and another man is holding a sign that says flux-redux."
HB_replace = 8
HB_m_offset_list = [
0.05,
0.55
]
HB_n_offset_list = [
0.2,
0.2
]
HB_m_scale_list = [
0.4,
0.4
]
HB_n_scale_list = [
0.4,
0.4
]
SR_delta = 0.2
SR_hw_split_ratio = "0.5,0.5"
SR_prompt = "A man is holding a sign that says RAG-Diffusion BREAK another man is holding a sign that says flux-redux."
height = 1024
width = 1024
seed = 2272
Redux_list = [
"data/Redux/Lecun.jpg",
"data/Redux/Hinton.jpg"
]
Redux_list = [pipe_prior_redux(load_image(Redux)) for Redux in Redux_list]
del pipe_prior_redux
torch.cuda.empty_cache()
image = pipe(
SR_delta = SR_delta,
SR_hw_split_ratio = SR_hw_split_ratio,
SR_prompt = SR_prompt,
HB_m_offset_list = HB_m_offset_list,
HB_n_offset_list = HB_n_offset_list,
HB_m_scale_list = HB_m_scale_list,
HB_n_scale_list = HB_n_scale_list,
Redux_list = Redux_list,
HB_replace = HB_replace,
seed = seed,
prompt = prompt, height=height, width=width, num_inference_steps=20, guidance_scale=3.5
)
image.images[0].save("RAG_with_Redux.png")
The left side is a skeleton with fire, and the right side is an ice dragon | ||
python RAG_with_Redux.py --idx=0 |
Four ceramic mugs are placed on a wooden table | ||||
python RAG_with_Redux.py --idx=2 |
Two women in an illustration style. | ||
python RAG_with_Redux.py --idx=3 |
import torch
from RAG_pipeline_flux import RAG_FluxPipeline
pipe = RAG_FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
# 8steps
# pipe.load_lora_weights('ByteDance/Hyper-SD', weight_name='Hyper-FLUX.1-dev-8steps-lora.safetensors')
# pipe.fuse_lora(lora_scale=0.125)
# MiaoKa-Yarn-World
# pipe.load_lora_weights('Shakker-Labs/FLUX.1-dev-LoRA-MiaoKa-Yarn-World', weight_name='FLUX-dev-lora-MiaoKa-Yarn-World.safetensors')
# pipe.fuse_lora(lora_scale=1.0)
# Black-Myth-Wukong
pipe.load_lora_weights('Shakker-Labs/FLUX.1-dev-LoRA-collections', weight_name='FLUX-dev-lora-Black_Myth_Wukong_hyperrealism_v1.safetensors')
pipe.fuse_lora(lora_scale=0.7)
pipe = pipe.to("cuda")
prompt = "A mountain on the left, a crouching man in the middle, and an ancient architecture on the right."
HB_replace = 3
HB_prompt_list = [
"Mountain",
"Crouching man",
"Ancient architecture"
]
HB_m_offset_list = [
0.02,
0.35,
0.68
]
HB_n_offset_list = [
0.1,
0.1,
0.0
]
HB_m_scale_list = [
0.29,
0.3,
0.29
]
HB_n_scale_list = [
0.8,
0.8,
1.0
]
SR_delta = 0.0
SR_hw_split_ratio = "0.33, 0.34, 0.33"
SR_prompt = "A mountain towering on the left, its peaks reaching into the sky, the steep slopes inviting exploration and wonder. BREAK In the middle, a crouching man is focused, his posture suggesting thoughtfulness or a momentary pause in action. BREAK On the right, an ancient architecture, its stone walls and archways revealing stories of the past, stands firmly, offering a glimpse into historical grandeur."
height = 1024
width = 1024
seed = 1236
image = pipe(
prompt=prompt,
HB_replace=HB_replace,
HB_prompt_list=HB_prompt_list,
HB_m_offset_list=HB_m_offset_list,
HB_n_offset_list=HB_n_offset_list,
HB_m_scale_list=HB_m_scale_list,
HB_n_scale_list=HB_n_scale_list,
SR_delta=SR_delta,
SR_hw_split_ratio=SR_hw_split_ratio,
SR_prompt=SR_prompt,
seed=seed,
height=height,
width=width,
num_inference_steps=20,
guidance_scale=3.5,
).images[0]
filename = "RAG_with_LoRA.png"
image.save(filename)
Hyper-Flux
FLUX.1-dev-LoRA-collections
FLUX.1-dev-LoRA-MiaoKa-Yarn-World
Our work is sponsored by HuggingFace and fal.ai, and it built on diffusers, Flux.1-dev, RPG.
@article{chen2024region,
title={Region-Aware Text-to-Image Generation via Hard Binding and Soft Refinement},
author={Chen, Zhennan and Li, Yajie and Wang, Haofan and Chen, Zhibo and Jiang, Zhengkai and Li, Jun and Wang, Qian and Yang, Jian and Tai, Ying},
journal={arXiv preprint arXiv:2411.06558},
year={2024}
}