Skip to content

Commit

Permalink
Option to not shuffle mv neighbors when subsampling
Browse files Browse the repository at this point in the history
  • Loading branch information
trancelestial committed Nov 3, 2022
1 parent ad8e94d commit b53354e
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions nerfstudio/data/dataparsers/uniscene_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,19 @@
CONSOLE = Console()


def get_src_from_pairs(ref_idx, all_imgs, pairs_srcs, neighbors_num=None) -> Dict[str, TensorType]:
def get_src_from_pairs(
ref_idx, all_imgs, pairs_srcs, neighbors_num=None, neighbors_shuffle=False
) -> Dict[str, TensorType]:
src_idx = pairs_srcs[ref_idx] # src_idx[0] is ref img
# randomly sample neighbors
if neighbors_num and neighbors_num > -1 and neighbors_num < len(src_idx) - 1:
perm_idx = torch.randperm(len(src_idx) - 1) + 1
src_idx = torch.cat([src_idx[[0]], src_idx[perm_idx[:neighbors_num]]])
# perm_idx = torch.cat(torch.tensor([0]), perm_idx)
# src_idx = src_idx[perm_idx[:neighbors_num+1]]
if neighbors_shuffle:
perm_idx = torch.randperm(len(src_idx) - 1) + 1
src_idx = torch.cat([src_idx[[0]], src_idx[perm_idx[:neighbors_num]]])
# perm_idx = torch.cat(torch.tensor([0]), perm_idx)
# src_idx = src_idx[perm_idx[:neighbors_num+1]]
else:
src_idx = src_idx[: neighbors_num + 1]
return {"src_imgs": all_imgs[src_idx], "src_idxs": src_idx}


Expand Down Expand Up @@ -130,6 +135,7 @@ class UniSceneDataParserConfig(DataParserConfig):
] = "center_crop_for_dtu"
"""center crop type as monosdf, we should create a dataset that don't need this"""
neighbors_num: Optional[int] = None
neighbors_shuffle: Optional[bool] = False


@dataclass
Expand Down Expand Up @@ -292,7 +298,12 @@ def glob_data(data_dir):

additional_inputs_dict["pairs"] = {
"func": get_src_from_pairs,
"kwargs": {"all_imgs": all_imgs, "pairs_srcs": pairs_srcs, "neighbors_num": self.config.neighbors_num},
"kwargs": {
"all_imgs": all_imgs,
"pairs_srcs": pairs_srcs,
"neighbors_num": self.config.neighbors_num,
"neighbors_shuffle": self.config.neighbors_shuffle,
},
}

dataparser_outputs = DataparserOutputs(
Expand Down

0 comments on commit b53354e

Please sign in to comment.