Skip to content

Commit

Permalink
Merge branch 'FlagAI-Open:master' into add_eva_clip
Browse files Browse the repository at this point in the history
  • Loading branch information
Quan-Sun authored Nov 24, 2022
2 parents 1c953a4 + e232dae commit 5c7daeb
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 78 deletions.
1 change: 0 additions & 1 deletion flagai/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _load_state_dict_into_model(cls,
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model

Expand Down
2 changes: 1 addition & 1 deletion flagai/model/mm/AltCLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def encode(self,
padding="max_length",
truncation=True,
max_length=77):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = next(self.text_model.parameters()).device
text = tokenizer(text,
truncation=True,
max_length=77,
Expand Down
10 changes: 8 additions & 2 deletions flagai/model/mm/AltDiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(
**kwargs,
):
super(DDPM, self).__init__(unet_config, **kwargs)
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
# self.device = torch.device(
# "cuda" if torch.cuda.is_available() else "cpu")
assert parameterization in [
"eps", "x0"
], 'currently only supporting "eps" and "x0"'
Expand Down Expand Up @@ -513,6 +513,11 @@ def log_images(self,
class LatentDiffusion(DDPM):
"""main class"""

def to(self, device):
self.device = device
self.cond_stage_model.to(device)
super().to(device)

def __init__(self,
first_stage_config,
cond_stage_config,
Expand All @@ -527,6 +532,7 @@ def __init__(self,
tokenizer=None,
*args,
**kwargs):
self.device = "cpu"
self.tokenizer = tokenizer
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
Expand Down
148 changes: 74 additions & 74 deletions flagai/model/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,81 +389,81 @@ def predict_generate_images(self,
start_code = torch.randn([n_samples, C, H // f, W // f],
device=device)

precision_scope = nullcontext
# precision_scope = nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with self.model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if scale != 1.0:
uc = self.model.get_learned_conditioning(
batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = self.model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(
S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = self.model.decode_first_stage(
samples_ddim)
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(
0, 2, 3, 1).numpy()

x_checked_image, has_nsfw_concept = check_safety(safety_checker, safety_feature_extractor, x_samples_ddim)

x_checked_image_torch = torch.from_numpy(
x_checked_image).permute(0, 3, 1, 2)

prompt_count = 0
if not skip_save:
for x_sample in x_checked_image_torch:
x_sample = 255. * rearrange(
x_sample.cpu().numpy(),
'c h w -> h w c')
img = Image.fromarray(
x_sample.astype(np.uint8))
img.save(
os.path.join(sample_path,
f"{base_count:05}.png"))
#img.save(os.path.join(sample_path, f"{prompts[prompt_count]}.png"))

base_count += 1
prompt_count = prompt_count % batch_size
prompt_count += 1

if not skip_grid:
all_samples.append(x_checked_image_torch)

if not skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)

# to image
grid = 255. * rearrange(
grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
# img = put_watermark(img, wm_encoder)
img.save(
os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1

toc = time.time()
# with precision_scope("cuda"):
with self.model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if scale != 1.0:
uc = self.model.get_learned_conditioning(
batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = self.model.get_learned_conditioning(prompts)
shape = [C, H // f, W // f]
samples_ddim, _ = sampler.sample(
S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = self.model.decode_first_stage(
samples_ddim)
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(
0, 2, 3, 1).numpy()

x_checked_image, has_nsfw_concept = check_safety(safety_checker, safety_feature_extractor, x_samples_ddim)

x_checked_image_torch = torch.from_numpy(
x_checked_image).permute(0, 3, 1, 2)

prompt_count = 0
if not skip_save:
for x_sample in x_checked_image_torch:
x_sample = 255. * rearrange(
x_sample.cpu().numpy(),
'c h w -> h w c')
img = Image.fromarray(
x_sample.astype(np.uint8))
img.save(
os.path.join(sample_path,
f"{base_count:05}.png"))
#img.save(os.path.join(sample_path, f"{prompts[prompt_count]}.png"))

base_count += 1
prompt_count = prompt_count % batch_size
prompt_count += 1

if not skip_grid:
all_samples.append(x_checked_image_torch)

if not skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)

# to image
grid = 255. * rearrange(
grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
# img = put_watermark(img, wm_encoder)
img.save(
os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1

toc = time.time()

print(
f"Your samples are ready and waiting for you here: \n{outpath} \n"
Expand Down

0 comments on commit 5c7daeb

Please sign in to comment.