Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
dome272 authored Jun 7, 2022
1 parent d02aff1 commit a4cfa7c
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 0 deletions.
107 changes: 107 additions & 0 deletions ddpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from utils import get_data
from modules import UNet
import logging
from torchvision.utils import save_image
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")


class Diffusion:
def __init__(self, noise_steps=100, beta_start=0.0001, beta_end=0.05, img_size=256, device="cuda"):
self.noise_steps = noise_steps
self.beta_start = beta_start
self.beta_end = beta_end

self.beta = self.prepare_noise_schedule()
self.alpha = 1 - self.beta
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
self.img_size = img_size
self.device = device

def prepare_noise_schedule(self):
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

def noise_images(self, x, t):
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None].to(self.device)
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None].to(self.device)
Ɛ = torch.randn_like(x)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

def sample_timesteps(self, n):
return torch.randint(low=1, high=self.noise_steps, size=(n,))

def sample(self, model, n):
logging.info(f"Sampling {n} new images....")
model.eval()
with torch.no_grad():
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
t = (torch.ones(n) * i).long().to(self.device)
predicted_noise = model(x, t)
alpha = self.alpha[t][:, None, None, None].to(self.device)
alpha_hat = self.alpha_hat[t][:, None, None, None].to(self.device)
beta = self.beta[t][:, None, None, None].to(self.device)
if i > 1:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + beta * noise
model.train()
return x


def train(args):
device = args.device
dataloader = get_data(args)
model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
mse = nn.MSELoss()
diffusion = Diffusion(device=device) # change upper noise level

for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
pbar = tqdm(dataloader)
for i, (images, _) in enumerate(pbar):
images = images.to(device)
t = diffusion.sample_timesteps(images.shape[0]).to(device)
x_t, noise = diffusion.noise_images(images, t)
predicted_noise = model(x_t, t)
loss = mse(noise, predicted_noise)

optimizer.zero_grad()
loss.backward()
optimizer.step()

pbar.set_postfix(MSE=loss.item())

sampled_images = diffusion.sample(model, n=images.shape[0])
save_image(sampled_images.add(1).mul(0.5), os.path.join("results", f"{epoch}.jpg"))


def launch():
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.run_name = "DDPM_test"
args.epochs = 300
args.batch_size = 16 # 5
args.dataset_path = r"C:\Users\dome\datasets\landscape_img_folder"
args.device = "cuda"
args.lr = 3e-4
train(args)


if __name__ == '__main__':
# n = Noise()
# img = torch.Tensor(np.array(Image.open("./images/test.jpg").resize((256, 256)))).permute(2, 0, 1) / 127.5 - 1
#
# imgs = img.unsqueeze(0).expand(3, -1, -1, -1)
# ts = torch.Tensor([100, 400, 300]).long()
# noised_imgs = n.noise_images(imgs, ts)
# plt.imshow(noised_imgs[0].add(1).mul(0.5).permute(1, 2, 0))
# plt.show()
launch()
101 changes: 101 additions & 0 deletions modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class UNetEncoderBlock(nn.Module):
def __init__(self, c_in, c_out, k=3, pad=1, stride=2, first_layer=False, use_incr=False):
super().__init__()
incr = 0 if (first_layer != True or use_incr == False) else 4
self.encoder = nn.Sequential(
nn.Identity() if first_layer else nn.ReLU(),
nn.Conv2d(c_in, c_out, (k + 1 + incr), padding=(pad + incr // 2), stride=stride),
nn.Identity() if first_layer else nn.InstanceNorm2d(c_out),
nn.ReLU(),
nn.Conv2d(c_out, c_out, k, padding=pad),
nn.InstanceNorm2d(c_out)
)

def forward(self, x):
x = self.encoder(x)
return x


class UNetDecoderBlock(nn.Module):
def __init__(self, c_in, c_out, k=3, pad=1, stride=2, last_layer=False, use_incr=False):
super().__init__()
incr = 0 if (last_layer != True or use_incr == False) else 4
self.decoder = nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, (k + 1), padding=pad, stride=stride),
nn.InstanceNorm2d(c_in),
nn.ReLU(),
nn.Conv2d(c_in, c_out, (k + incr), padding=(pad + incr // 2)),
nn.Identity() if last_layer else nn.InstanceNorm2d(c_out),
nn.Identity() if last_layer else nn.ReLU()
)

def forward(self, x):
x = self.decoder(x)
return x


class UNet(nn.Module):
def __init__(self, c_in=3, c_out=3, t_emb=128, c_t=128, hidden_size=256):
super().__init__()
self.c_t = c_t
self.encoder_mappers = nn.ModuleList([
nn.Conv2d(c_t, t_emb, 1) for _ in range(3)
])
self.decoder_mappers = nn.ModuleList([
nn.Conv2d(c_t, t_emb, 1) for _ in range(3)
])
self.encoders = nn.ModuleList([
UNetEncoderBlock(c_in, (hidden_size // 8), first_layer=True),
UNetEncoderBlock(hidden_size // 8 + t_emb, hidden_size // 4),
UNetEncoderBlock(hidden_size // 4 + t_emb, hidden_size // 2),
UNetEncoderBlock(hidden_size // 2 + t_emb, hidden_size),
])
self.decoders = nn.ModuleList([
UNetDecoderBlock(hidden_size + t_emb, hidden_size // 2),
UNetDecoderBlock(2 * hidden_size // 2 + t_emb, hidden_size // 4),
UNetDecoderBlock(2 * hidden_size // 4 + t_emb, hidden_size // 8),
UNetDecoderBlock((2 * hidden_size // 8), c_out, last_layer=True),
])

def gen_t_embedding(self, t, max_positions=10000):
half_dim = self.c_t // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=t.device).float().mul(-emb).exp()
emb = t[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.c_t % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), mode='constant')
return emb[:, :, None, None]

def forward(self, x, t):
t = self.gen_t_embedding(t) # bs x 128 x 1 x 1
encodings = []
for i, encoder in enumerate(self.encoders):
if i > 0:
c = self.encoder_mappers[i - 1](t).expand(-1, -1, *x.shape[2:])
x = torch.cat([x, c], dim=1)
x = encoder(x)
encodings.insert(0, x)

for i, decoder in enumerate(self.decoders):
if i > 0:
x = torch.cat((x, encodings[i]), dim=1)
if i != len(self.decoders) - 1:
c = self.decoder_mappers[i - 1](t).expand(-1, -1, *x.shape[2:])
x = torch.cat([x, c], dim=1)
x = decoder(x)

return x


if __name__ == '__main__':
net = UNet()
x = torch.randn(1, 3, 256, 256)
t = x.new_tensor([500] * x.shape[0]).long()
print(net(x, t).shape)
14 changes: 14 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torchvision
from torch.utils.data import DataLoader


def get_data(args):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(80),
torchvision.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
torchvision.transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
return dataloader

0 comments on commit a4cfa7c

Please sign in to comment.