Skip to content

Commit

Permalink
iw3: Scaling the delta in row_flow model
Browse files Browse the repository at this point in the history
this is more of a bug fix than an improvement.
grid_sample value is 1/(width/2) scale, so it scale NN outputs.
  • Loading branch information
nagadomi committed Aug 21, 2023
1 parent 10e5af5 commit 4711c71
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions iw3/models/row_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self):
nn.ReLU(inplace=True),
nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1, padding_mode="replicate"),
)
self.register_buffer("delta_scale", torch.tensor(1.0 / 127.0))

for m in self.modules():
if isinstance(m, (nn.Conv2d,)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
Expand All @@ -32,8 +34,7 @@ def forward(self, x):
rgb = x[:, 0:3, :, ]
grid = x[:, 6:8, :, ]
x = x[:, 3:6, :, ] # depth + diverdence feature + convergence

delta = self.conv(x)
delta = self.conv(x) * self.delta_scale
grid = grid + torch.cat([delta, torch.zeros_like(delta)], dim=1)
grid = grid.permute(0, 2, 3, 1)
z = F.grid_sample(rgb, grid, mode="bilinear", padding_mode="border", align_corners=True)
Expand Down

0 comments on commit 4711c71

Please sign in to comment.