Skip to content

Commit

Permalink
Add tutorial to run on custom datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-nguyen committed Jul 23, 2023
1 parent 1331759 commit ac130b4
Show file tree
Hide file tree
Showing 9 changed files with 25,038 additions and 5 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ If you like this project, check out related works from our group:
- [PIZZA: A Powerful Image-only Zero-Shot Zero-CAD Approach to 6DoF Tracking (3DV 2022)](https://github.com/nv-nguyen/pizza)
- [BOP visualization toolkit](https://github.com/nv-nguyen/bop_viz_kit)

## Updates:
- Adding [tutorial](https://github.com/nv-nguyen/cnos##testing-on-custom-datasets-rocket) to run CNOS on custom datasets

## Installation :construction_worker:

<details><summary>Click to expand</summary>
Expand Down Expand Up @@ -144,6 +147,39 @@ python -m src.scripts.visualize_detectron2 dataset_name=$DATASET_NAME input_file

</details>

## Testing on custom datasets :rocket:

You can run CNOS on your custom dataset given a RGB image and the CAD model of the target object. We provide an example of running CNOS on "BBQ sauce" sample taken from [MegaPose](https://github.com/megapose6d/megapose6d).

![qualitative](./media/demo/result.png)

<details><summary>Click to expand</summary>

There are two steps to test CNOS on your own dataset:

0. Define the path to your dataset:
```
export CAD_PATH=./media/demo/hope_000002.ply
export RGB_PATH=./media/demo/bba_sauce_rgb.png
export OUTPUT_DIR=./tmp/custom_dataset
```

1. Render the template from CAD models:
```
export CAD_PATH=./media/demo/hope_000002.ply
export OUTPUT_DIR=./tmp/custom_dataset
bash ./src/scripts/render_custom.sh
```
If the quality of rendering is not good, you can try to ajust the lightning conditions and distance between the camera and the object in [this script](https://github.com/nv-nguyen/cnos/tree/main/src/scripts/render_custom.sh).

2. Run CNOS and visualize the results:
```
bash ./src/scripts/run_inference_custom.sh
```
The detections will be saved at $OUTPUT_DIR/cnos_results. This script is used by default for single-CAD object segmentation. If you want to segment multiple objects, please make few adaptations [this script](https://github.com/nv-nguyen/cnos/tree/main/src/scripts/inference_custom.py).

</details>

## Acknowledgement

The code is adapted from [Nope](https://github.com/nv-nguyen/nope), [Segmenting Anything](https://github.com/facebookresearch/segment-anything), [DINOv2](https://github.com/facebookresearch/dinov2).
Expand Down
Binary file added media/demo/bba_sauce_rgb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24,830 changes: 24,830 additions & 0 deletions media/demo/hope_000002.ply

Large diffs are not rendered by default.

Binary file added media/demo/hope_000002.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added media/demo/result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 11 additions & 5 deletions src/poses/pyrender.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def render(
obj_poses,
img_size,
intrinsic,
light_itensity=0.6,
is_tless=False,
):
# camera pose is fixed as np.eye(4)
Expand All @@ -28,12 +29,14 @@ def render(
cam_pose[2, 2] = -1
# create scene config
ambient_light = np.array([0.02, 0.02, 0.02, 1.0]) # np.array([1.0, 1.0, 1.0, 1.0])
if light_itensity != 0.6:
ambient_light = np.array([1.0, 1.0, 1.0, 1.0])
scene = pyrender.Scene(
bg_color=np.array([0.0, 0.0, 0.0, 0.0]), ambient_light=ambient_light
)
light = pyrender.SpotLight(
color=np.ones(3),
intensity=0.6,
intensity=light_itensity,
innerConeAngle=np.pi / 16.0,
outerConeAngle=np.pi / 6.0,
)
Expand Down Expand Up @@ -64,15 +67,17 @@ def render(
)
parser.add_argument("gpus_devices", nargs="?", help="GPU devices")
parser.add_argument("disable_output", nargs="?", help="Disable output of blender")

parser.add_argument("light_itensity", nargs="?", type=float, default=0.6, help="Light itensity")
parser.add_argument("radius", nargs="?", type=float, default=1, help="Distance from camera to object")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus_devices
poses = np.load(args.obj_pose)
# we can increase high energy for lightning but it's simpler to change just scale of the object to meter
# poses[:, :3, :3] = poses[:, :3, :3] / 1000.0
poses[:, :3, 3] = poses[:, :3, 3] / 1000.0

if args.radius != 1:
poses[:, :3, 3] = poses[:, :3, 3] * args.radius
if "tless" in args.output_dir:
intrinsic = np.asarray(
[1075.65091572, 0.0, 360, 0.0, 1073.90347929, 270, 0.0, 0.0, 1.0]
Expand All @@ -97,11 +102,12 @@ def render(
mesh = pyrender.Mesh.from_trimesh(mesh, smooth=False)
else:
mesh = pyrender.Mesh.from_trimesh(as_mesh(mesh))

os.makedirs(args.output_dir, exist_ok=True)
render(
output_dir=args.output_dir,
mesh=mesh,
obj_poses=poses,
intrinsic=intrinsic,
img_size=(480, 640),
light_itensity=args.light_itensity,
)
157 changes: 157 additions & 0 deletions src/scripts/inference_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os, sys
import numpy as np
import shutil
from tqdm import tqdm
import time
import torch
from PIL import Image
import logging
import os, sys
import os.path as osp
from hydra import initialize, compose
# set level logging
logging.basicConfig(level=logging.INFO)
import logging
import numpy as np
from hydra.utils import instantiate
import argparse
import glob
from src.utils.bbox_utils import CropResizePad
from omegaconf import DictConfig, OmegaConf
from torchvision.utils import save_image
import torchvision.transforms as T
from src.model.utils import Detections, convert_npz_to_json
from src.model.loss import Similarity
from src.utils.inout import save_json_bop23
import cv2
import distinctipy
from skimage.feature import canny
from skimage.morphology import binary_dilation
from segment_anything.utils.amg import rle_to_mask
inv_rgb_transform = T.Compose(
[
T.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
),
]
)

def visualize(rgb, detections, save_path="./tmp/tmp.png"):
img = rgb.copy()
gray = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2GRAY)
img = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
# img = (255*img).astype(np.uint8)
colors = distinctipy.get_colors(len(detections))
alpha = 0.33

for mask_idx, det in enumerate(detections):
mask = rle_to_mask(det["segmentation"])
edge = canny(mask)
edge = binary_dilation(edge, np.ones((2, 2)))
obj_id = det["category_id"]
temp_id = obj_id - 1

r = int(255*colors[temp_id][0])
g = int(255*colors[temp_id][1])
b = int(255*colors[temp_id][2])
img[mask, 0] = alpha*r + (1 - alpha)*img[mask, 0]
img[mask, 1] = alpha*g + (1 - alpha)*img[mask, 1]
img[mask, 2] = alpha*b + (1 - alpha)*img[mask, 2]
img[edge, :] = 255

img = Image.fromarray(np.uint8(img))
img.save(save_path)
prediction = Image.open(save_path)

# concat side by side in PIL
img = np.array(img)
concat = Image.new('RGB', (img.shape[1] + prediction.size[0], img.shape[0]))
concat.paste(rgb, (0, 0))
concat.paste(prediction, (img.shape[1], 0))
return concat

def run_inference(template_dir, rgb_path, num_max_dets, conf_threshold):
with initialize(version_base=None, config_path="../../configs"):
cfg = compose(config_name='run_inference.yaml')

metric = Similarity()
logging.info("Initializing model")
model = instantiate(cfg.model)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.descriptor_model.model = model.descriptor_model.model.to(device)
model.descriptor_model.model.device = device
# if there is predictor in the model, move it to device
if hasattr(model.segmentor_model, "predictor"):
model.segmentor_model.predictor.model = (
model.segmentor_model.predictor.model.to(device)
)
else:
model.segmentor_model.model.setup_model(device=device, verbose=True)
logging.info(f"Moving models to {device} done!")


logging.info("Initializing template")
template_paths = glob.glob(f"{template_dir}/*.png")
boxes, templates = [], []
for path in template_paths:
image = Image.open(path)
boxes.append(image.getbbox())

image = torch.from_numpy(np.array(image.convert("RGB")) / 255).float()
templates.append(image)

templates = torch.stack(templates).permute(0, 3, 1, 2)
boxes = torch.tensor(np.array(boxes))

processing_config = OmegaConf.create(
{
"image_size": 224,
}
)
proposal_processor = CropResizePad(processing_config.image_size)
templates = proposal_processor(images=templates, boxes=boxes).cuda()
save_image(inv_rgb_transform(templates), f"{template_dir}/cnos_results/templates.png", nrow=7)
ref_feats = model.descriptor_model.compute_features(
templates, token_name="x_norm_clstoken"
)
logging.info(f"Ref feats: {ref_feats.shape}")

# run inference
rgb = Image.open(rgb_path)
detections = model.segmentor_model.generate_masks(np.array(rgb))
detections = Detections(detections)
decriptors = model.descriptor_model.forward(np.array(rgb), detections)

# get scores per proposal
scores = metric(decriptors[:, None, :], ref_feats[None, :, :])
score_per_detection = torch.topk(scores, k=5, dim=-1)[0]
score_per_detection = torch.mean(
score_per_detection, dim=-1
)

# get top-k detections
scores, index = torch.topk(score_per_detection, k=num_max_dets, dim=-1)
detections.filter(index)
detections.add_attribute("scores", scores)
detections.add_attribute("object_ids", torch.zeros_like(scores))

detections.to_numpy()
save_path = f"{template_dir}/cnos_results/detection"
detections.save_to_file(0, 0, 0, save_path, "custom", return_results=False)
detections = convert_npz_to_json(idx=0, list_npz_paths=[save_path+".npz"])
save_json_bop23(save_path+".json", detections)
vis_img = visualize(rgb, detections)
vis_img.save(f"{template_dir}/cnos_results/vis.png")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("template_dir", nargs="?", help="Path to root directory of the template")
parser.add_argument("rgb_path", nargs="?", help="Path to RGB image")
parser.add_argument("num_max_dets", nargs="?", default=1, type=int, help="Number of max detections")
parser.add_argument("confg_threshold", nargs="?", default=0.5, type=float, help="Confidence threshold")
args = parser.parse_args()

os.makedirs(f"{args.template_dir}/cnos_results", exist_ok=True)
run_inference(args.template_dir, args.rgb_path, num_max_dets=args.num_max_dets, conf_threshold=args.confg_threshold)
3 changes: 3 additions & 0 deletions src/scripts/render_custom.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export LIGHTING_ITENSITY=1.0 # lighting intensity
export RADIUS=0.4 # distance to camera
python -m src.poses.pyrender $CAD_PATH ./src/poses/predefined_poses/obj_poses_level0.npy $OUTPUT_DIR 0 False $LIGHTING_ITENSITY $RADIUS
1 change: 1 addition & 0 deletions src/scripts/run_inference_custom.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python -m src.scripts.inference_custom $OUTPUT_DIR $RGB_PATH

0 comments on commit ac130b4

Please sign in to comment.