Skip to content

Commit

Permalink
Add tricks and tips with examples of jakob-ropers-snkeos
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-nguyen committed Aug 10, 2023
1 parent 4bf0e35 commit 6a70502
Show file tree
Hide file tree
Showing 7 changed files with 3,501 additions and 13 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ The detections will be saved at $OUTPUT_DIR/cnos_results. This script is used by

</details>

Tricks and tips to improve results when running on custom objects with the sample taken from [jakob-ropers-snkeos](https://github.com/jakob-ropers-snkeos).

![tips](./media/demo2/result.png)

<details><summary>Click to expand</summary>


Please note that SAM or FastSAM can perform exceptionally well, even on very small objects. However, certain parameters from the original implementation require adjustments to achieve optimal results. For example, it is recommended to reduce the default stability_score_thresh value of 0.97 to smaller settings, like 0.5 (applied after step 1 of rendering).
```
python -m src.scripts.inference_custom --template_dir $OUTPUT_DIR --rgb_path $RGB_PATH --stability_score_thresh 0.5
```
</details>

## Acknowledgement

The code is adapted from [Nope](https://github.com/nv-nguyen/nope), [Segmenting Anything](https://github.com/facebookresearch/segment-anything), [DINOv2](https://github.com/facebookresearch/dinov2).
Expand Down
3,436 changes: 3,436 additions & 0 deletions media/demo2/NeedleHolder.ply

Large diffs are not rendered by default.

Binary file added media/demo2/ThreeToolTest.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/demo2/result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 13 additions & 9 deletions src/scripts/inference_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def visualize(rgb, detections, save_path="./tmp/tmp.png"):
concat.paste(prediction, (img.shape[1], 0))
return concat

def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold):
def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold, stability_score_thresh):
with initialize(version_base=None, config_path="../../configs"):
cfg = compose(config_name='run_inference.yaml')

cfg.model.segmentor_model.stability_score_thresh = stability_score_thresh
metric = Similarity()
logging.info("Initializing model")
model = instantiate(cfg.model)
Expand Down Expand Up @@ -112,14 +112,14 @@ def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold):
)
proposal_processor = CropResizePad(processing_config.image_size)
templates = proposal_processor(images=templates, boxes=boxes).cuda()
save_image(inv_rgb_transform(templates), f"{template_dir}/cnos_results/templates.png", nrow=7)
save_image(templates, f"{template_dir}/cnos_results/templates.png", nrow=7)
ref_feats = model.descriptor_model.compute_features(
templates, token_name="x_norm_clstoken"
)
logging.info(f"Ref feats: {ref_feats.shape}")

# run inference
rgb = Image.open(rgb_path)
rgb = Image.open(rgb_path).convert("RGB")
detections = model.segmentor_model.generate_masks(np.array(rgb))
detections = Detections(detections)
decriptors = model.descriptor_model.forward(np.array(rgb), detections)
Expand All @@ -134,6 +134,9 @@ def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold):
# get top-k detections
scores, index = torch.topk(score_per_detection, k=num_max_dets, dim=-1)
detections.filter(index)

# keep only detections with score > conf_threshold
detections.filter(scores>conf_threshold)
detections.add_attribute("scores", scores)
detections.add_attribute("object_ids", torch.zeros_like(scores))

Expand All @@ -147,11 +150,12 @@ def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("template_dir", nargs="?", help="Path to root directory of the template")
parser.add_argument("rgb_path", nargs="?", help="Path to RGB image")
parser.add_argument("num_max_dets", nargs="?", default=1, type=int, help="Number of max detections")
parser.add_argument("confg_threshold", nargs="?", default=0.5, type=float, help="Confidence threshold")
parser.add_argument("--template_dir", nargs="?", help="Path to root directory of the template")
parser.add_argument("--rgb_path", nargs="?", help="Path to RGB image")
parser.add_argument("--num_max_dets", nargs="?", default=1, type=int, help="Number of max detections")
parser.add_argument("--confg_threshold", nargs="?", default=0.5, type=float, help="Confidence threshold")
parser.add_argument("--stability_score_thresh", nargs="?", default=0.97, type=float, help="stability_score_thresh of SAM")
args = parser.parse_args()

os.makedirs(f"{args.template_dir}/cnos_results", exist_ok=True)
run_inference(args.template_dir, args.rgb_path, num_max_dets=args.num_max_dets, conf_threshold=args.confg_threshold)
run_inference(args.template_dir, args.rgb_path, num_max_dets=args.num_max_dets, conf_threshold=args.confg_threshold, stability_score_thresh=args.stability_score_thresh)
2 changes: 1 addition & 1 deletion src/scripts/run_inference_custom.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python -m src.scripts.inference_custom $OUTPUT_DIR $RGB_PATH
python -m src.scripts.inference_custom --template_dir $OUTPUT_DIR --rgb_path $RGB_PATH --stability_score_thresh 0.5
41 changes: 38 additions & 3 deletions src/utils/visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from torchvision import transforms
# from moviepy.video.io.bindings import mplfig_to_npimage
import cv2

import distinctipy
np.random.seed(2022)
COLORS_SPACE = np.random.randint(0, 255, size=(1000, 3))


from skimage.feature import canny
from skimage.morphology import binary_dilation
from tqdm import tqdm
def put_image_to_grid(list_imgs, adding_margin=True):
num_col = len(list_imgs)
b, c, h, w = list_imgs[0].shape
Expand Down Expand Up @@ -50,3 +51,37 @@ def resize_tensor(tensor, size):
return F.interpolate(tensor, size, mode="bilinear", align_corners=True)




def visualize_masks(rgb, masks, save_path="./tmp/tmp.png"):
img = rgb.copy()
gray = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
img = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
# img = (255*img).astype(np.uint8)
colors = distinctipy.get_colors(len(masks))
alpha = 0.33

for mask in tqdm(masks):
edge = canny(mask)
edge = binary_dilation(edge, np.ones((2, 2)))
obj_id = 0
temp_id = obj_id - 1

r = int(255*colors[temp_id][0])
g = int(255*colors[temp_id][1])
b = int(255*colors[temp_id][2])
img[mask, 0] = alpha*r + (1 - alpha)*img[mask, 0]
img[mask, 1] = alpha*g + (1 - alpha)*img[mask, 1]
img[mask, 2] = alpha*b + (1 - alpha)*img[mask, 2]
img[edge, :] = 255

img = Image.fromarray(np.uint8(img))
img.save(save_path)
prediction = Image.open(save_path)

# concat side by side in PIL
img = np.array(img)
concat = Image.new('RGB', (img.shape[1] + prediction.size[0], img.shape[0]))
concat.paste(rgb, (0, 0))
concat.paste(prediction, (img.shape[1], 0))
return concat

0 comments on commit 6a70502

Please sign in to comment.