Skip to content

Commit

Permalink
added altdiffusion_m18
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed Mar 28, 2023
1 parent 6eecaa0 commit 66862a0
Show file tree
Hide file tree
Showing 26 changed files with 7,477 additions and 13 deletions.
19 changes: 13 additions & 6 deletions examples/AltDiffusion/generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import sys;sys.path.append("/home/yanzhaodong/FlagAI")
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
Expand All @@ -13,10 +14,16 @@
model_dir="./checkpoints",
use_fp16=False)

loader2 = AutoLoader(task_name="text2img", model_name="AltDiffusion-m9")

model = loader.get_model()
model.eval()
model.to(device)
predictor = Predictor(model)
predictor.predict_generate_images(
"Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation"
)
import pdb;pdb.set_trace()
for name, param in model.named_parameters():
if name.startswith("cond_stage_model"):
print(name)
# model.eval()
# model.to(device)
# predictor = Predictor(model)
# predictor.predict_generate_images(
# "Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation"
# )
28 changes: 28 additions & 0 deletions examples/AltDiffusion/generate_18m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import sys;sys.path.append("/home/yanzhaodong/FlagAI")
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor

# Initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loader = AutoLoader(task_name="text2img", #contrastive learning
model_name="AltDiffusion-m18",
model_dir="./checkpoints",
use_fp16=False)
# loader = AutoLoader(task_name="text2img", #contrastive learning
# model_name="AltDiffusion-m18")
model = loader.get_model()
# for name, param in model.named_parameters():
# if name.startswith("cond_stage_model"):
# print(name)
import pdb;pdb.set_trace()
# model.eval()
# model.to(device)
# predictor = Predictor(model)
# predictor.predict_generate_images_m18(
# "Anime portrait of natalie portman as an anime girl by stanley artgerm lau, wlop, rossdraws, james jean, andrei riabovitchev, marc simonetti, and sakimichan, trending on artstation"
# )
8 changes: 5 additions & 3 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __getattr__(self, name):
"cpm3_lm": ("flagai.model.cpm3_model", "CPM3"),
"cpm3_train": ("flagai.model.cpm3_train_model", "CPM3"),
"diffusion_text2img": ("flagai.model.mm.AltDiffusion", "LatentDiffusion"),
"diffusion2_text2img": ("flagai.model.mm.AltDiffusion2", "LatentDiffusion"),
"altclip_txt_img_matching": ("flagai.model.mm.AltCLIP", "AltCLIP"),
"evaclip_txt_img_matching": ("flagai.model.mm.eva_clip_model", "EVA_CLIP"),
}
Expand Down Expand Up @@ -121,7 +122,9 @@ def __getattr__(self, name):
"altdiffusion":
["flagai.model.mm.diffusion", "LatentDiffusion", "diffusion", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
"altdiffusion-m9":
["flagai.model.mm.diffusion", "LatentDiffusion", "diffusion", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
["flagai.model.mm.diffusion2", "LatentDiffusion", "diffusion", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
"altdiffusion-m18":
["flagai.model.mm.Altdiffusion2", "LatentDiffusion", "diffusion2", "mm","flagai.model.mm.AltCLIP", "AltCLIPProcess"],
"swinv1-base-patch4-window7-224":
["flagai.model.vision.swinv1", "SwinTransformer", "swinv1", "vision"],
"swinv2-base-patch4-window8-256":
Expand Down Expand Up @@ -198,7 +201,6 @@ def __init__(self,
f"For the model_name: {model_name}, these tasks are be supported: {tasks}"
)
return

download_path = os.path.join(model_dir, raw_model_name)
print("*" * 20, task_name, model_name)
model_name_ = self.is_exist_finetuned_model(raw_model_name, task_name)
Expand All @@ -211,7 +213,7 @@ def __init__(self,
**kwargs)
if kwargs.get("use_fp16", None):
self.model.half()

if model_type == "nlp":
if brief_model_name in ["galactica", ]:
self.tokenizer = getattr(LazyImport(MODEL_DICT[model_name][4]),
Expand Down
1 change: 1 addition & 0 deletions flagai/model/mm/AltCLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self,
if text_config_dict is None:
text_config_dict = {}
# when reload the config from local, we need name to select which class should be instanced.
import pdb;pdb.set_trace()
self.text_config = STUDENT_CONFIG_DICT[
kwargs['text_config']['model_type']](**kwargs.pop('text_config'))
self.num_layers = num_layers
Expand Down
2 changes: 1 addition & 1 deletion flagai/model/mm/AltDiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,4 +1932,4 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
]

return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))
((mean1 - mean2)**2) * torch.exp(-logvar2))
Loading

0 comments on commit 66862a0

Please sign in to comment.