Skip to content

Commit

Permalink
deleted code to perform VQ inference w/o cache
Browse files Browse the repository at this point in the history
  • Loading branch information
srama2512 committed Feb 26, 2023
1 parent a335334 commit 239c3e9
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 154 deletions.
77 changes: 24 additions & 53 deletions VQ2D/perform_vq_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
create_similarity_network,
get_clip_name_from_clip_uid,
perform_retrieval,
perform_cached_retrieval,
SiamPredictor,
)
from vq2d.structures import ResponseTrack
from vq2d.tracking import Tracker
Expand Down Expand Up @@ -56,7 +54,7 @@ def __init__(self, annots):
for annot in self.annots
]

def run(self, predictor, similarity_net, tracker, cfg, device):
def run(self, similarity_net, tracker, cfg, device):

data_cfg = cfg.data
sig_cfg = cfg.signals
Expand Down Expand Up @@ -88,18 +86,15 @@ def run(self, predictor, similarity_net, tracker, cfg, device):
start_time = time.time()
# Retrieve nearest matches and their scores per image
cached_bboxes, cached_scores, cache_exists = None, None, False
if cfg.model.enable_cache:
assert cfg.model.cache_root != ""
cache_path = os.path.join(cfg.model.cache_root, f"{annot_key}.pt")
if os.path.isfile(cache_path):
cache = torch.load(cache_path)
cached_bboxes = cache["ret_bboxes"]
cached_scores = cache["ret_scores"]
cache_exists = True
assert len(cached_bboxes) == query_frame
assert len(cached_scores) == query_frame
else:
print(f"Could not find cached detections: {cache_path}")
assert cfg.model.cache_root != ""
cache_path = os.path.join(cfg.model.cache_root, f"{annot_key}.pt")
assert os.path.isfile(cache_path)
cache = torch.load(cache_path)
cached_bboxes = cache["ret_bboxes"]
cached_scores = cache["ret_scores"]
cache_exists = True
assert len(cached_bboxes) == query_frame
assert len(cached_scores) == query_frame

if visual_crop["frame_number"] >= len(video_reader):
print(
Expand All @@ -109,34 +104,19 @@ def run(self, predictor, similarity_net, tracker, cfg, device):
)
return {}

if cache_exists:
(
ret_bboxes,
ret_scores,
ret_imgs,
visual_crop_im,
) = perform_cached_retrieval(
video_reader,
visual_crop,
query_frame,
predictor,
cached_bboxes,
cached_scores,
recency_factor=cfg.model.recency_factor,
subsampling_factor=cfg.model.subsampling_factor,
visualize=cfg.logging.visualize,
)
else:
ret_bboxes, ret_scores, ret_imgs, visual_crop_im = perform_retrieval(
video_reader,
visual_crop,
query_frame,
predictor,
batch_size=data_cfg.rcnn_batch_size,
recency_factor=cfg.model.recency_factor,
subsampling_factor=cfg.model.subsampling_factor,
visualize=cfg.logging.visualize,
)
(ret_bboxes, ret_scores, ret_imgs, visual_crop_im,) = perform_retrieval(
video_reader,
visual_crop,
query_frame,
cached_bboxes,
cached_scores,
recency_factor=cfg.model.recency_factor,
subsampling_factor=cfg.model.subsampling_factor,
visualize=cfg.logging.visualize,
reference_pad=cfg.model.reference_pad,
reference_size=cfg.model.reference_size,
)

detection_time_taken = time.time() - start_time
start_time = time.time()
# Generate a time signal of scores
Expand Down Expand Up @@ -279,15 +259,6 @@ def work(self, task_queue, results_queue):

device = torch.device(f"cuda:{self.device_id}")

# Create detector
detectron_cfg = get_detectron_cfg()
detectron_cfg.merge_from_file(self.cfg.model.config_path)
detectron_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = self.cfg.model.score_thresh
detectron_cfg.MODEL.WEIGHTS = self.cfg.model.checkpoint_path
detectron_cfg.MODEL.DEVICE = f"cuda:{self.device_id}"
detectron_cfg.INPUT.FORMAT = "RGB"
predictor = SiamPredictor(detectron_cfg)

# Create tracker
similarity_net = create_similarity_network()
similarity_net.eval()
Expand All @@ -306,7 +277,7 @@ def work(self, task_queue, results_queue):
task = task_queue.get(timeout=1.0)
except QueueEmpty:
break
pred_rts = task.run(predictor, similarity_net, tracker, self.cfg, device)
pred_rts = task.run(similarity_net, tracker, self.cfg, device)
results_queue.put(pred_rts)


Expand Down
3 changes: 1 addition & 2 deletions VQ2D/vq2d/baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .dataset import (
register_visual_query_datasets,
)
from .feature_retrieval import perform_retrieval, perform_cached_retrieval
from .feature_retrieval import perform_retrieval
from .predictor import SiamPredictor
from .utils import (
create_similarity_network,
Expand All @@ -21,7 +21,6 @@
"get_clip_name_from_clip_uid",
"get_image_name_from_clip_uid",
"perform_retrieval",
"perform_cached_retrieval",
"extract_window_with_context",
"register_visual_query_datasets",
"SiamPredictor",
Expand Down
102 changes: 4 additions & 98 deletions VQ2D/vq2d/baselines/feature_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,107 +14,13 @@ def perform_retrieval(
video_reader: Any,
visual_crop: Dict[str, Any],
query_frame: int,
net: DefaultPredictor,
batch_size: int = 8,
downscale_height: int = 700,
recency_factor: float = 1.0, # Search only within the most recent frames.
subsampling_factor: float = 1.0, # Search only within a subsampled set of frames.
visualize: bool = False,
):
"""
Given a visual crop and frames from a clip, retrieve the bounding box proposal
from each frame that is most similar to the visual crop.
"""
vc_fno = visual_crop["frame_number"]
owidth, oheight = visual_crop["original_width"], visual_crop["original_height"]

# Load visual crop frame
reference = video_reader[vc_fno] # RGB format
## Resize visual crop if stored aspect ratio was incorrect
if (reference.shape[0] != oheight) or (reference.shape[1] != owidth):
reference = cv2.resize(reference, (owidth, oheight))
reference = torch.as_tensor(rearrange(reference, "h w c -> () c h w"))
reference = reference.float()
ref_bbox = (
visual_crop["x"],
visual_crop["y"],
visual_crop["x"] + visual_crop["width"],
visual_crop["y"] + visual_crop["height"],
)
reference = extract_window_with_context(
reference,
ref_bbox,
net.cfg.INPUT.REFERENCE_CONTEXT_PAD,
size=net.cfg.INPUT.REFERENCE_SIZE,
pad_value=125,
)
reference = rearrange(asnumpy(reference.byte()), "() c h w -> h w c")
# Define search window
search_window = list(range(0, query_frame))
## Pick recent k% of frames
window_size = int(round(len(search_window) * recency_factor))
if len(search_window[-window_size:]) > 0:
search_window = search_window[-window_size:]
## Subsample only k% of frames
window_size = len(search_window)
idxs_to_sample = np.linspace(
0, window_size - 1, int(subsampling_factor * window_size)
).astype(int)
if len(idxs_to_sample) > 0:
search_window = [search_window[i] for i in idxs_to_sample]

# Load reference frames and perform detection
ret_bboxes = []
ret_scores = []
ret_imgs = []
# Batch extract predictions
for i in range(0, len(search_window), batch_size):
bimages = []
breferences = []
image_scales = []
i_end = min(i + batch_size, len(search_window))
for j in range(i, i_end):
orig_image = video_reader[search_window[j]] # RGB format
image = orig_image
if image.shape[:2] != (oheight, owidth):
image = cv2.resize(image, (owidth, oheight))
# print("Incorrect aspect ratio encountered!")
# Scale-down image to reduce memory consumption
image_scale = float(downscale_height) / image.shape[0]
image = cv2.resize(image, None, fx=image_scale, fy=image_scale)
bimages.append(image)
breferences.append(reference)
if visualize:
ret_imgs.append(orig_image)
image_scales.append(image_scale)
# Perform inference
all_outputs = net(bimages, breferences)
# Unpack outputs
for j in range(i, i_end):
instances = all_outputs[j - i]["instances"]
image_scale = image_scales[j - i]
# Re-scale bboxes
ret_bbs = (
asnumpy(instances.pred_boxes.tensor / image_scale).astype(int).tolist()
)
ret_bbs = [BBox(search_window[j], *bbox) for bbox in ret_bbs]
ret_scs = asnumpy(instances.scores).tolist()
ret_bboxes.append(ret_bbs)
ret_scores.append(ret_scs)
del all_outputs
return ret_bboxes, ret_scores, ret_imgs, reference


def perform_cached_retrieval(
video_reader: Any,
visual_crop: Dict[str, Any],
query_frame: int,
net: DefaultPredictor,
cached_bboxes: List[BBox],
cached_scores: List[float],
recency_factor: float = 1.0, # Search only within the most recent frames.
subsampling_factor: float = 1.0, # Search only within a subsampled set of frames.
visualize: bool = False,
reference_pad: int = 16,
reference_size: int = 256,
):
"""
Given a visual crop and frames from a clip, retrieve the bounding box proposal
Expand All @@ -139,8 +45,8 @@ def perform_cached_retrieval(
reference = extract_window_with_context(
reference,
ref_bbox,
net.cfg.INPUT.REFERENCE_CONTEXT_PAD,
size=net.cfg.INPUT.REFERENCE_SIZE,
reference_pad,
size=reference_size,
pad_value=125,
)
reference = rearrange(asnumpy(reference.byte()), "() c h w -> h w c")
Expand Down
3 changes: 2 additions & 1 deletion VQ2D/vq2d/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ model:
subsampling_factor: 1.0
recency_factor: 1.0
cache_root: ""
enable_cache: False
reference_pad: 16
reference_size: 256

tracker:
type: "kys" # Options: [ kys | pfilter ]
Expand Down

0 comments on commit 239c3e9

Please sign in to comment.