forked from dome272/Diffusion-Models-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
222 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import os | ||
import torch | ||
import torch.nn as nn | ||
from tqdm import tqdm | ||
from torch import optim | ||
from utils import get_data | ||
from modules import UNet | ||
import logging | ||
from torchvision.utils import save_image | ||
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S") | ||
|
||
|
||
class Diffusion: | ||
def __init__(self, noise_steps=100, beta_start=0.0001, beta_end=0.05, img_size=256, device="cuda"): | ||
self.noise_steps = noise_steps | ||
self.beta_start = beta_start | ||
self.beta_end = beta_end | ||
|
||
self.beta = self.prepare_noise_schedule() | ||
self.alpha = 1 - self.beta | ||
self.alpha_hat = torch.cumprod(self.alpha, dim=0) | ||
self.img_size = img_size | ||
self.device = device | ||
|
||
def prepare_noise_schedule(self): | ||
return torch.linspace(self.beta_start, self.beta_end, self.noise_steps) | ||
|
||
def noise_images(self, x, t): | ||
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None].to(self.device) | ||
sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None].to(self.device) | ||
Ɛ = torch.randn_like(x) | ||
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ | ||
|
||
def sample_timesteps(self, n): | ||
return torch.randint(low=1, high=self.noise_steps, size=(n,)) | ||
|
||
def sample(self, model, n): | ||
logging.info(f"Sampling {n} new images....") | ||
model.eval() | ||
with torch.no_grad(): | ||
x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device) | ||
for i in tqdm(reversed(range(1, self.noise_steps)), position=0): | ||
t = (torch.ones(n) * i).long().to(self.device) | ||
predicted_noise = model(x, t) | ||
alpha = self.alpha[t][:, None, None, None].to(self.device) | ||
alpha_hat = self.alpha_hat[t][:, None, None, None].to(self.device) | ||
beta = self.beta[t][:, None, None, None].to(self.device) | ||
if i > 1: | ||
noise = torch.randn_like(x) | ||
else: | ||
noise = torch.zeros_like(x) | ||
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + beta * noise | ||
model.train() | ||
return x | ||
|
||
|
||
def train(args): | ||
device = args.device | ||
dataloader = get_data(args) | ||
model = UNet().to(device) | ||
optimizer = optim.AdamW(model.parameters(), lr=args.lr) | ||
mse = nn.MSELoss() | ||
diffusion = Diffusion(device=device) # change upper noise level | ||
|
||
for epoch in range(args.epochs): | ||
logging.info(f"Starting epoch {epoch}:") | ||
pbar = tqdm(dataloader) | ||
for i, (images, _) in enumerate(pbar): | ||
images = images.to(device) | ||
t = diffusion.sample_timesteps(images.shape[0]).to(device) | ||
x_t, noise = diffusion.noise_images(images, t) | ||
predicted_noise = model(x_t, t) | ||
loss = mse(noise, predicted_noise) | ||
|
||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
pbar.set_postfix(MSE=loss.item()) | ||
|
||
sampled_images = diffusion.sample(model, n=images.shape[0]) | ||
save_image(sampled_images.add(1).mul(0.5), os.path.join("results", f"{epoch}.jpg")) | ||
|
||
|
||
def launch(): | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
args = parser.parse_args() | ||
args.run_name = "DDPM_test" | ||
args.epochs = 300 | ||
args.batch_size = 16 # 5 | ||
args.dataset_path = r"C:\Users\dome\datasets\landscape_img_folder" | ||
args.device = "cuda" | ||
args.lr = 3e-4 | ||
train(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
# n = Noise() | ||
# img = torch.Tensor(np.array(Image.open("./images/test.jpg").resize((256, 256)))).permute(2, 0, 1) / 127.5 - 1 | ||
# | ||
# imgs = img.unsqueeze(0).expand(3, -1, -1, -1) | ||
# ts = torch.Tensor([100, 400, 300]).long() | ||
# noised_imgs = n.noise_images(imgs, ts) | ||
# plt.imshow(noised_imgs[0].add(1).mul(0.5).permute(1, 2, 0)) | ||
# plt.show() | ||
launch() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class UNetEncoderBlock(nn.Module): | ||
def __init__(self, c_in, c_out, k=3, pad=1, stride=2, first_layer=False, use_incr=False): | ||
super().__init__() | ||
incr = 0 if (first_layer != True or use_incr == False) else 4 | ||
self.encoder = nn.Sequential( | ||
nn.Identity() if first_layer else nn.ReLU(), | ||
nn.Conv2d(c_in, c_out, (k + 1 + incr), padding=(pad + incr // 2), stride=stride), | ||
nn.Identity() if first_layer else nn.InstanceNorm2d(c_out), | ||
nn.ReLU(), | ||
nn.Conv2d(c_out, c_out, k, padding=pad), | ||
nn.InstanceNorm2d(c_out) | ||
) | ||
|
||
def forward(self, x): | ||
x = self.encoder(x) | ||
return x | ||
|
||
|
||
class UNetDecoderBlock(nn.Module): | ||
def __init__(self, c_in, c_out, k=3, pad=1, stride=2, last_layer=False, use_incr=False): | ||
super().__init__() | ||
incr = 0 if (last_layer != True or use_incr == False) else 4 | ||
self.decoder = nn.Sequential( | ||
nn.ConvTranspose2d(c_in, c_in, (k + 1), padding=pad, stride=stride), | ||
nn.InstanceNorm2d(c_in), | ||
nn.ReLU(), | ||
nn.Conv2d(c_in, c_out, (k + incr), padding=(pad + incr // 2)), | ||
nn.Identity() if last_layer else nn.InstanceNorm2d(c_out), | ||
nn.Identity() if last_layer else nn.ReLU() | ||
) | ||
|
||
def forward(self, x): | ||
x = self.decoder(x) | ||
return x | ||
|
||
|
||
class UNet(nn.Module): | ||
def __init__(self, c_in=3, c_out=3, t_emb=128, c_t=128, hidden_size=256): | ||
super().__init__() | ||
self.c_t = c_t | ||
self.encoder_mappers = nn.ModuleList([ | ||
nn.Conv2d(c_t, t_emb, 1) for _ in range(3) | ||
]) | ||
self.decoder_mappers = nn.ModuleList([ | ||
nn.Conv2d(c_t, t_emb, 1) for _ in range(3) | ||
]) | ||
self.encoders = nn.ModuleList([ | ||
UNetEncoderBlock(c_in, (hidden_size // 8), first_layer=True), | ||
UNetEncoderBlock(hidden_size // 8 + t_emb, hidden_size // 4), | ||
UNetEncoderBlock(hidden_size // 4 + t_emb, hidden_size // 2), | ||
UNetEncoderBlock(hidden_size // 2 + t_emb, hidden_size), | ||
]) | ||
self.decoders = nn.ModuleList([ | ||
UNetDecoderBlock(hidden_size + t_emb, hidden_size // 2), | ||
UNetDecoderBlock(2 * hidden_size // 2 + t_emb, hidden_size // 4), | ||
UNetDecoderBlock(2 * hidden_size // 4 + t_emb, hidden_size // 8), | ||
UNetDecoderBlock((2 * hidden_size // 8), c_out, last_layer=True), | ||
]) | ||
|
||
def gen_t_embedding(self, t, max_positions=10000): | ||
half_dim = self.c_t // 2 | ||
emb = math.log(max_positions) / (half_dim - 1) | ||
emb = torch.arange(half_dim, device=t.device).float().mul(-emb).exp() | ||
emb = t[:, None] * emb[None, :] | ||
emb = torch.cat([emb.sin(), emb.cos()], dim=1) | ||
if self.c_t % 2 == 1: # zero pad | ||
emb = nn.functional.pad(emb, (0, 1), mode='constant') | ||
return emb[:, :, None, None] | ||
|
||
def forward(self, x, t): | ||
t = self.gen_t_embedding(t) # bs x 128 x 1 x 1 | ||
encodings = [] | ||
for i, encoder in enumerate(self.encoders): | ||
if i > 0: | ||
c = self.encoder_mappers[i - 1](t).expand(-1, -1, *x.shape[2:]) | ||
x = torch.cat([x, c], dim=1) | ||
x = encoder(x) | ||
encodings.insert(0, x) | ||
|
||
for i, decoder in enumerate(self.decoders): | ||
if i > 0: | ||
x = torch.cat((x, encodings[i]), dim=1) | ||
if i != len(self.decoders) - 1: | ||
c = self.decoder_mappers[i - 1](t).expand(-1, -1, *x.shape[2:]) | ||
x = torch.cat([x, c], dim=1) | ||
x = decoder(x) | ||
|
||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
net = UNet() | ||
x = torch.randn(1, 3, 256, 256) | ||
t = x.new_tensor([500] * x.shape[0]).long() | ||
print(net(x, t).shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torchvision | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def get_data(args): | ||
transforms = torchvision.transforms.Compose([ | ||
torchvision.transforms.Resize(80), | ||
torchvision.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)), | ||
torchvision.transforms.ToTensor(), | ||
]) | ||
dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms) | ||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False) | ||
return dataloader | ||
|