Skip to content

Commit 11d6d1a

Browse files
authoredMay 19, 2023
Merge pull request kerrj#25 from shengjie-lin/multi-phrases-fix
Multi phrases fix
2 parents ca94b13 + b3fa5f8 commit 11d6d1a

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed
 

‎lerf/data/utils/feature_dataloader.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def load(self):
4242
self.data = torch.from_numpy(np.load(self.cache_path)).to(self.device)
4343

4444
def save(self):
45+
os.makedirs(self.cache_path.parent, exist_ok=True)
4546
cache_info_path = self.cache_path.with_suffix(".info")
4647
with open(cache_info_path, "w") as f:
4748
f.write(json.dumps(self.cfg))

‎lerf/lerf.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_max_across(self, ray_samples, weights, hashgrid_field, scales_shape, pre
9090
n_phrases = len(self.image_encoder.positives)
9191
n_phrases_maxs = [None for _ in range(n_phrases)]
9292
n_phrases_sims = [None for _ in range(n_phrases)]
93-
for _, scale in enumerate(scales_list):
93+
for i, scale in enumerate(scales_list):
9494
scale = scale.item()
9595
with torch.no_grad():
9696
clip_output = self.lerf_field.get_output_from_hashgrid(
@@ -100,12 +100,13 @@ def get_max_across(self, ray_samples, weights, hashgrid_field, scales_shape, pre
100100
)
101101
clip_output = self.renderer_clip(embeds=clip_output, weights=weights.detach())
102102

103-
for i in range(n_phrases):
104-
probs = self.image_encoder.get_relevancy(clip_output, i)
105-
pos_prob = probs[..., 0:1]
106-
if n_phrases_maxs[i] is None or pos_prob.max() > n_phrases_sims[i].max():
107-
n_phrases_maxs[i] = scale
108-
n_phrases_sims[i] = pos_prob
103+
for j in range(n_phrases):
104+
if preset_scales is None or j == i:
105+
probs = self.image_encoder.get_relevancy(clip_output, j)
106+
pos_prob = probs[..., 0:1]
107+
if n_phrases_maxs[j] is None or pos_prob.max() > n_phrases_sims[j].max():
108+
n_phrases_maxs[j] = scale
109+
n_phrases_sims[j] = pos_prob
109110
return torch.stack(n_phrases_sims), torch.Tensor(n_phrases_maxs)
110111

111112
def get_outputs(self, ray_bundle: RayBundle):

0 commit comments

Comments
 (0)