Skip to content

Commit

Permalink
correcting gather mode bug in KPConv blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
HuguesTHOMAS committed Dec 24, 2024
1 parent 88d8473 commit 54e644a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions Pointcept-wrapper/models/kpconvx/utils/kpconv_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def forward(self, q_pts: Tensor,
# Gathering only worth if K > 25, nearest mode and depthwise_conv
if self.gather_mode:

# Collect nearest kernel point weights -> (M, H, G, C//G, 0//G)
# Collect nearest kernel point weights -> (M, H, G, C//G, O//G)
neighbors_weights = gather(self.weights, neighbors_1nn)

# Apply influence weights
Expand All @@ -571,10 +571,10 @@ def forward(self, q_pts: Tensor,

# Depthwise
neighbors_weights = neighbors_weights.view(-1, H, self.groups, self.out_channels_per_group) # (M, H, G=C, O//G)
neighbor_feats = neighbors_weights.view(-1, H, self.groups, 1) # (M, H, C, 1)
neighbor_feats = neighbor_feats.view(-1, H, self.groups, 1) # (M, H, C, 1)

# Apply weights and summation
output_feats = torch.sum(neighbor_feats * neighbors_weights, dim=1) # -> (M, G, 0//G)
output_feats = torch.sum(neighbor_feats * neighbors_weights, dim=1) # -> (M, G, O//G)
output_feats = output_feats.reshape((-1, self.out_channels)) # -> (M, O)


Expand Down
6 changes: 3 additions & 3 deletions Pointcept-wrapper/models/kpnext/kpconv_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def forward(self, q_pts: Tensor,
# Gathering only worth if K > 25, nearest mode and depthwise_conv
if self.gather_mode:

# Collect nearest kernel point weights -> (M, H, G, C//G, 0//G)
# Collect nearest kernel point weights -> (M, H, G, C//G, O//G)
neighbors_weights = gather(self.weights, neighbors_1nn)

# Apply influence weights
Expand All @@ -593,10 +593,10 @@ def forward(self, q_pts: Tensor,

# Depthwise
neighbors_weights = neighbors_weights.view(-1, H, self.groups, self.out_channels_per_group) # (M, H, G=C, O//G)
neighbor_feats = neighbors_weights.view(-1, H, self.groups, 1) # (M, H, C, 1)
neighbor_feats = neighbor_feats.view(-1, H, self.groups, 1) # (M, H, C, 1)

# Apply weights and summation
output_feats = torch.sum(neighbor_feats * neighbors_weights, dim=1) # -> (M, G, 0//G)
output_feats = torch.sum(neighbor_feats * neighbors_weights, dim=1) # -> (M, G, O//G)
output_feats = output_feats.reshape((-1, self.out_channels)) # -> (M, O)


Expand Down
6 changes: 3 additions & 3 deletions Standalone/KPConvX/models/kpconv_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def forward(self, q_pts: Tensor,
# Gathering only worth if K > 25, nearest mode and depthwise_conv
if self.gather_mode:

# Collect nearest kernel point weights -> (M, H, G, C//G, 0//G)
# Collect nearest kernel point weights -> (M, H, G, C//G, O//G)
neighbors_weights = gather(self.weights, neighbors_1nn)

# Apply influence weights
Expand All @@ -572,10 +572,10 @@ def forward(self, q_pts: Tensor,

# Depthwise
neighbors_weights = neighbors_weights.view(-1, H, self.groups, self.out_channels_per_group) # (M, H, G=C, O//G)
neighbor_feats = neighbors_weights.view(-1, H, self.groups, 1) # (M, H, C, 1)
neighbor_feats = neighbor_feats.view(-1, H, self.groups, 1) # (M, H, C, 1)

# Apply weights and summation
output_feats = torch.sum(neighbor_feats * neighbors_weights, dim=1) # -> (M, G, 0//G)
output_feats = torch.sum(neighbor_feats * neighbors_weights, dim=1) # -> (M, G, O//G)
output_feats = output_feats.reshape((-1, self.out_channels)) # -> (M, O)


Expand Down

0 comments on commit 54e644a

Please sign in to comment.