Skip to content

Commit

Permalink
Update local_correlation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
qkqhd222 authored Oct 12, 2023
1 parent 0263a81 commit a274747
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions roma/utils/local_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def local_correlation(
# If flow is None, assume feature0 and feature1 are aligned
coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
))
coords = torch.stack((coords[1], coords[0]), dim=-1)[
None
Expand All @@ -27,8 +27,8 @@ def local_correlation(
coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
local_window = torch.meshgrid(
(
torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
))
local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
None
Expand All @@ -41,4 +41,4 @@ def local_correlation(
)
window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
return corr
return corr

0 comments on commit a274747

Please sign in to comment.