Skip to content

Commit

Permalink
enable varying resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweiy committed May 26, 2024
1 parent 3f3a60e commit 6f6b7da
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ conda activate dmd2

pip install --upgrade anyio
pip install torch==2.0.1 torchvision==0.15.2
pip install --upgrade diffusers peft wandb lmdb transformers accelerate==0.23.0 lmdb datasets evaluate scipy opencv-python matplotlib imageio piq==0.7.0 safetensors gradio
pip install --upgrade diffusers peft wandb lmdb transformers accelerate==0.23.0 lmdb datasets evaluate scipy opencv-python matplotlib imageio piq==0.7.0 safetensors gradio huggingface-hub==0.22.0
python setup.py develop
```

Expand Down
23 changes: 22 additions & 1 deletion demo/text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, args, accelerator):
self.image_resolution = args.image_resolution
self.latent_resolution = args.latent_resolution
self.num_train_timesteps = args.num_train_timesteps
self.vae_downsample_ratio = self.image_resolution // self.latent_resolution

self.base_add_time_ids = self.build_condition_input()
self.conditioning_timestep = args.conditioning_timestep
Expand Down Expand Up @@ -183,6 +184,8 @@ def inference(
self,
prompt: str,
seed: int,
height: int,
width: int,
num_images: int,
fast_vae_decode: bool
):
Expand All @@ -196,7 +199,7 @@ def inference(
add_time_ids = self.base_add_time_ids.repeat(num_images, 1)

noise = torch.randn(
num_images, 4, self.latent_resolution, self.latent_resolution,
num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio,
generator=generator
).to(device=self.device, dtype=self.DTYPE)

Expand Down Expand Up @@ -292,6 +295,22 @@ def create_demo():
label="Use Tiny VAE for faster decoding",
value=True
)
height = gr.Slider(
label="Image Height",
minimum=512,
maximum=1536,
step=64,
value=1024,
info="Image height in pixels."
)
width = gr.Slider(
label="Image Width",
minimum=512,
maximum=1536,
step=64,
value=1024,
info="Image width in pixels."
)
with gr.Column():
result = gr.Gallery(label="Generated Images", show_label=False, elem_id="gallery", height=1024)

Expand All @@ -300,6 +319,8 @@ def create_demo():
inputs = [
prompt,
seed,
height,
width,
num_images,
fast_vae_decode
]
Expand Down
2 changes: 1 addition & 1 deletion experiments/sdxl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bash scripts/download_hf_checkpoint.sh $CHECKPOINT_NAME $OUTPUT_PATH

### Download Base Diffusion Models and Training Data
```bash
export CHECKPOINT_PATH="" # change this to your own checkpoint folder
export CHECKPOINT_PATH="" # change this to your own checkpoint folder (this should be a central directory shared across nodes)
export WANDB_ENTITY="" # change this to your own wandb entity
export WANDB_PROJECT="" # change this to your own wandb project
export MASTER_IP="" # change this to your own master ip
Expand Down
14 changes: 4 additions & 10 deletions main/sdxl/sdxl_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from main.utils import import_model_class_from_model_name_or_path
from transformers import CLIPTextModel, CLIPTextModelWithProjection
import torch

class SDXLTextEncoder(torch.nn.Module):
def __init__(self, args, accelerator, dtype=torch.float32) -> None:
super().__init__()
text_encoder_cls_one = import_model_class_from_model_name_or_path(
args.model_id, args.revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
args.model_id, args.revision, subfolder="text_encoder_2"
)

self.text_encoder_one = text_encoder_cls_one.from_pretrained(

self.text_encoder_one = CLIPTextModel.from_pretrained(
args.model_id, subfolder="text_encoder", revision=args.revision
).to(accelerator.device).to(dtype=dtype)

self.text_encoder_two = text_encoder_cls_two.from_pretrained(
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
args.model_id, subfolder="text_encoder_2", revision=args.revision
).to(accelerator.device).to(dtype=dtype)

Expand Down

0 comments on commit 6f6b7da

Please sign in to comment.