Skip to content

Commit

Permalink
improve ddpm conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jul 19, 2022
1 parent cb90fd6 commit 3f0b44b
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions scripts/convert_ddpm_original_checkpoint_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from diffusers import UNetUnconditionalModel
from diffusers import UNetUnconditionalModel, DDPMScheduler, DDPMPipeline
import argparse
import json
import torch
Expand Down Expand Up @@ -56,7 +56,7 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s

if attention_paths_to_split is not None:
if config is None:
raise ValueError(f"Please specify the config if setting 'attention_paths_to_split' to 'True'.")
raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")

for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
Expand Down Expand Up @@ -86,7 +86,6 @@ def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_s
for replacement in additional_replacements:
new_path = new_path.replace(replacement['old'], replacement['new'])


if 'attentions' in new_path:
checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
else:
Expand All @@ -97,7 +96,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""

new_checkpoint = {}

new_checkpoint['time_embedding.linear_1.weight'] = checkpoint['temb.dense.0.weight']
Expand All @@ -121,7 +119,6 @@ def convert_ddpm_checkpoint(checkpoint, config):

for i in range(num_downsample_blocks):
block_id = (i - 1) // (config['num_res_blocks'] + 1)
layer_in_block_id = (i - 1) % (config['num_res_blocks'] + 1)

if any('downsample' in layer for layer in downsample_blocks[i]):
new_checkpoint[f'downsample_blocks.{i}.downsamplers.0.conv.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
Expand All @@ -138,7 +135,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint)


if any('attn' in layer for layer in downsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in downsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in downsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
Expand All @@ -148,7 +144,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_attention_paths(attns[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)


mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
Expand Down Expand Up @@ -186,7 +181,6 @@ def convert_ddpm_checkpoint(checkpoint, config):
paths = renew_resnet_paths(blocks[j])
assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])


if any('attn' in layer for layer in upsample_blocks[i]):
num_attn = len({'.'.join(shave_segments(layer, 2).split('.')[:2]) for layer in upsample_blocks[i] if 'attn' in layer})
attns = {layer_id: [key for key in upsample_blocks[i] if f'attn.{layer_id}' in key] for layer_id in range(num_blocks)}
Expand Down Expand Up @@ -220,12 +214,21 @@ def convert_ddpm_checkpoint(checkpoint, config):
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
)


args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path)

with open(args.config_file) as f:
config = json.loads(f.read())

converted_checkpoint = convert_ddpm_checkpoint(args.checkpoint_path, args.config_file)
torch.save(converted_checkpoint, args.dump_path)
converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)

if "ddpm" in config:
del config["ddpm"]

model = UNetUnconditionalModel(**config)
model.load_state_dict(converted_checkpoint)

scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))

pipe = DDPMPipeline(unet=model, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)

0 comments on commit 3f0b44b

Please sign in to comment.