Skip to content

Commit

Permalink
updating tilevae
Browse files Browse the repository at this point in the history
  • Loading branch information
Fanghua-Yu committed Feb 26, 2024
1 parent f874d90 commit f5d54aa
Show file tree
Hide file tree
Showing 6 changed files with 1,148 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir '/opt/data/private/LV_Dataset/
### Gradio Demo
```Shell
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_image_slider --log_history
# less VRAM & slower (12G for Diffusion, 16G for LLaVA)
CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_image_slider --log_history --loading_half_params --use_tile_vae --load_8bit_llava
```
<p align="center">
<img src="assets/DemoGuide.png">
Expand Down
18 changes: 17 additions & 1 deletion SUPIR/models/SUPIR_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from SUPIR.utils.colorfix import wavelet_reconstruction, adaptive_instance_normalization
from pytorch_lightning import seed_everything
from torch.nn.functional import interpolate
from SUPIR.utils.tilevae import VAEHook

class SUPIRModel(DiffusionEngine):
def __init__(self, control_stage_config, ae_dtype='fp32', diffusion_dtype='fp32', p_p='', n_p='', *args, **kwargs):
Expand Down Expand Up @@ -131,7 +132,8 @@ def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, res
batch_uc = copy.deepcopy(batch)
batch_uc['txt'] = [n_p for _ in p]

c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)
with torch.cuda.amp.autocast(dtype=self.ae_dtype):
c, uc = self.conditioner.get_unconditional_conditioning(batch, batch_uc)

denoiser = lambda input, sigma, c, control_scale: self.denoiser(
self.model, input, sigma, c, control_scale, **kwargs
Expand All @@ -148,6 +150,20 @@ def batchify_sample(self, x, p, p_p='default', n_p='default', num_steps=100, res
samples = adaptive_instance_normalization(samples, x_stage1)
return samples

def init_tile_vae(self, encoder_tile_size=512, decoder_tile_size=64):
self.first_stage_model.denoise_encoder.original_forward = self.first_stage_model.denoise_encoder.forward
self.first_stage_model.encoder.original_forward = self.first_stage_model.encoder.forward
self.first_stage_model.decoder.original_forward = self.first_stage_model.decoder.forward
self.first_stage_model.denoise_encoder.forward = VAEHook(
self.first_stage_model.denoise_encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
fast_encoder=False, color_fix=False, to_gpu=True)
self.first_stage_model.encoder.forward = VAEHook(
self.first_stage_model.encoder, encoder_tile_size, is_decoder=False, fast_decoder=False,
fast_encoder=False, color_fix=False, to_gpu=True)
self.first_stage_model.decoder.forward = VAEHook(
self.first_stage_model.decoder, decoder_tile_size, is_decoder=True, fast_decoder=False,
fast_encoder=False, color_fix=False, to_gpu=True)


if __name__ == '__main__':
from SUPIR.util import create_model, load_state_dict
Expand Down
138 changes: 138 additions & 0 deletions SUPIR/utils/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import sys
import contextlib
from functools import lru_cache

import torch
#from modules import errors

if sys.platform == "darwin":
from modules import mac_specific


def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_specific.has_mps


def get_cuda_device_string():
return "cuda"


def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()

if has_mps():
return "mps"

return "cpu"


def get_optimal_device():
return torch.device(get_optimal_device_name())


def get_device_for(task):
return get_optimal_device()


def torch_gc():

if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

if has_mps():
mac_specific.torch_mps_gc()


def enable_tf32():
if torch.cuda.is_available():

# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


enable_tf32()
#errors.run(enable_tf32, "Enabling TF32")

cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
unet_needs_upcast = False


def cond_cast_unet(input):
return input.to(dtype_unet) if unet_needs_upcast else input


def cond_cast_float(input):
return input.float() if unet_needs_upcast else input


def randn(seed, shape):
torch.manual_seed(seed)
return torch.randn(shape, device=device)


def randn_without_seed(shape):
return torch.randn(shape, device=device)


def autocast(disable=False):
if disable:
return contextlib.nullcontext()

return torch.autocast("cuda")


def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()


class NansException(Exception):
pass


def test_for_nans(x, where):
if not torch.all(torch.isnan(x)).item():
return

if where == "unet":
message = "A tensor with all NaNs was produced in Unet."

elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."

else:
message = "A tensor with all NaNs was produced."

message += " Use --disable-nan-check commandline argument to disable this check."

raise NansException(message)


@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""

x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)

x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
Loading

0 comments on commit f5d54aa

Please sign in to comment.