Skip to content

Commit

Permalink
finish refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jun 12, 2022
1 parent 2d97544 commit 12b10cb
Showing 23 changed files with 288 additions and 216 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := models tests src utils
check_dirs := tests src utils

modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))
7 changes: 4 additions & 3 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -2,15 +2,16 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

__version__ = "0.0.1"
__version__ = "0.0.3"

from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
from .schedulers import SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.ddim import DDIMScheduler
from .schedulers.gaussian_ddpm import GaussianDDPMScheduler
from .schedulers.glide_ddim import GlideDDIMScheduler
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
2 changes: 1 addition & 1 deletion src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -213,7 +213,7 @@ def extract_init_dict(cls, config_dict, **kwargs):

passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warn(
logger.warning(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)

6 changes: 3 additions & 3 deletions src/diffusers/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -490,7 +490,7 @@ def _find_mismatched_keys(
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")

if len(unexpected_keys) > 0:
logger.warning(
logger.warninging(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
@@ -502,7 +502,7 @@ def _find_mismatched_keys(
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
logger.warninging(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
@@ -521,7 +521,7 @@ def _find_mismatched_keys(
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
logger.warninging(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/configuration_ldmbert.py
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ def __init__(
scale_embedding=False,
use_cache=True,
pad_token_id=0,
**kwargs
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
29 changes: 15 additions & 14 deletions src/diffusers/pipelines/modeling_vae.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,10 @@
import math

import numpy as np
import tqdm
import torch
import torch.nn as nn

import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
@@ -740,29 +740,30 @@ def sample(self):

def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)

def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)

def mode(self):
return self.mean


class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
@@ -834,7 +835,7 @@ def __init__(
give_pre_end=give_pre_end,
)

self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)

def encode(self, x):
@@ -855,4 +856,4 @@ def forward(self, input, sample_posterior=True):
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
return dec, posterior
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ def __init__(
scale_embedding=False,
use_cache=True,
pad_token_id=0,
**kwargs
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import tqdm
import torch

import tqdm
from diffusers import DiffusionPipeline

from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA

# add these relative imports here, so we can load from hub
from .modeling_vae import AutoencoderKL # NOQA
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
from .modeling_vae import AutoencoderKL # NOQA


class LatentDiffusion(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)

@torch.no_grad()
def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
def __call__(
self,
prompt,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
):
# eta corresponds to η in paper and should be between [0, 1]

if torch_device is None:
@@ -23,16 +34,18 @@ def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)

# get unconditional embeddings for classifier free guidence
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)[0]

# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)[0]

num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)

@@ -41,7 +54,7 @@ def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=
device=torch_device,
generator=generator,
)

# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding

@@ -60,20 +73,20 @@ def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
else:
# for classifier free guidance, we need to do two forward passes
# here we concanate embedding and unconditioned embedding in a single batch
# here we concanate embedding and unconditioned embedding in a single batch
# to avoid doing two forward passes
image_in = torch.cat([image] * 2)
context = torch.cat([uncond_embeddings, text_embedding])
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)

# 1. predict noise residual
pred_noise_t = self.unet(image_in, timesteps, context=context)

# perform guidance
if guidance_scale != 1.0:
pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)

# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)

@@ -87,8 +100,8 @@ def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=
image = pred_prev_image + variance

# scale and decode image with vae
image = 1 / 0.18215 * image
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)

return image
11 changes: 6 additions & 5 deletions src/diffusers/pipelines/old/latent_diffusion/modeling_ldmbert.py
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@
logging,
replace_return_docstrings,
)

from .configuration_ldmbert import LDMBertConfig


@@ -662,7 +663,7 @@ def __init__(self, config):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)

def forward(
self,
input_ids=None,
@@ -674,7 +675,7 @@ def forward(
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
):

outputs = self.model(
input_ids,
@@ -689,15 +690,15 @@ def forward(
sequence_output = outputs[0]
# logits = self.to_logits(sequence_output)
# outputs = (logits,) + outputs[1:]

# if labels is not None:
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
# outputs = (loss,) + outputs

# if not return_dict:
# return outputs

return BaseModelOutput(
last_hidden_state=sequence_output,
# hidden_states=outputs[1],
29 changes: 15 additions & 14 deletions src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py
Original file line number Diff line number Diff line change
@@ -2,10 +2,10 @@
import math

import numpy as np
import tqdm
import torch
import torch.nn as nn

import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
@@ -740,29 +740,30 @@ def sample(self):

def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)

def nll(self, sample, dims=[1,2,3]):
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)

def mode(self):
return self.mean


class AutoencoderKL(ModelMixin, ConfigMixin):
def __init__(
self,
@@ -834,7 +835,7 @@ def __init__(
give_pre_end=give_pre_end,
)

self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)

def encode(self, x):
@@ -855,4 +856,4 @@ def forward(self, input, sample_posterior=True):
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
return dec, posterior
Loading

0 comments on commit 12b10cb

Please sign in to comment.