Skip to content

Commit

Permalink
Add Birch-san's sub-quadratic attention implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
brkirch committed Jan 6, 2023
1 parent 4af3ca5 commit d782a95
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 35 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ The documentation was moved from this README over to the project's [wiki](https:
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443)
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
Expand Down
15 changes: 6 additions & 9 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet

from modules.sd_hijack_optimizations import invokeAI_mps_available

import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
Expand Down Expand Up @@ -40,17 +38,16 @@ def apply_optimizations():
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else:
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
Expand Down
124 changes: 99 additions & 25 deletions modules/sd_hijack_optimizations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import sys
import traceback
import importlib
import psutil

import torch
from torch import einsum
Expand All @@ -12,6 +12,8 @@
from modules import shared
from modules.hypernetworks import hypernetwork

from .sub_quadratic_attention import efficient_dot_product_attention


if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
Expand All @@ -22,6 +24,19 @@
print(traceback.format_exc(), file=sys.stderr)


def get_available_vram():
if shared.device.type == 'cuda':
stats = torch.cuda.memory_stats(shared.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
else:
return psutil.virtual_memory().available


# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
Expand Down Expand Up @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):

r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)

stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_available_vram()

gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
Expand Down Expand Up @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)


def check_for_psutil():
try:
spec = importlib.util.find_spec('psutil')
return spec is not None
except ModuleNotFoundError:
return False

invokeAI_mps_available = check_for_psutil()

# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
mem_total_gb = psutil.virtual_memory().total // (1 << 30)

def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
Expand Down Expand Up @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):

# -- End of code from https://github.com/invoke-ai/InvokeAI --


# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
def sub_quad_attention_forward(self, x, context=None, mask=None):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."

h = self.heads

q = self.to_q(x)
context = default(context, x)

context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x

q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)

x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)

x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)

out_proj, dropout = self.to_out
x = out_proj(x)
x = dropout(x)

return x

def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens

available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)

if chunk_threshold_bytes is None:
chunk_threshold_bytes = available_vram
elif chunk_threshold_bytes == 0:
chunk_threshold_bytes = None

if kv_chunk_size_min is None:
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
elif kv_chunk_size_min == 0:
kv_chunk_size_min = None

if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens

return efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min = kv_chunk_size_min,
use_checkpoint=use_checkpoint,
)


def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
Expand Down Expand Up @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x):

h_ = torch.zeros_like(k, device=q.device)

stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_available_vram()

tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
Expand Down Expand Up @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)

def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
4 changes: 4 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
Expand Down
Loading

0 comments on commit d782a95

Please sign in to comment.