Skip to content

Commit

Permalink
add pipeline_class_name argument to Stable Diffusion conversion script (
Browse files Browse the repository at this point in the history
huggingface#4461)

* add pipeline class

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* style

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
3 people authored Aug 7, 2023
1 parent 71c8224 commit aef11cb
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" Conversion script for the LDM checkpoints. """

import argparse
import importlib

import torch

Expand Down Expand Up @@ -133,8 +134,22 @@
required=False,
help="Set to a path, hub id to an already converted vae to not convert it again.",
)
parser.add_argument(
"--pipeline_class_name",
type=str,
default=None,
required=False,
help="Specify the pipeline class name",
)

args = parser.parse_args()

if args.pipeline_class_name is not None:
library = importlib.import_module("diffusers")
class_obj = getattr(library, args.pipeline_class_name)
else:
pipeline_class = None

pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file,
Expand All @@ -152,6 +167,7 @@
clip_stats_path=args.clip_stats_path,
controlnet=args.controlnet,
vae_path=args.vae_path,
pipeline_class=pipeline_class,
)

if args.half:
Expand Down

0 comments on commit aef11cb

Please sign in to comment.