Skip to content

Commit

Permalink
Guidance (Stability-AI#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx authored Oct 14, 2022
1 parent b95c65a commit df5339e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 33 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='stability-sdk',
version='0.2.4',
version='0.2.5',
author='Wes Brown',
author_email='[email protected]',
maintainer='David Marx',
Expand Down
104 changes: 72 additions & 32 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __init__(

def generate(
self,
prompt: Union[List[str], str],
prompt: Union[str, List[str], generation.Prompt, List[generation.Prompt]],
init_image: Optional[Image.Image] = None,
mask_image: Optional[Image.Image] = None,
height: int = 512,
Expand All @@ -228,6 +228,11 @@ def generate(
samples: int = 1,
safety: bool = True,
classifiers: Optional[generation.ClassifierParameters] = None,
guidance_preset: generation.GuidancePreset = generation.GUIDANCE_PRESET_NONE,
guidance_cuts: int = 0,
guidance_strength: Optional[float] = None,
guidance_prompt: Union[str, generation.Prompt] = None,
guidance_models: List[str] = None,
) -> Generator[generation.Answer, None, None]:
"""
Generate images from a prompt.
Expand All @@ -246,6 +251,11 @@ def generate(
:param samples: Number of samples to generate.
:param safety: DEPRECATED/UNUSED - Cannot be disabled.
:param classifiers: DEPRECATED/UNUSED - Has no effect on image generation.
:param guidance_preset: Guidance preset to use. See generation.GuidancePreset for supported values.
:param guidance_cuts: Number of cuts to use for guidance.
:param guidance_strength: Strength of the guidance. We recommend values in range [0.0,1.0]. A good default is 0.25
:param guidance_prompt: Prompt to use for guidance, defaults to `prompt` argument (above) if not specified.
:param guidance_models: Models to use for guidance.
:return: Generator of Answer objects.
"""
if (prompt is None) and (init_image is None):
Expand All @@ -265,50 +275,80 @@ def generate(
else:
seed = list(seed)

prompt_ = []
if isinstance(prompt, str):
prompt_ = [generation.Prompt(text=prompt)]
elif isinstance(prompt, Sequence):
prompt_ = [generation.Prompt(text=p) for p in prompt]
else:
raise TypeError("prompt must be a string or a sequence")

prompts: List[generation.Prompt] = []
if any(isinstance(prompt, t) for t in (str, generation.Prompt)):
prompt = [prompt]
for p in prompt:
if isinstance(p, str):
p = generation.Prompt(text=p)
elif not isinstance(p, generation.Prompt):
raise TypeError("prompt must be a string or generation.Prompt object")
prompts.append(p)

step_parameters = dict(
scaled_step=0,
sampler=generation.SamplerParameters(cfg_scale=cfg_scale),
schedule=generation.ScheduleParameters(
start=start_schedule,
end=end_schedule,
),
)

if init_image is not None:
prompt_ += [image_to_prompt(init_image, init=True)]
parameters = (
generation.StepParameter(
scaled_step=0,
sampler=generation.SamplerParameters(
cfg_scale=cfg_scale,
),
schedule=generation.ScheduleParameters(
start=start_schedule,
end=end_schedule,
),
),
)
prompts += [image_to_prompt(init_image, init=True)]

if mask_image is not None:
prompt_ += [image_to_prompt(mask_image, mask=True)]
else:
parameters = (
generation.StepParameter(
scaled_step=0,
sampler=generation.SamplerParameters(cfg_scale=cfg_scale),
),
)
prompts += [image_to_prompt(mask_image, mask=True)]


if guidance_prompt:
if isinstance(guidance_prompt, str):
guidance_prompt = generation.Prompt(text=guidance_prompt)
elif not isinstance(guidance_prompt, generation.Prompt):
raise ValueError("guidance_prompt must be a string or Prompt object")
if guidance_strength == 0.0:
guidance_strength = None


# Build our CLIP parameters
if guidance_preset is not generation.GUIDANCE_PRESET_NONE:
# to do: make it so user can override this
step_parameters['sampler']=None

if guidance_models:
guiders = [generation.Model(alias=model) for model in guidance_models]
else:
guiders = None

if guidance_cuts:
cutouts = generation.CutoutParameters(count=guidance_cuts)
else:
cutouts = None

step_parameters["guidance"] = generation.GuidanceParameters(
guidance_preset=guidance_preset,
instances=[
generation.GuidanceInstanceParameters(
guidance_strength=guidance_strength,
models=guiders,
cutouts=cutouts,
prompt=guidance_prompt,
)
],
)

rq = generation.Request(
engine_id=self.engine,
request_id=request_id,
prompt=prompt_,
prompt=prompts,
image=generation.ImageParameters(
transform=generation.TransformType(diffusion=sampler),
height=height,
width=width,
seed=seed,
steps=steps,
samples=samples,
parameters=parameters,
parameters=[generation.StepParameter(**step_parameters)],
),
)

Expand Down

0 comments on commit df5339e

Please sign in to comment.