Skip to content

Commit

Permalink
[Community Pipeline] Diffusion Posterior Sampling for General Noisy I…
Browse files Browse the repository at this point in the history
…nverse Problems (huggingface#5939)

* [community pipeline] dps impl

* add type checking

* pass ruff check

* ruff formatter
  • Loading branch information
tongdaxu authored Nov 27, 2023
1 parent ebf581e commit 14a0d21
Show file tree
Hide file tree
Showing 2 changed files with 609 additions and 1 deletion.
144 changes: 143 additions & 1 deletion examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2480,4 +2480,146 @@ images = pipe(
).images
images[0].save("controlnet_and_adapter_inpaint.png")

```
```

## Diffusion Posterior Sampling Pipeline
* Reference paper
```
@article{chung2022diffusion,
title={Diffusion posterior sampling for general noisy inverse problems},
author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
journal={arXiv preprint arXiv:2209.14687},
year={2022}
}
```
* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
```python
import torch.nn.functional as F
import scipy
from torch import nn
# define the Gaussian blurring operator first
class GaussialBlurOperator(nn.Module):
def __init__(self, kernel_size, intensity):
super().__init__()
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type='gaussian',
kernel_size=kernel_size,
std=intensity)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))
for param in self.parameters():
param.requires_grad=False
def forward(self, data, **kwargs):
return self.conv(data)
def transpose(self, data, **kwargs):
return data
def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
```
* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
```python
# set up source image
src = Image.open('sample.png')
# read image into [1,3,H,W]
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
# normalize image to [-1,1]
src = (src / 127.5) - 1.0
src = src.to("cuda")
# set up operator and measurement
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
measurement = operator(src)
# save the source and corrupted images
save_image((src+1.0)/2.0, "dps_src.png")
save_image((measurement+1.0)/2.0, "dps_mea.png")
```
* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
* Source image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65)
* Gaussian blurred image:
* ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae)
* You can download those image to run the example on your own.
* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
```python
def RMSELoss(yhat, y):
return torch.sqrt(torch.sum((yhat-y)**2))
```
* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
```python
# set up scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(1000)
# set up model
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
```
* And finally, run the pipeline:
```python
# finally, the pipeline
dpspipe = DPSPipeline(model, scheduler)
image = dpspipe(
measurement = measurement,
operator = operator,
loss_fn = RMSELoss,
zeta = 1.0,
).images[0]
image.save("dps_generated_image.png")
```
* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
* Reconstructed image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209)
* The reconstruction is perceptually similar to the source image, but different in details.
* In dps_pipeline.py, we also provide a super-resolution example, which should produce:
* Downsampled image:
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
* Reconstructed image:
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)
Loading

0 comments on commit 14a0d21

Please sign in to comment.