forked from Uminosachi/sd-webui-inpaint-anything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ia_sam_manager.py
120 lines (105 loc) · 4.87 KB
/
ia_sam_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import platform
import torch
from modules import devices
from fast_sam import FastSamAutomaticMaskGenerator, fast_sam_model_registry
from ia_check_versions import ia_check_versions
from ia_config import get_webui_setting
from ia_logging import ia_logging
from ia_threading import torch_default_load_cd
from mobile_sam import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorMobile
from mobile_sam import SamPredictor as SamPredictorMobile
from mobile_sam import sam_model_registry as sam_model_registry_mobile
from segment_anything_fb import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from segment_anything_hq import SamAutomaticMaskGenerator as SamAutomaticMaskGeneratorHQ
from segment_anything_hq import SamPredictor as SamPredictorHQ
from segment_anything_hq import sam_model_registry as sam_model_registry_hq
@torch_default_load_cd()
def get_sam_mask_generator(sam_checkpoint, anime_style_chk=False):
"""Get SAM mask generator.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamAutomaticMaskGenerator or None: SAM mask generator
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorHQ
points_per_batch = 32
elif "FastSAM" in os.path.basename(sam_checkpoint):
model_type = os.path.splitext(os.path.basename(sam_checkpoint))[0]
sam_model_registry_local = fast_sam_model_registry
SamAutomaticMaskGeneratorLocal = FastSamAutomaticMaskGenerator
points_per_batch = None
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGeneratorMobile
points_per_batch = 64
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamAutomaticMaskGeneratorLocal = SamAutomaticMaskGenerator
points_per_batch = 64
pred_iou_thresh = 0.88 if not anime_style_chk else 0.83
stability_score_thresh = 0.95 if not anime_style_chk else 0.9
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if get_webui_setting("inpaint_anything_sam_oncpu", False):
ia_logging.info("SAM is running on CPU... (the option has been checked)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_mask_generator = SamAutomaticMaskGeneratorLocal(
model=sam, points_per_batch=points_per_batch, pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
else:
sam_mask_generator = None
return sam_mask_generator
@torch_default_load_cd()
def get_sam_predictor(sam_checkpoint):
"""Get SAM predictor.
Args:
sam_checkpoint (str): SAM checkpoint path
Returns:
SamPredictor or None: SAM predictor
"""
# model_type = "vit_h"
if "_hq_" in os.path.basename(sam_checkpoint):
model_type = os.path.basename(sam_checkpoint)[7:12]
sam_model_registry_local = sam_model_registry_hq
SamPredictorLocal = SamPredictorHQ
elif "FastSAM" in os.path.basename(sam_checkpoint):
raise NotImplementedError("FastSAM predictor is not implemented yet.")
elif "mobile_sam" in os.path.basename(sam_checkpoint):
model_type = "vit_t"
sam_model_registry_local = sam_model_registry_mobile
SamPredictorLocal = SamPredictorMobile
else:
model_type = os.path.basename(sam_checkpoint)[4:9]
sam_model_registry_local = sam_model_registry
SamPredictorLocal = SamPredictor
if os.path.isfile(sam_checkpoint):
sam = sam_model_registry_local[model_type](checkpoint=sam_checkpoint)
if platform.system() == "Darwin":
if "FastSAM" in os.path.basename(sam_checkpoint) or not ia_check_versions.torch_mps_is_available:
sam.to(device=torch.device("cpu"))
else:
sam.to(device=torch.device("mps"))
else:
if get_webui_setting("inpaint_anything_sam_oncpu", False):
ia_logging.info("SAM is running on CPU... (the option has been checked)")
sam.to(device=devices.cpu)
else:
sam.to(device=devices.device)
sam_predictor = SamPredictorLocal(sam)
else:
sam_predictor = None
return sam_predictor