Skip to content

Commit

Permalink
CLIP interrogator
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Sep 11, 2022
1 parent 13008ba commit f194457
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ __pycache__
/embeddings
/styles.csv
/webui-user.bat
/interrogate
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Styles
- Variations
- Seed resizing
- CLIP interrogator

## Installing and running

Expand Down Expand Up @@ -289,5 +290,6 @@ After that follow the instructions in the `Manual instructions` section starting
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
16 changes: 10 additions & 6 deletions modules/devices.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import torch


# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)

cpu = torch.device("cpu")


def get_optimal_device():
if torch.cuda.is_available():
return torch.device("cuda")
if has_mps:
return torch.device("mps")
return torch.device("cpu")
if torch.cuda.is_available():
return torch.device("cuda")

if has_mps:
return torch.device("mps")

return cpu
142 changes: 142 additions & 0 deletions modules/interrogate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
import sys
import traceback
from collections import namedtuple
import re

import torch

from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

import modules.shared as shared
from modules import devices, paths

blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'

Category = namedtuple("Category", ["name", "topn", "items"])

re_topn = re.compile(r"\.top(\d+)\.")

class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
categories = None

def __init__(self, content_dir):
self.categories = []

if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))

with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]

self.categories.append(Category(name=filename, topn=topn, items=lines))

def load_blip_model(self):
import models.blip

blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()

return blip_model

def load_clip_model(self):
import clip

model, preprocess = clip.load(clip_model_name)
model.eval()
model = model.to(shared.device)

return model, preprocess

def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()

self.blip_model = self.blip_model.to(shared.device)

if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()

self.clip_model = self.clip_model.to(shared.device)

def unload(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu)

if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu)


def rank(self, image_features, text_array, top_count=1):
import clip

top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)

similarity = torch.zeros((1, len(text_array))).to(shared.device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]

top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]


def generate_caption(self, pil_image):
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(shared.device)

with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)

return caption[0]

def interrogate(self, pil_image):
res = None

try:
self.load()

caption = self.generate_caption(pil_image)
res = caption

images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device)

with torch.no_grad():
image_features = self.clip_model.encode_image(images).float()

image_features /= image_features.norm(dim=-1, keepdim=True)

if shared.opts.interrogate_use_builtin_artists:
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]

res += ", " + artist[0]

for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
res += ", " + match

except Exception:
print(f"Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)

self.unload()

return res
1 change: 1 addition & 0 deletions modules/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
(sd_path, 'ldm', 'Stable Diffusion'),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
]

paths = {}
Expand Down
8 changes: 8 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate

config_filename = "config.json"

Expand Down Expand Up @@ -77,6 +78,8 @@ def nextjob(self):
styles_filename = os.path.join(script_path, 'styles.csv')
prompt_styles = modules.styles.load_styles(styles_filename)

interrogator = modules.interrogate.InterrogateModels("interrogate")

face_restorers = []

class Options:
Expand Down Expand Up @@ -123,6 +126,11 @@ def __init__(self, default=None, label="", component=None, component_args=None):
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"interrogate_keep_models_in_memory": OptionInfo(True, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
}

def __init__(self):
Expand Down
18 changes: 16 additions & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,14 @@ def add_style(style_name, text):
return [update, update]


def interrogate(image):
prompt = shared.interrogator.interrogate(image)

return gr_show(True) if prompt is None else prompt

def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
with gr.Row():
with gr.Row(elem_id="toprow"):
txt2img_prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
txt2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
Expand Down Expand Up @@ -365,10 +370,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
)

with gr.Blocks(analytics_enabled=False) as img2img_interface:
with gr.Row():
with gr.Row(elem_id="toprow"):
img2img_prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="img2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
img2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
img2img_interrogate = gr.Button('Interrogate', elem_id="img2img_interrogate", variant='primary')
submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary')
check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)

Expand Down Expand Up @@ -461,6 +467,7 @@ def apply_mode(mode, uploadmask):
inpaint_full_res: gr_show(is_inpaint),
inpainting_mask_invert: gr_show(is_inpaint),
denoising_strength_change_factor: gr_show(is_loopback),
img2img_interrogate: gr_show(not is_inpaint),
}

switch_mode.change(
Expand All @@ -480,6 +487,7 @@ def apply_mode(mode, uploadmask):
inpaint_full_res,
inpainting_mask_invert,
denoising_strength_change_factor,
img2img_interrogate,
]
)

Expand Down Expand Up @@ -540,6 +548,12 @@ def apply_mode(mode, uploadmask):
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)

img2img_interrogate.click(
fn=interrogate,
inputs=[init_img],
outputs=[img2img_prompt],
)

check_progress.click(
fn=check_progress_call,
show_progress=False,
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ fonts
font-roboto
git+https://github.com/crowsonkb/k-diffusion.git
git+https://github.com/TencentARC/GFPGAN.git
timm==0.4.12
fairscale==0.4.4
2 changes: 2 additions & 0 deletions requirements_versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ pytorch_lightning==1.7.2
scikit-image==0.19.2
fonts
font-roboto
timm==0.4.12
fairscale==0.4.4
2 changes: 2 additions & 0 deletions script.js
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ titles = {
"Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
"Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
"Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",

"Interrogate": "Reconstruct frompt from existing image and put it into the prompt field.",
}

function gradioApp(){
Expand Down
6 changes: 5 additions & 1 deletion style.css
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
max-width: 13em;
}

#img2img_interrogate{
max-width: 10em;
}

#subseed_show{
min-width: 6em;
max-width: 6em;
Expand All @@ -26,7 +30,7 @@
padding-right: 0;
}

#component-1 div{
#toprow div{
border: none;
gap: 0;
}
Expand Down
13 changes: 10 additions & 3 deletions webui.bat
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ if %ERRORLEVEL% == 0 goto :install_reqs
goto :show_stdout_stderr

:install_reqs
%PYTHON% -c "import omegaconf; import fonts" >tmp/stdout.txt 2>tmp/stderr.txt
%PYTHON% -c "import omegaconf; import fonts; import timm" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :make_dirs
echo Installing requirements...
%PYTHON% -m pip install -r %REQS_FILE% --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
Expand Down Expand Up @@ -117,12 +117,19 @@ goto :show_stdout_stderr

:install_codeformer_reqs
%PYTHON% -c "import lpips" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model
if %ERRORLEVEL% == 0 goto :clone_blip
echo Installing requirements for CodeFormer...
%PYTHON% -m pip install -r repositories\CodeFormer\requirements.txt --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model
if %ERRORLEVEL% == 0 goto :clone_blip
goto :show_stdout_stderr

:clone_blip
if exist repositories\BLIP goto :check_model
echo Cloning BLIP repository...
%GIT% clone https://github.com/salesforce/BLIP.git repositories\BLIP >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% NEQ 0 goto :show_stdout_stderr
%GIT% -C repositories/BLIP checkout 48211a1594f1321b00f14c9f7a5b4813144b2fb9 >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% NEQ 0 goto :show_stdout_stderr

:check_model
dir model.ckpt >tmp/stdout.txt 2>tmp/stderr.txt
Expand Down
4 changes: 3 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import modules.face_restoration
import modules.realesrgan_model as realesrgan
import modules.esrgan_model as esrgan
import modules.extras
import modules.extras
import modules.lowvram
import modules.txt2img
import modules.img2img
Expand All @@ -33,6 +33,7 @@
esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan()


def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
Expand Down Expand Up @@ -116,5 +117,6 @@ def sigint_handler(sig, frame):

demo.launch(share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None, server_port=cmd_opts.port)


if __name__ == "__main__":
webui()

0 comments on commit f194457

Please sign in to comment.