Skip to content

Commit

Permalink
clipseg improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed May 13, 2024
1 parent 4812eff commit 17a6b35
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 10 deletions.
1 change: 1 addition & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"ConditioningSetMaskAndCombine5": {"class": ConditioningSetMaskAndCombine5, "name": "ConditioningSetMaskAndCombine5"},
"CondPassThrough": {"class": CondPassThrough},
#masking
"DownloadAndLoadCLIPSeg": {"class": DownloadAndLoadCLIPSeg, "name": "(Down)load CLIPSeg"},
"BatchCLIPSeg": {"class": BatchCLIPSeg, "name": "Batch CLIPSeg"},
"ColorToMask": {"class": ColorToMask, "name": "Color To Mask"},
"CreateGradientMask": {"class": CreateGradientMask, "name": "Create Gradient Mask"},
Expand Down
85 changes: 75 additions & 10 deletions nodes/mask_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ def INPUT_TYPES(s):
{
"images": ("IMAGE",),
"text": ("STRING", {"multiline": False}),
"threshold": ("FLOAT", {"default": 0.1,"min": 0.0, "max": 10.0, "step": 0.001}),
"threshold": ("FLOAT", {"default": 0.5,"min": 0.0, "max": 10.0, "step": 0.001}),
"binary_mask": ("BOOLEAN", {"default": True}),
"combine_mask": ("BOOLEAN", {"default": False}),
"use_cuda": ("BOOLEAN", {"default": True}),
},
"optional":
{
"blur_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.1}),
"opt_model": ("CLIPSEGMODEL", ),
"prev_mask": ("MASK", {"default": None}),
}
}

Expand All @@ -50,7 +52,7 @@ def INPUT_TYPES(s):
Segments an image or batch of images using CLIPSeg.
"""

def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0):
def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_cuda, blur_sigma=0.0, opt_model=None, prev_mask=None):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torchvision.transforms as transforms
offload_device = model_management.unet_offload_device()
Expand All @@ -59,10 +61,23 @@ def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_
else:
device = torch.device("cpu")
dtype = model_management.unet_dtype()
if not hasattr(self, "model"):
self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")

if opt_model is None:
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', 'clipseg-rd64-refined-fp16')
if not hasattr(self, "model"):
try:
if not os.path.exists(checkpoint_path):
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Kijai/clipseg-rd64-refined-fp16", local_dir=checkpoint_path, local_dir_use_symlinks=False)
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
except:
checkpoint_path = "CIDAS/clipseg-rd64-refined"
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)
processor = CLIPSegProcessor.from_pretrained(checkpoint_path)

else:
self.model = opt_model['model']
processor = opt_model['processor']

self.model.to(dtype).to(device)

Expand All @@ -81,20 +96,20 @@ def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_
outputs = self.model(**input_prc)

tensor = torch.sigmoid(outputs.logits)
tensor = torch.where(tensor > (threshold / 10), tensor, torch.tensor(0, dtype=torch.float))
print(tensor.min(), tensor.max())
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
tensor = torch.where(tensor > (threshold), tensor, torch.tensor(0, dtype=torch.float))


tensor = F.interpolate(tensor.unsqueeze(1), size=(H, W), mode='nearest')
tensor = tensor.squeeze(1)

self.model.to(offload_device)
results = tensor.cpu().float()
print(results.min(), results.max())

if binary_mask:
tensor = (tensor > 0).float()
if blur_sigma > 0:
kernel_size = int(6 * blur_sigma + 1)
kernel_size = int(6 * int(blur_sigma) + 1)
blur = transforms.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=(blur_sigma, blur_sigma))
tensor = blur(tensor)

Expand All @@ -105,8 +120,58 @@ def segment_image(self, images, text, threshold, binary_mask, combine_mask, use_
del outputs
model_management.soft_empty_cache()

if prev_mask is not None:
tensor = tensor + prev_mask
torch.clamp(tensor, min=0.0, max=1.0)

return tensor,

class DownloadAndLoadCLIPSeg:

def __init__(self):
pass

@classmethod
def INPUT_TYPES(s):

return {"required":
{
"model": (
[ 'Kijai/clipseg-rd64-refined-fp16',
'CIDAS/clipseg-rd64-refined',
],
{
"default": 'clipseg-rd64-refined-fp16'
}),
},
}

CATEGORY = "KJNodes/masking"
RETURN_TYPES = ("CLIPSEGMODEL",)
RETURN_NAMES = ("clipseg_model",)
FUNCTION = "segment_image"
DESCRIPTION = """
Downloads and loads CLIPSeg model with huggingface_hub,
to ComfyUI/models/clip_seg
"""

def segment_image(self, model):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
checkpoint_path = os.path.join(folder_paths.models_dir,'clip_seg', model)
if not hasattr(self, "model"):
if not os.path.exists(checkpoint_path):
from huggingface_hub import snapshot_download
snapshot_download(repo_id=model, local_dir=checkpoint_path.split("/")[-1], local_dir_use_symlinks=False)
self.model = CLIPSegForImageSegmentation.from_pretrained(checkpoint_path)

processor = CLIPSegProcessor.from_pretrained(checkpoint_path)

clipseg_model = {}
clipseg_model['model'] = self.model
clipseg_model['processor'] = processor

return clipseg_model,

class CreateTextMask:

RETURN_TYPES = ("IMAGE", "MASK",)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pillow>=10.3.0
scipy
color-matcher
matplotlib
huggingface_hub

0 comments on commit 17a6b35

Please sign in to comment.