Skip to content

Commit

Permalink
upload
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jul 19, 2022
1 parent 3f0b44b commit 37fe8e0
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 526 deletions.
138 changes: 0 additions & 138 deletions conversion.py

This file was deleted.

17 changes: 16 additions & 1 deletion scripts/convert_ldm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNetUnconditionalModel, LatentDiffusionUncondPipeline


def shave_segments(path, n_shave_prefix_segments=1):
Expand Down Expand Up @@ -314,4 +315,18 @@ def convert_ldm_checkpoint(checkpoint, config):
config = json.loads(f.read())

converted_checkpoint = convert_ldm_checkpoint(checkpoint, config)
torch.save(checkpoint, args.dump_path)

if "ldm" in config:
del config["ldm"]

model = UNetUnconditionalModel(**config)
model.load_state_dict(converted_checkpoint)

try:
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))

pipe = LatentDiffusionUncondPipeline(unet=model, scheduler=scheduler, vae=vqvae)
pipe.save_pretrained(args.dump_path)
except:
model.save_pretrained(args.dump_path)
97 changes: 58 additions & 39 deletions scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@
import argparse
import json
import torch
from diffusers import UNetUnconditionalModel

from diffusers import UNetUnconditionalModel


def convert_ncsnpp_checkpoint(checkpoint, config):
"""
Takes a state dict and the path to
"""
new_model_architecture = UNetUnconditionalModel(**config)
new_model_architecture.time_steps.W.data= checkpoint['all_modules.0.W'].data
new_model_architecture.time_steps.weight.data = checkpoint['all_modules.0.W'].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint['all_modules.1.weight'].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint['all_modules.1.bias'].data
new_model_architecture.time_embedding.linear_2.weight.data = checkpoint['all_modules.2.weight'].data
new_model_architecture.time_embedding.linear_2.bias.data= checkpoint['all_modules.2.bias'].data

new_model_architecture.conv_in.weight.data = checkpoint['all_modules.3.weight'].data
new_model_architecture.conv_in.bias.data = checkpoint['all_modules.3.bias'].data
new_model_architecture = UNetUnconditionalModel(**config)
new_model_architecture.time_steps.W.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_steps.weight.data = checkpoint["all_modules.0.W"].data
new_model_architecture.time_embedding.linear_1.weight.data = checkpoint["all_modules.1.weight"].data
new_model_architecture.time_embedding.linear_1.bias.data = checkpoint["all_modules.1.bias"].data

new_model_architecture.time_embedding.linear_2.weight.data = checkpoint["all_modules.2.weight"].data
new_model_architecture.time_embedding.linear_2.bias.data = checkpoint["all_modules.2.bias"].data

new_model_architecture.conv_in.weight.data = checkpoint["all_modules.3.weight"].data
new_model_architecture.conv_in.bias.data = checkpoint["all_modules.3.bias"].data

new_model_architecture.conv_norm_out.weight.data = checkpoint[list(checkpoint.keys())[-4]].data
new_model_architecture.conv_norm_out.bias.data = checkpoint[list(checkpoint.keys())[-3]].data
Expand All @@ -44,12 +43,11 @@ def convert_ncsnpp_checkpoint(checkpoint, config):

module_index = 4


def set_attention_weights(new_layer,old_checkpoint,index):
def set_attention_weights(new_layer, old_checkpoint, index):
new_layer.query.weight.data = old_checkpoint[f"all_modules.{index}.NIN_0.W"].data.T
new_layer.key.weight.data = old_checkpoint[f"all_modules.{index}.NIN_1.W"].data.T
new_layer.value.weight.data = old_checkpoint[f"all_modules.{index}.NIN_2.W"].data.T

new_layer.query.bias.data = old_checkpoint[f"all_modules.{index}.NIN_0.b"].data
new_layer.key.bias.data = old_checkpoint[f"all_modules.{index}.NIN_1.b"].data
new_layer.value.bias.data = old_checkpoint[f"all_modules.{index}.NIN_2.b"].data
Expand All @@ -60,7 +58,7 @@ def set_attention_weights(new_layer,old_checkpoint,index):
new_layer.group_norm.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
new_layer.group_norm.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.bias"].data

def set_resnet_weights(new_layer,old_checkpoint,index):
def set_resnet_weights(new_layer, old_checkpoint, index):
new_layer.conv1.weight.data = old_checkpoint[f"all_modules.{index}.Conv_0.weight"].data
new_layer.conv1.bias.data = old_checkpoint[f"all_modules.{index}.Conv_0.bias"].data
new_layer.norm1.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_0.weight"].data
Expand All @@ -70,7 +68,7 @@ def set_resnet_weights(new_layer,old_checkpoint,index):
new_layer.conv2.bias.data = old_checkpoint[f"all_modules.{index}.Conv_1.bias"].data
new_layer.norm2.weight.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.weight"].data
new_layer.norm2.bias.data = old_checkpoint[f"all_modules.{index}.GroupNorm_1.bias"].data

new_layer.time_emb_proj.weight.data = old_checkpoint[f"all_modules.{index}.Dense_0.weight"].data
new_layer.time_emb_proj.bias.data = old_checkpoint[f"all_modules.{index}.Dense_0.bias"].data

Expand All @@ -81,47 +79,47 @@ def set_resnet_weights(new_layer,old_checkpoint,index):
for i, block in enumerate(new_model_architecture.downsample_blocks):
has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j],checkpoint, module_index)
set_resnet_weights(block.resnets[j], checkpoint, module_index)
module_index += 1
if has_attentions:
set_attention_weights(block.attentions[j],checkpoint, module_index)
set_attention_weights(block.attentions[j], checkpoint, module_index)
module_index += 1

if hasattr(block, "downsamplers") and block.downsamplers is not None:
set_resnet_weights(block.resnet_down,checkpoint, module_index)
set_resnet_weights(block.resnet_down, checkpoint, module_index)
module_index += 1
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.Conv_0.weight"].data
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.Conv_0.bias"].data
module_index += 1



set_resnet_weights(new_model_architecture.mid.resnets[0],checkpoint,module_index)
set_resnet_weights(new_model_architecture.mid.resnets[0], checkpoint, module_index)
module_index += 1
set_attention_weights(new_model_architecture.mid.attentions[0],checkpoint, module_index)
set_attention_weights(new_model_architecture.mid.attentions[0], checkpoint, module_index)
module_index += 1
set_resnet_weights(new_model_architecture.mid.resnets[1],checkpoint,module_index)
set_resnet_weights(new_model_architecture.mid.resnets[1], checkpoint, module_index)
module_index += 1

for i, block in enumerate(new_model_architecture.upsample_blocks):
has_attentions = hasattr(block, "attentions")
for j in range(len(block.resnets)):
set_resnet_weights(block.resnets[j],checkpoint, module_index)
set_resnet_weights(block.resnets[j], checkpoint, module_index)
module_index += 1
if has_attentions:
set_attention_weights(block.attentions[0],checkpoint, module_index) # why can there only be a single attention layer for up?
set_attention_weights(
block.attentions[0], checkpoint, module_index
) # why can there only be a single attention layer for up?
module_index += 1

if hasattr(block, "resnet_up") and block.resnet_up is not None:
block.skip_norm.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
block.skip_norm.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
module_index += 1
block.skip_conv.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
block.skip_conv.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
module_index += 1
set_resnet_weights(block.resnet_up,checkpoint, module_index)
set_resnet_weights(block.resnet_up, checkpoint, module_index)
module_index += 1

new_model_architecture.conv_norm_out.weight.data = checkpoint[f"all_modules.{module_index}.weight"].data
new_model_architecture.conv_norm_out.bias.data = checkpoint[f"all_modules.{module_index}.bias"].data
module_index += 1
Expand All @@ -130,11 +128,16 @@ def set_resnet_weights(new_layer,old_checkpoint,index):

return new_model_architecture.state_dict()


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--checkpoint_path", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt", type=str, required=False, help="Path to the checkpoint to convert."
"--checkpoint_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model.pt",
type=str,
required=False,
help="Path to the checkpoint to convert.",
)

parser.add_argument(
Expand All @@ -146,19 +149,35 @@ def set_resnet_weights(new_layer,old_checkpoint,index):
)

parser.add_argument(
"--dump_path", default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt", type=str, required=False, help="Path to the output model."
"--dump_path",
default="/Users/arthurzucker/Work/diffusers/ArthurZ/diffusion_model_new.pt",
type=str,
required=False,
help="Path to the output model.",
)

args = parser.parse_args()




checkpoint = torch.load(args.checkpoint_path, map_location="cpu")

with open(args.config_file) as f:
config = json.loads(f.read())

converted_checkpoint = convert_ncsnpp_checkpoint(
checkpoint,
config,
)

if "sde" in config:
del config["sde"]

model = UNetUnconditionalModel(**config)
model.load_state_dict(converted_checkpoint)

try:
scheduler = ScoreSdeVeScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))

converted_checkpoint = convert_ncsnpp_checkpoint(checkpoint, config,)
torch.save(converted_checkpoint, args.dump_path)
pipe = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
except:
model.save_pretrained(args.dump_path)
Loading

0 comments on commit 37fe8e0

Please sign in to comment.