@@ -90,7 +90,7 @@ def get_max_across(self, ray_samples, weights, hashgrid_field, scales_shape, pre
90
90
n_phrases = len (self .image_encoder .positives )
91
91
n_phrases_maxs = [None for _ in range (n_phrases )]
92
92
n_phrases_sims = [None for _ in range (n_phrases )]
93
- for _ , scale in enumerate (scales_list ):
93
+ for i , scale in enumerate (scales_list ):
94
94
scale = scale .item ()
95
95
with torch .no_grad ():
96
96
clip_output = self .lerf_field .get_output_from_hashgrid (
@@ -100,12 +100,13 @@ def get_max_across(self, ray_samples, weights, hashgrid_field, scales_shape, pre
100
100
)
101
101
clip_output = self .renderer_clip (embeds = clip_output , weights = weights .detach ())
102
102
103
- for i in range (n_phrases ):
104
- probs = self .image_encoder .get_relevancy (clip_output , i )
105
- pos_prob = probs [..., 0 :1 ]
106
- if n_phrases_maxs [i ] is None or pos_prob .max () > n_phrases_sims [i ].max ():
107
- n_phrases_maxs [i ] = scale
108
- n_phrases_sims [i ] = pos_prob
103
+ for j in range (n_phrases ):
104
+ if preset_scales is None or j == i :
105
+ probs = self .image_encoder .get_relevancy (clip_output , j )
106
+ pos_prob = probs [..., 0 :1 ]
107
+ if n_phrases_maxs [j ] is None or pos_prob .max () > n_phrases_sims [j ].max ():
108
+ n_phrases_maxs [j ] = scale
109
+ n_phrases_sims [j ] = pos_prob
109
110
return torch .stack (n_phrases_sims ), torch .Tensor (n_phrases_maxs )
110
111
111
112
def get_outputs (self , ray_bundle : RayBundle ):
0 commit comments