Skip to content

Commit

Permalink
[Pytorch] pytorch only timesteps (huggingface#724)
Browse files Browse the repository at this point in the history
* pytorch timesteps

* style

* get rid of if-else

* fix test

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
kashif and patrickvonplaten authored Oct 5, 2022
1 parent 60c9634 commit 726aba0
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:

- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).


## API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,8 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimzed to move all timesteps to correct device beforehand
if torch.is_tensor(self.scheduler.timesteps):
timesteps_tensor = self.scheduler.timesteps.to(self.device)
else:
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)

# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,10 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)

for i, t in enumerate(self.progress_bar(timesteps)):
t_index = t_start + i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,10 @@ def __call__(
latents = init_latents

t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps = self.scheduler.timesteps[t_start:].to(self.device)

for i, t in tqdm(enumerate(timesteps)):
t_index = t_start + i
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
- Schedulers are available in PyTorch and Jax.

## API

Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
Expand All @@ -166,7 +166,7 @@ def _get_variance(self, timestep, prev_timestep):

return variance

def set_timesteps(self, num_inference_steps: int, **kwargs):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -183,7 +183,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs):
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += offset

def step(
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())

self.variance_type = variance_type

def set_timesteps(self, num_inference_steps: int):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -156,9 +156,10 @@ def set_timesteps(self, num_inference_steps: int):
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1]
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)

def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __init__(

# setable values
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None
self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i)

def set_timesteps(self, num_inference_steps: int):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -110,15 +110,16 @@ def set_timesteps(self, num_inference_steps: int):
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
schedule = [
(
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
self.schedule = torch.tensor(schedule, dtype=torch.float32)
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)

def add_noise_to_input(
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
self.plms_timesteps = None
self.timesteps = None

def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand Down Expand Up @@ -184,7 +184,8 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor
::-1
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy

self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)

self.ets = []
self.counter = 0
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def __init__(

self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)

def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
def set_timesteps(
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Expand All @@ -101,7 +103,7 @@ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps

self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)

def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
Expand Down
7 changes: 3 additions & 4 deletions src/diffusers/schedulers/scheduling_sde_vp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit

import math
from typing import Union

import torch

Expand Down Expand Up @@ -52,8 +51,8 @@ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling
self.discrete_sigmas = None
self.timesteps = None

def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)

def step_pred(self, score, x, t, generator=None):
if self.timesteps is None:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def test_steps_offset(self):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))

def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
Expand Down Expand Up @@ -568,10 +568,12 @@ def test_steps_offset(self):
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10)
assert np.equal(
assert torch.equal(
scheduler.timesteps,
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
).all()
torch.LongTensor(
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)

def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
Expand Down

0 comments on commit 726aba0

Please sign in to comment.