Skip to content

Commit

Permalink
enabled fp16 in AltDiffusion
Browse files Browse the repository at this point in the history
Signed-off-by: Anhforth <[email protected]>
  • Loading branch information
Anhforth authored and ftgreat committed Jan 13, 2023
1 parent f1ea4fe commit c1e230c
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 22 deletions.
6 changes: 4 additions & 2 deletions examples/AltDiffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor

from flagai.fp16 import FP16_Module
# Initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


loader = AutoLoader(task_name="text2img", #contrastive learning
model_name="AltDiffusion-m9",
model_dir="./checkpoints")
model_dir="./checkpoints",
**{"use_fp16":False})

model = loader.get_model()
model.eval()
Expand Down
4 changes: 2 additions & 2 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(self,
class_num=2)
"""

raw_model_name = copy.deepcopy(model_name)

model_name = model_name.lower()
Expand All @@ -195,7 +194,6 @@ def __init__(self,

download_path = os.path.join(model_dir, raw_model_name)
print("*" * 20, task_name, model_name)

model_name_ = self.is_exist_finetuned_model(raw_model_name, task_name)
self.model = getattr(LazyImport(self.model_name[0]),
self.model_name[1]).from_pretrain(
Expand All @@ -204,6 +202,8 @@ def __init__(self,
only_download_config=only_download_config,
device=device,
**kwargs)
if kwargs.get("use_fp16", None):
self.model.half()

if model_type == "nlp":
tokenizer_class = getattr(LazyImport("flagai.data.tokenizer"),
Expand Down
8 changes: 4 additions & 4 deletions flagai/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def load_local(checkpoint_path):
model.load_weights(checkpoint_path)
return model

def load_diffusion_local(yaml_path, only_download_config=False):
def load_diffusion_local(yaml_path, only_download_config=False, **kwargs):
"""
Now only diffusion models requires yaml
"""
Expand All @@ -126,8 +126,8 @@ def load_diffusion_local(yaml_path, only_download_config=False):
config = OmegaConf.load(f"{yaml_path}")
model_config = config.model
model_config.params.cond_stage_config.params.download_path = raw_download_path

model = cls(**model_config.get("params", dict()))
kwargs.update(model_config.get("params", dict()))
model = cls(**kwargs)
if not only_download_config:
model = cls._load_state_dict_into_model(
model,
Expand All @@ -140,7 +140,7 @@ def load_diffusion_local(yaml_path, only_download_config=False):
"""
Now only diffusion models requires yaml
"""
return load_diffusion_local(yaml_path, only_download_config=only_download_config)
return load_diffusion_local(yaml_path, only_download_config=only_download_config, **kwargs)
elif os.path.exists(config_path):
"""
It is fine when checkpoint_path does not exist, for the case that only_download_config=True
Expand Down
7 changes: 4 additions & 3 deletions flagai/model/mm/AltDiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flagai.model.mm.utils import make_beta_schedule, extract_into_tensor, noise_like
from flagai.model.mm.Sampler import DDIMSampler
from flagai.model.base_model import BaseModel
from torch.cuda.amp import autocast as autocast

__conditioning_keys__ = {
'concat': 'c_concat',
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
**kwargs,
):
super(DDPM, self).__init__(unet_config, **kwargs)

unet_config.params.update(kwargs)
assert parameterization in [
"eps", "x0"
], 'currently only supporting "eps" and "x0"'
Expand Down Expand Up @@ -854,7 +855,6 @@ def get_input(self,

if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
# Determine if learning the cond, otherwise it will return
# if the image editing is driven by texts, we will process the init_img and caption
if cond_key == "img_and_caption":
Expand Down Expand Up @@ -1269,7 +1269,8 @@ def apply_model(self, x_noisy, t, cond, return_ids=False):
x_recon = fold(o) / normalization

else:
x_recon = self.model(x_noisy, t, **cond)
with autocast():
x_recon = self.model(x_noisy, t, **cond)

if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
Expand Down
18 changes: 12 additions & 6 deletions flagai/model/mm/Unets/Unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
timestep_embedding,
)
from flagai.model.mm.attentions.attention import SpatialTransformer

from torch.cuda.amp import autocast as autocast

# dummy replace
def convert_module_to_f16(x):
Expand Down Expand Up @@ -78,11 +78,12 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
with autocast():
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
x = layer(x.half())
return x


Expand Down Expand Up @@ -465,12 +466,13 @@ def __init__(
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):

super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
assert context_dim is not None, 'You forgot to include the dimension of your cross-attention conditioning...'

if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
assert use_spatial_transformer, 'You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
Expand Down Expand Up @@ -500,6 +502,7 @@ def __init__(
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
self.use_fp16 = use_fp16

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
Expand Down Expand Up @@ -719,7 +722,10 @@ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.use_fp16:
emb = self.time_embed(t_emb.half())
else:
emb = self.time_embed(t_emb)

if self.num_classes is not None:
assert y.shape == (x.shape[0],)
Expand Down
11 changes: 6 additions & 5 deletions flagai/model/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
from contextlib import contextmanager, nullcontext
from einops import rearrange
from torch.cuda.amp import autocast as autocast

class Predictor:
def __init__(self, model, tokenizer=None):
Expand Down Expand Up @@ -367,7 +368,8 @@ def predict_generate_images(self,
f: int = 8,
scale: float = 7.5,
from_file: str = None,
seed: int = 34234):
seed: int = 34234,
fp16: bool = False):
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from flagai.model.predictor.utils import chunk, check_safety, get_safety_checker
Expand All @@ -389,7 +391,6 @@ def predict_generate_images(self,
C: channels of images, 4 for colored images
"""
seed_everything(seed)

assert "diffusion" in self.class_name.lower()
device = next(self.model.parameters()).device
if plms:
Expand Down Expand Up @@ -447,9 +448,9 @@ def predict_generate_images(self,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)

x_samples_ddim = self.model.decode_first_stage(
samples_ddim)
with autocast():
x_samples_ddim = self.model.decode_first_stage(
samples_ddim)
x_samples_ddim = torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(
Expand Down

0 comments on commit c1e230c

Please sign in to comment.