Skip to content

Commit

Permalink
Improving documentation Signed-off-by: Steven Zimmerman <SZim92@gmail…
Browse files Browse the repository at this point in the history
….com>
  • Loading branch information
SZim92 committed Aug 12, 2024
1 parent 677b8b2 commit 2251887
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions service/schedulers_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import diffusers

# Dictionary mapping scheduler names to their configuration.
# This includes the class name of the scheduler and any keyword arguments.
scheduler_map = {
"DPM++ 2M": {"class_name": "DPMSolverMultistepScheduler", "kwargs": {}},
"DPM++ 2M Karras": {
Expand Down Expand Up @@ -49,26 +51,45 @@
}

# 从 scheduler_map 获取调度器清单
schedulers = list(scheduler_map.keys())
schedulers = list(scheduler_map.keys()) # List of available scheduler names

def set_scheduler(pipe: diffusers.DiffusionPipeline, name: str):
print("---------------------debug ", name)
scheduler_cfg = scheduler_map.get(name)
"""
Sets the scheduler for a diffusion pipeline based on the provided name.
If the name is "None", the pipeline's default scheduler is used.
Args:
pipe (diffusers.DiffusionPipeline): The diffusion pipeline to configure.
name (str): The name of the scheduler to set.
Raises:
Exception: If the specified scheduler name is unknown.
"""
print("---------------------debug ", name) # Debug print statement
scheduler_cfg = scheduler_map.get(name) # Retrieve the scheduler configuration

# Handle the case when no specific scheduler is requested (name is "None")
if name == "None":
if hasattr(pipe.scheduler, "scheduler_config"):
default_class_name = pipe.scheduler.scheduler_config["_class_name"]
else:
default_class_name = pipe.scheduler.config["_class_name"]
# same scheduler
if default_class_name == type(pipe.scheduler).__name__:
return
# If the pipeline already has the default scheduler, do nothing
return
else:
# Get the default scheduler class
scheduler_class = getattr(diffusers, default_class_name)
elif scheduler_cfg is None:
raise Exception(f"unkown scheduler name \"{name}\"")
else:
# If a scheduler configuration is found
elif scheduler_cfg is not None:
# Get the scheduler class from the diffusers library
scheduler_class = getattr(diffusers, scheduler_cfg["class_name"])
else:
# If the scheduler name is not found, raise an exception
raise Exception(f"unkown scheduler name \"{name}\"")
print(f"load scheduler {name}")
# Set the scheduler for the pipeline using its configuration and keyword arguments
pipe.scheduler = scheduler_class.from_config(
pipe.scheduler.config, **scheduler_cfg["kwargs"]
)
Expand Down Expand Up @@ -111,4 +132,4 @@ def set_scheduler(pipe: diffusers.DiffusionPipeline, name: str):
# guidance_scale=guidance_scale,
# generator=generator,
# ).images[0]
# image.show()
# image.show()

0 comments on commit 2251887

Please sign in to comment.