Skip to content

Commit

Permalink
Add Singlestep DPM-Solver (singlestep high-order schedulers) (hugging…
Browse files Browse the repository at this point in the history
…face#1442)

* add singlestep dpmsolver

* fix a style typo

* fix a style typo

* add docs

* finish

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
LuChengTHU and patrickvonplaten authored Dec 7, 2022
1 parent 6a7f1f0 commit 8e74efa
Show file tree
Hide file tree
Showing 7 changed files with 800 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).

[[autodoc]] DDPMScheduler

#### Singlestep DPM-Solver

Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).

[[autodoc]] DPMSolverSinglestepScheduler

#### Multistep DPM-Solver

Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
Expand Down
599 changes: 599 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"HeunDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverSinglestepScheduler",
]


Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class DPMSolverSinglestepScheduler(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
177 changes: 177 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
Expand Down Expand Up @@ -870,6 +871,182 @@ def test_full_loop_with_no_set_alpha_to_one(self):
assert abs(result_mean.item() - 0.1941) < 1e-3


class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DPMSolverSinglestepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)

def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
"prediction_type": "epsilon",
"thresholding": False,
"sample_max_value": 1.0,
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
}

config.update(**kwargs)
return config

def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]

with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]

output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

def test_from_save_pretrained(self):
pass

def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]

for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)

# copy over dummy past residuals (must be after setting timesteps)
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]

with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)

# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]

output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)

num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)

for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample

return sample

def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)

def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
for solver_type in ["midpoint", "heun"]:
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
algorithm_type="dpmsolver++",
solver_order=order,
solver_type=solver_type,
)

def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)

def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"

def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)

def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2791) < 1e-3

def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.1453) < 1e-3

def test_fp16_support(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
scheduler = scheduler_class(**scheduler_config)

num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter.half()
scheduler.set_timesteps(num_inference_steps)

for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample

assert sample.dtype == torch.float16


class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DPMSolverMultistepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)
Expand Down

0 comments on commit 8e74efa

Please sign in to comment.