Skip to content

Commit

Permalink
low_vram
Browse files Browse the repository at this point in the history
  • Loading branch information
lllyasviel committed Feb 11, 2023
1 parent 5be9c5e commit 1e5c75a
Show file tree
Hide file tree
Showing 20 changed files with 253 additions and 55 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This is also friendly to merge/replacement/offsetting of models/weights/blocks/l

**Q:** But wait, if the weight of a conv layer is zero, the gradient will also be zero, and the network will not learn anything. Why "zero convolution" works?

**A:** This is not true. [See an explanation here](FAQ.md).
**A:** This is not true. [See an explanation here](docs/faq.md).

# Stable Diffusion + ControlNet

Expand All @@ -49,6 +49,10 @@ We provide 9 Gradio apps with these models.

All test images can be found at the folder "test_imgs".

### News

2023/02/11 - [Low VRAM mode](docs/low_vram.md) is added. Please use this mode if you are using 8GB GPU(s).

## ControlNet with Canny Edge

Stable Diffusion 1.5 + ControlNet (using simple Canny edge detection)
Expand Down Expand Up @@ -206,13 +210,13 @@ This model is not available right now. We need to evaluate the potential risks b

We provide simple python scripts to process images.

[See a gradio example here](annotator.md).
[See a gradio example here](docs/annotator.md).

# Train with Your Own Data

Training a ControlNet is as easy as (or even easier than) training a simple pix2pix.

[See the steps here](train.md).
[See the steps here](docs/train.md).

# Citation

Expand Down
12 changes: 12 additions & 0 deletions cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,15 @@ def configure_optimizers(self):
params += list(self.model.diffusion_model.out.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt

def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()
55 changes: 55 additions & 0 deletions cldm/hack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,33 @@
import einops

import ldm.modules.encoders.modules
import ldm.modules.attention

from transformers import logging
from ldm.modules.attention import default


def disable_verbosity():
logging.set_verbosity_error()
print('logging.set_verbosity_error()')
return


def enable_sliced_attention():
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
print('Enabled sliced_attention.')
return


def hack_everything(clip_skip=0):
disable_verbosity()
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
print('Enabled clip hacks.')
return


# Written by Lvmin
def _hacked_clip_forward(self, text):
PAD = self.tokenizer.pad_token_id
EOS = self.tokenizer.eos_token_id
Expand Down Expand Up @@ -54,3 +66,46 @@ def pad(x, p, i):
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)

return z


# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x

q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

limit = k.shape[0]
att_step = 1
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))

q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range(0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale

del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks

sim_buffer = sim_buffer.softmax(dim=-1)

sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i + att_step, :, :] = sim_buffer

del sim_buffer
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
save_memory = False
File renamed without changes.
File renamed without changes.
9 changes: 9 additions & 0 deletions docs/low_vram.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Enable Low VRAM Mode

If you are using 8GB GPu card, please open "config.py", and then set

```python
save_memory = True
```

Note that is feature is still being tested - not all graphics cards are guaranteed to succeed.
File renamed without changes.
21 changes: 16 additions & 5 deletions gradio_canny2image.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch

from cldm.hack import disable_verbosity
disable_verbosity()

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import apply_canny
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler


model = create_model('./models/cldm_v15.yaml').cuda()
model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cpu'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)


Expand All @@ -33,14 +34,24 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti

seed_everything(seed)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

Expand Down
21 changes: 16 additions & 5 deletions gradio_depth2image.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch

from cldm.hack import disable_verbosity
disable_verbosity()

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.midas import apply_midas
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler


model = create_model('./models/cldm_v15.yaml').cuda()
model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda'))
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cpu'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)


Expand All @@ -35,14 +36,24 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti

seed_everything(seed)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

Expand Down
21 changes: 16 additions & 5 deletions gradio_fake_scribble2image.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch

from cldm.hack import disable_verbosity
disable_verbosity()

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.hed import apply_hed, nms
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler


model = create_model('./models/cldm_v15.yaml').cuda()
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda'))
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cpu'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)


Expand All @@ -39,14 +40,24 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti

seed_everything(seed)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

Expand Down
21 changes: 16 additions & 5 deletions gradio_hed2image.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch

from cldm.hack import disable_verbosity
disable_verbosity()

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.hed import apply_hed
from cldm.model import create_model, load_state_dict
from ldm.models.diffusion.ddim import DDIMSampler


model = create_model('./models/cldm_v15.yaml').cuda()
model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
model = create_model('./models/cldm_v15.yaml').cpu()
model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cpu'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)


Expand All @@ -35,14 +36,24 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti

seed_everything(seed)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
shape = (4, H // 8, W // 8)

if config.save_memory:
model.low_vram_shift(is_diffusing=True)

samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)

if config.save_memory:
model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

Expand Down
Loading

0 comments on commit 1e5c75a

Please sign in to comment.