Skip to content

Commit

Permalink
integrate the new samplers PR
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 6, 2022
1 parent a971e4a commit 5993df2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 87 deletions.
7 changes: 4 additions & 3 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
self.firstphase_height_truncated = int(scale * self.height)

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)

if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
Expand Down Expand Up @@ -520,7 +520,8 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs

shared.state.nextjob()

self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)

noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

# GC now before running the next img2img to prevent running out of memory
Expand Down Expand Up @@ -555,7 +556,7 @@ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mas
self.nmask = None

def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
crop_region = None

if self.image_mask is not None:
Expand Down
59 changes: 31 additions & 28 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,46 @@
import modules.shared as shared


SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])

samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
('Euler', 'sample_euler', ['k_euler']),
('LMS', 'sample_lms', ['k_lms']),
('Heun', 'sample_heun', ['k_heun']),
('DPM2', 'sample_dpm_2', ['k_dpm_2']),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
]

if opts.show_karras_scheduler_variants:
k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2
k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral
k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms
samplers_k_diffusion_ka = [
('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']),
('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']),
('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']),
]
samplers_k_diffusion.extend(samplers_k_diffusion_ka)

samplers_data_k_diffusion = [
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
for label, funcname, aliases in samplers_k_diffusion
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
for label, funcname, aliases, options in samplers_k_diffusion
if hasattr(k_diffusion.sampling, funcname)
]

all_samplers = [
*samplers_data_k_diffusion,
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]

samplers = []
samplers_for_img2img = []


def create_sampler_with_index(list_of_configs, index, model):
config = list_of_configs[index]
sampler = config.constructor(model)
sampler.config = config

return sampler


def set_samplers():
global samplers, samplers_for_img2img

Expand Down Expand Up @@ -130,6 +130,7 @@ def __init__(self, constructor, sd_model):
self.step = 0
self.eta = None
self.default_eta = 0.0
self.config = None

def number_of_needed_noises(self, p):
return 0
Expand Down Expand Up @@ -291,6 +292,7 @@ def __init__(self, funcname, sd_model):
self.stop_at = None
self.eta = None
self.default_eta = 1.0
self.config = None

def callback_state(self, d):
store_latent(d["denoised"])
Expand Down Expand Up @@ -355,11 +357,12 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps

if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.funcname.endswith('ka'):
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
sigmas = self.model_wrap.get_sigmas(steps)

x = x * sigmas[0]

extra_params_kwargs = self.initialize(p)
Expand Down
1 change: 0 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def options_section(section_identifer, options_dict):
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."),
}))

options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
Expand Down
53 changes: 0 additions & 53 deletions scripts/alternate_sampler_noise_schedules.py

This file was deleted.

3 changes: 1 addition & 2 deletions scripts/img2imgalt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from modules import processing, shared, sd_samplers, prompt_parser
from modules.processing import Processed
from modules.sd_samplers import samplers
from modules.shared import opts, cmd_opts, state

import torch
Expand Down Expand Up @@ -159,7 +158,7 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs

combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)

sampler = samplers[p.sampler_index].constructor(p.sd_model)
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)

sigmas = sampler.model_wrap.get_sigmas(p.steps)

Expand Down

0 comments on commit 5993df2

Please sign in to comment.