Skip to content

Commit

Permalink
improve code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
C43H66N12O12S2 authored and AUTOMATIC1111 committed Sep 29, 2022
1 parent b6f80bd commit 965dcf4
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions modules/sd_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
samplers_for_img2img.remove(samplers_for_img2img[6])
samplers_for_img2img.remove(samplers_for_img2img[6])
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]

sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
Expand Down Expand Up @@ -314,12 +312,12 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):

extra_params_kwargs = self.initialize(p)
if 'sigma_min' in inspect.signature(self.func).parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'n' in inspect.signature(self.func).parameters:
samples = self.func(self.model_wrap_cfg, x, sigma_min=self.model_wrap.sigmas[0].item(), sigma_max=self.model_wrap.sigmas[-1].item(), n=steps, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
samples = self.func(self.model_wrap_cfg, x, sigma_min=self.model_wrap.sigmas[0].item(), sigma_max=self.model_wrap.sigmas[-1].item(), extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)

extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples

0 comments on commit 965dcf4

Please sign in to comment.