forked from AUTOMATIC1111/stable-diffusion-webui
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
13008ba
commit f194457
Showing
13 changed files
with
204 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ __pycache__ | |
/embeddings | ||
/styles.csv | ||
/webui-user.bat | ||
/interrogate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
import torch | ||
|
||
|
||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility | ||
has_mps = getattr(torch, 'has_mps', False) | ||
|
||
cpu = torch.device("cpu") | ||
|
||
|
||
def get_optimal_device(): | ||
if torch.cuda.is_available(): | ||
return torch.device("cuda") | ||
if has_mps: | ||
return torch.device("mps") | ||
return torch.device("cpu") | ||
if torch.cuda.is_available(): | ||
return torch.device("cuda") | ||
|
||
if has_mps: | ||
return torch.device("mps") | ||
|
||
return cpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import os | ||
import sys | ||
import traceback | ||
from collections import namedtuple | ||
import re | ||
|
||
import torch | ||
|
||
from PIL import Image | ||
from torchvision import transforms | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
import modules.shared as shared | ||
from modules import devices, paths | ||
|
||
blip_image_eval_size = 384 | ||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' | ||
clip_model_name = 'ViT-L/14' | ||
|
||
Category = namedtuple("Category", ["name", "topn", "items"]) | ||
|
||
re_topn = re.compile(r"\.top(\d+)\.") | ||
|
||
class InterrogateModels: | ||
blip_model = None | ||
clip_model = None | ||
clip_preprocess = None | ||
categories = None | ||
|
||
def __init__(self, content_dir): | ||
self.categories = [] | ||
|
||
if os.path.exists(content_dir): | ||
for filename in os.listdir(content_dir): | ||
m = re_topn.search(filename) | ||
topn = 1 if m is None else int(m.group(1)) | ||
|
||
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: | ||
lines = [x.strip() for x in file.readlines()] | ||
|
||
self.categories.append(Category(name=filename, topn=topn, items=lines)) | ||
|
||
def load_blip_model(self): | ||
import models.blip | ||
|
||
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) | ||
blip_model.eval() | ||
|
||
return blip_model | ||
|
||
def load_clip_model(self): | ||
import clip | ||
|
||
model, preprocess = clip.load(clip_model_name) | ||
model.eval() | ||
model = model.to(shared.device) | ||
|
||
return model, preprocess | ||
|
||
def load(self): | ||
if self.blip_model is None: | ||
self.blip_model = self.load_blip_model() | ||
|
||
self.blip_model = self.blip_model.to(shared.device) | ||
|
||
if self.clip_model is None: | ||
self.clip_model, self.clip_preprocess = self.load_clip_model() | ||
|
||
self.clip_model = self.clip_model.to(shared.device) | ||
|
||
def unload(self): | ||
if not shared.opts.interrogate_keep_models_in_memory: | ||
if self.clip_model is not None: | ||
self.clip_model = self.clip_model.to(devices.cpu) | ||
|
||
if self.blip_model is not None: | ||
self.blip_model = self.blip_model.to(devices.cpu) | ||
|
||
|
||
def rank(self, image_features, text_array, top_count=1): | ||
import clip | ||
|
||
top_count = min(top_count, len(text_array)) | ||
text_tokens = clip.tokenize([text for text in text_array]).cuda() | ||
with torch.no_grad(): | ||
text_features = self.clip_model.encode_text(text_tokens).float() | ||
text_features /= text_features.norm(dim=-1, keepdim=True) | ||
|
||
similarity = torch.zeros((1, len(text_array))).to(shared.device) | ||
for i in range(image_features.shape[0]): | ||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) | ||
similarity /= image_features.shape[0] | ||
|
||
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) | ||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] | ||
|
||
|
||
def generate_caption(self, pil_image): | ||
gpu_image = transforms.Compose([ | ||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | ||
])(pil_image).unsqueeze(0).to(shared.device) | ||
|
||
with torch.no_grad(): | ||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) | ||
|
||
return caption[0] | ||
|
||
def interrogate(self, pil_image): | ||
res = None | ||
|
||
try: | ||
self.load() | ||
|
||
caption = self.generate_caption(pil_image) | ||
res = caption | ||
|
||
images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device) | ||
|
||
with torch.no_grad(): | ||
image_features = self.clip_model.encode_image(images).float() | ||
|
||
image_features /= image_features.norm(dim=-1, keepdim=True) | ||
|
||
if shared.opts.interrogate_use_builtin_artists: | ||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] | ||
|
||
res += ", " + artist[0] | ||
|
||
for name, topn, items in self.categories: | ||
matches = self.rank(image_features, items, top_count=topn) | ||
for match, score in matches: | ||
res += ", " + match | ||
|
||
except Exception: | ||
print(f"Error interrogating", file=sys.stderr) | ||
print(traceback.format_exc(), file=sys.stderr) | ||
|
||
self.unload() | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,5 @@ pytorch_lightning==1.7.2 | |
scikit-image==0.19.2 | ||
fonts | ||
font-roboto | ||
timm==0.4.12 | ||
fairscale==0.4.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters