Skip to content

Commit

Permalink
[Flax] Add Flax inpainting impl (huggingface#1966)
Browse files Browse the repository at this point in the history
* [Flax] Add Flax inpainting impl

* fixed copies, add README.md

* fixed README.md

* add test

* format

* update README.md
  • Loading branch information
xvjiarui authored Jan 17, 2023
1 parent f77ff56 commit a43bdd0
Show file tree
Hide file tree
Showing 7 changed files with 679 additions and 2 deletions.
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,53 @@ output = pipeline(
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
```

Diffusers also has a Text-guided inpainting pipeline with Flax/Jax

```python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import PIL
import requests
from io import BytesIO


from diffusers import FlaxStableDiffusionInpaintPipeline

def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))

pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained("xvjiarui/stable-diffusion-2-inpainting")

prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
init_image = num_samples * [init_image]
mask_image = num_samples * [mask_image]
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)


# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
processed_masked_images = shard(processed_masked_images)
processed_masks = shard(processed_masks)

images = pipeline(prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```

### Image-to-Image text-guided generation with Stable Diffusion

The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,8 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
from .pipelines import (
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,8 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
from .stable_diffusion import (
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
Loading

0 comments on commit a43bdd0

Please sign in to comment.