Skip to content

Commit

Permalink
Inference mode
Browse files Browse the repository at this point in the history
  • Loading branch information
caetas committed Sep 22, 2023
1 parent 455e2d6 commit 3ab15f7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 4 deletions.
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Run the following commands:

```bash
chmod -X control_execute.sh
bash control_execute.sh
bash train_script.sh
```

**NOTE: You can skip the first command after the first execution.**
Expand All @@ -78,8 +78,13 @@ Although my input sketch is very rudimentary, the trained network can follow the

You can check more examples in the [reports/figures](reports/figures/)

## TO-DO
- Add inference-only mode.
## Inference Mode

With a Streamlit app, you can draw your own sketch of a Pokemon and ask the pretrained ControlNet to generate an image based on your sketch and a prompt. The influence of the prompt and of the ControlNet can also be adjusted via some sliders.

```bash
streamlit run app.py
```

## Documentation

Expand Down
4 changes: 3 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ psycopg2-binary==2.9.2
# Python-dotenv reads key-value pairs from a .env file and can set them as environment variables.
# It helps in the development of applications following the 12-factor (https://12factor.net/) principles.
python-dotenv==0.19.2
s3fs==2022.1.0
s3fs
zulip==0.8.2

streamlit
streamlit-drawable-canvas
datasets
opencv-python
matplotlib
Expand Down
64 changes: 64 additions & 0 deletions src/finetune_sd/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image

@st.cache_resource
def load_model():

base_model_path = "runwayml/stable-diffusion-v1-5"
controlnet_path = "./../../models/checkpoint-10000/controlnet"

controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True, safety_checker = None
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed
pipe.enable_xformers_memory_efficient_attention()

pipe.enable_model_cpu_offload()

return pipe

@st.cache_data
def generate_image(_pipe, control_image, guidance_scale, controlnet_conditioning_scale, prompt):
image = pipe(prompt, num_inference_steps=30, image=control_image, guidance_scale=guidance_scale, controlnet_conditioning_scale=controlnet_conditioning_scale).images[0]
return image

pipe = load_model()
# create title
st.title("Stable Diffusion ControlNet Demo for Pokemon Generation")

# create drawable canvas that is black, 512x512, and that you can draw in white
canvas_result = st_canvas(
fill_color='#000000',
stroke_width=10,
stroke_color='#FFFFFF',
background_color='#000000',
width=700,
height=700,
drawing_mode="freedraw",
key="canvas",
)

# add a slider to control guidance scale
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=10.0, value=7.5, step=0.5)
# add a slider to control the controlnet_conditioning_scale
controlnet_conditioning_scale = st.slider("ControlNet Conditioning Scale", min_value=0.0, max_value=1.0, value=0.8, step=0.1)
# add a box to enter the prompt
prompt = st.text_input("Prompt", value="a pokemon that looks like a dragon")

# get the image from the canvas when you click the button
if st.button('Generate Pokemon'):
if canvas_result.image_data is not None:
control_image = canvas_result.image_data.copy()
control_image = Image.fromarray(control_image, 'RGB')
# resize the Pillow image to 512x512
control_image = control_image.resize((512, 512))
image = generate_image(pipe, control_image, guidance_scale, controlnet_conditioning_scale, prompt)
st.image(image, caption="Generated Pokemon", use_column_width=True)
File renamed without changes.

0 comments on commit 3ab15f7

Please sign in to comment.