Skip to content

Commit

Permalink
Perp-Neg algorithm to avoid the Janus problem (ashawkey#307)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@bolt-dpryeu793f-wb3zt6c4w5.bolt-pods.turi-bolt.svc.kube.us-west-2c.k8s.cloud.apple.com>
  • Loading branch information
rezaarmand and root authored Jun 11, 2023
1 parent 8fb3613 commit f4b5c3b
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 35 deletions.
42 changes: 42 additions & 0 deletions guidance/if_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F

from torch.cuda.amp import custom_bwd, custom_fwd
from .perpneg_utils import weighted_perpendicular_aggregator


class SpecifyGradient(torch.autograd.Function):
Expand Down Expand Up @@ -121,6 +122,47 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1

return loss

def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, grad_scale=1):

B = pred_rgb.shape[0]
K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts

# [0, 1] to [-1, 1] and make sure shape is [64, 64]
images = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1

# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (images.shape[0],), dtype=torch.long, device=self.device)

# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(images)
images_noisy = self.scheduler.add_noise(images, noise, t)

# pred noise
model_input = torch.cat([images_noisy] * (1 + K))
model_input = self.scheduler.scale_model_input(model_input, t)
tt = torch.cat([t] * (1 + K))
unet_output = self.unet(model_input, tt, encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)



# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)

# since we omitted an item in grad, we need to use the custom function to specify the gradient
loss = SpecifyGradient.apply(images, grad)

return loss

@torch.no_grad()
def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5):

Expand Down
48 changes: 48 additions & 0 deletions guidance/perpneg_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
def get_perpendicular_component(x, y):
assert x.shape == y.shape
return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y


def batch_get_perpendicular_component(x, y):
assert x.shape == y.shape
result = []
for i in range(x.shape[0]):
result.append(get_perpendicular_component(x[i], y[i]))
return torch.stack(result)


def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
"""
Notes:
- weights: an array with the weights for combining the noise predictions
- delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
"""
delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
weights = weights.split(batch_size, dim=0) # K x [B]
# print(f"{weights[0].shape = } {weights = }")

assert torch.all(weights[0] == 1.0)

main_positive = delta_noise_preds[0] # [B, 4, 64, 64]

accumulated_output = torch.zeros_like(main_positive)
for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
# print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")

idx_non_zero = torch.abs(weights[i]) > 1e-4

# print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
# print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
# print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
# print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
if sum(idx_non_zero) == 0:
continue
accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])

assert accumulated_output.shape == main_positive.shape, f"{accumulated_output.shape = }, {main_positive.shape = }"


return accumulated_output + main_positive
86 changes: 86 additions & 0 deletions guidance/sd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torchvision.utils import save_image

from torch.cuda.amp import custom_bwd, custom_fwd
from .perpneg_utils import weighted_perpendicular_aggregator

class SpecifyGradient(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -174,6 +175,91 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=Fa
loss = SpecifyGradient.apply(latents, grad)

return loss


def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,
save_guidance_path:Path=None):

B = pred_rgb.shape[0]
K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts

if as_latent:
latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
else:
# interp to 512x512 to be fed into vae.
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
# encode image into latents with vae, requires grad!
latents = self.encode_imgs(pred_rgb_512)

# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)

# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * (1 + K))
tt = torch.cat([t] * (1 + K))
unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample

# perform guidance (high scale from paper!)
noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
noise_pred = noise_pred_uncond + guidance_scale * weighted_perpendicular_aggregator(delta_noise_preds, weights, B)

# import kiui
# latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
# latents_tmp = latents_tmp.detach()
# kiui.lo(latents_tmp)
# self.scheduler.set_timesteps(30)
# for i, t in enumerate(self.scheduler.timesteps):
# latent_model_input = torch.cat([latents_tmp] * 3)
# noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
# noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
# noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
# latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
# imgs = self.decode_latents(latents_tmp)
# kiui.vis.plot_image(imgs)

# w(t), sigma_t^2
w = (1 - self.alphas[t])
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
grad = torch.nan_to_num(grad)

if save_guidance_path:
with torch.no_grad():
if as_latent:
pred_rgb_512 = self.decode_latents(latents)

# visualize predicted denoised image
# The following block of code is equivalent to `predict_start_from_noise`...
# see zero123_utils.py's version for a simpler implementation.
alphas = self.scheduler.alphas.to(latents)
total_timesteps = self.max_step - self.min_step + 1
index = total_timesteps - t.to(latents.device) - 1
b = len(noise_pred)
a_t = alphas[index].reshape(b,1,1,1).to(self.device)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))

# visualize noisier image
result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))



# all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
save_image(viz_images, save_guidance_path)

# since we omitted an item in grad, we need to use the custom function to specify the gradient
loss = SpecifyGradient.apply(latents, grad)
# print("we did it")
return loss


@torch.no_grad()
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
Expand Down
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def __call__ (self, parser, namespace, values, option_string = None):
parser.add_argument('--init_with', type=str, default='', help="ckpt to init dmtet")
parser.add_argument('--lock_geo', action='store_true', help="disable dmtet to learn geometry")

## Perp-Neg options
parser.add_argument('--perpneg', action='store_true', help="use perp_neg")
parser.add_argument('--negative_w', type=float, default=-2, help="scale of the weight of the negative prompt, the larger the better at avoiding janus problem, but may cause flat faces, vary between 0 to -4")
parser.add_argument('--front_decay_factor', type=float, default=2, help="decay factor for the front prompt")
parser.add_argument('--side_decay_factor', type=float, default=10, help="decay factor for the side prompt")

### training options
parser.add_argument('--iters', type=int, default=10000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate")
Expand Down
Loading

0 comments on commit f4b5c3b

Please sign in to comment.