Skip to content

Commit

Permalink
Pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
kumquatexpress committed Oct 19, 2022
1 parent 873bf7d commit 62a1335
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
3 changes: 2 additions & 1 deletion cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ build:
cuda: "11.6.2"
python_version: "3.10"
python_packages:
- "diffusers==0.4.0"
- "git+https://github.com/harvestlabs/diffusers@39b181e0cbbc1e3960d523c5115e957e224a5c80"
- "einops==0.5.0"
- "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "ftfy==6.1.1"
- "scipy==1.9.0"
Expand Down
15 changes: 8 additions & 7 deletions pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,11 @@ def __init__(
feature_extractor=feature_extractor,
)

inpaint = StableDiffusionInpaintPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
inpaint = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)

self.register_modules(
Expand Down Expand Up @@ -93,6 +90,7 @@ def disable_nsfw_filter(self):
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
init_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
mask_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
strength: float = 0.8,
Expand All @@ -110,6 +108,7 @@ def __call__(
if init_image is None:
result = self.text2img(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
Expand All @@ -124,6 +123,7 @@ def __call__(
if mask_image is None:
result = self.img2img(
prompt=prompt,
negative_prompt=negative_prompt,
init_image=init_image,
strength=strength,
num_inference_steps=num_inference_steps,
Expand All @@ -136,6 +136,7 @@ def __call__(
else:
result = self.inpaint(
prompt=prompt,
negative_prompt=negative_prompt,
init_image=init_image,
mask_image=mask_image,
strength=strength,
Expand Down
14 changes: 4 additions & 10 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
import pipelines


MODEL_CACHE = "diffusers-cache"


def patch_conv(**patch):
cls = torch.nn.Conv2d
init = cls.__init__
Expand All @@ -38,9 +35,6 @@ def setup(self):
self.pipe = pipelines.StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=scheduler,
revision="fp16",
torch_dtype=torch.float16,
cache_dir=MODEL_CACHE,
local_files_only=True,
).to("cuda")
self.pipe.disable_nsfw_filter()
Expand Down Expand Up @@ -74,7 +68,7 @@ def predict(
default=0.8,
),
num_outputs: int = Input(
description="Number of images to output", choices=[1, 2, 3, 4], default=1
description="Number of images to output", choices=[1, 2, 3, 4, 5, 10], default=1
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
Expand Down Expand Up @@ -112,9 +106,9 @@ def predict(

generator = torch.Generator("cuda").manual_seed(seed)
output = self.pipe(
prompt=[prompt] * num_outputs if prompt is not None else None,
negative_prompt=[negative_prompt] *
num_outputs if negative_prompt is not None else None,
prompt=prompt if prompt is not None else None,
negative_prompt=negative_prompt if negative_prompt is not None else None,
num_images_per_prompt=num_outputs,
height=height,
width=width,
init_image=init_image,
Expand Down

0 comments on commit 62a1335

Please sign in to comment.