forked from levihsu/OOTDiffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_ootd_hd.py
132 lines (113 loc) · 4.43 KB
/
inference_ootd_hd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import pdb
from pathlib import Path
import sys
PROJECT_ROOT = Path(__file__).absolute().parents[0].absolute()
sys.path.insert(0, str(PROJECT_ROOT))
import os
import torch
import numpy as np
from PIL import Image
import cv2
import random
import time
import pdb
from pipelines_ootd.pipeline_ootd import OotdPipeline
from pipelines_ootd.unet_garm_2d_condition import UNetGarm2DConditionModel
from pipelines_ootd.unet_vton_2d_condition import UNetVton2DConditionModel
from diffusers import UniPCMultistepScheduler
from diffusers import AutoencoderKL
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from transformers import CLIPTextModel, CLIPTokenizer
VIT_PATH = "../checkpoints/clip-vit-large-patch14"
VAE_PATH = "../checkpoints/ootd"
UNET_PATH = "../checkpoints/ootd/ootd_hd/checkpoint-36000"
MODEL_PATH = "../checkpoints/ootd"
class OOTDiffusionHD:
def __init__(self, gpu_id):
self.gpu_id = 'cuda:' + str(gpu_id)
vae = AutoencoderKL.from_pretrained(
VAE_PATH,
subfolder="vae",
torch_dtype=torch.float16,
)
unet_garm = UNetGarm2DConditionModel.from_pretrained(
UNET_PATH,
subfolder="unet_garm",
torch_dtype=torch.float16,
use_safetensors=True,
)
unet_vton = UNetVton2DConditionModel.from_pretrained(
UNET_PATH,
subfolder="unet_vton",
torch_dtype=torch.float16,
use_safetensors=True,
)
self.pipe = OotdPipeline.from_pretrained(
MODEL_PATH,
unet_garm=unet_garm,
unet_vton=unet_vton,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
safety_checker=None,
requires_safety_checker=False,
).to(self.gpu_id)
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH).to(self.gpu_id)
self.tokenizer = CLIPTokenizer.from_pretrained(
MODEL_PATH,
subfolder="tokenizer",
)
self.text_encoder = CLIPTextModel.from_pretrained(
MODEL_PATH,
subfolder="text_encoder",
).to(self.gpu_id)
def tokenize_captions(self, captions, max_length):
inputs = self.tokenizer(
captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
return inputs.input_ids
def __call__(self,
model_type='hd',
category='upperbody',
image_garm=None,
image_vton=None,
mask=None,
image_ori=None,
num_samples=1,
num_steps=20,
image_scale=1.0,
seed=-1,
):
if seed == -1:
random.seed(time.time())
seed = random.randint(0, 2147483647)
print('Initial seed: ' + str(seed))
generator = torch.manual_seed(seed)
with torch.no_grad():
prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to(self.gpu_id)
prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
prompt_image = prompt_image.unsqueeze(1)
if model_type == 'hd':
prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to(self.gpu_id))[0]
prompt_embeds[:, 1:] = prompt_image[:]
elif model_type == 'dc':
prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to(self.gpu_id))[0]
prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
else:
raise ValueError("model_type must be \'hd\' or \'dc\'!")
images = self.pipe(prompt_embeds=prompt_embeds,
image_garm=image_garm,
image_vton=image_vton,
mask=mask,
image_ori=image_ori,
num_inference_steps=num_steps,
image_guidance_scale=image_scale,
num_images_per_prompt=num_samples,
generator=generator,
).images
return images