Skip to content

Commit

Permalink
Corret Boundary Loss for SAM_RS
Browse files Browse the repository at this point in the history
  • Loading branch information
sstary committed Nov 8, 2024
1 parent bdfd950 commit 5ac5f59
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions SAM_RS/utils_loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,6 @@ def grouper(n, iterable):
return
yield chunk


def one_hot(label, n_classes, requires_grad=True):
"""Return One Hot Label"""
one_hot_label = torch.eye(
n_classes, device='cuda', requires_grad=requires_grad)[label]
one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)

return one_hot_label

class ObjectLoss(nn.Module):
def __init__(self, max_object=50):
super().__init__()
Expand All @@ -400,14 +391,7 @@ def forward(self, pred, gt):

return total_object_loss

class BoundaryLoss(nn.Module):
def __init__(self, theta0=3, theta=5):
super().__init__()

self.theta0 = theta0
self.theta = theta

def forward(self, pred, gt):
def forward(self, pred, gt):
"""
Input:
- pred: the output from model (before softmax)
Expand All @@ -418,15 +402,31 @@ def forward(self, pred, gt):
- boundary loss, averaged over mini-bathc
"""

n, c, _, _ = pred.shape
n, _, _, _ = pred.shape
# softmax so that predicted map can be distributed in [0, 1]
pred = torch.softmax(pred, dim=1)
# one-hot vector of ground truth
one_hot_gt = one_hot(gt, c)
pred = pred.argmax(dim=1).cpu() # Get Class Map with the Shape: [B, H, W]

### Other edge detection algorithms are also encoduraged to be used
### for boundary extraction from prediction map such as Canny edge detection
# Shift class map in four directions
shift_up = torch.roll(pred, shifts=-1, dims=1)
shift_down = torch.roll(pred, shifts=1, dims=1)
shift_left = torch.roll(pred, shifts=-1, dims=2)
shift_right = torch.roll(pred, shifts=1, dims=2)
# Compute boundaries by finding differences in class labels
boundary_up = (pred != shift_up).int()
boundary_down = (pred != shift_down).int()
boundary_left = (pred != shift_left).int()
boundary_right = (pred != shift_right).int()
# Combine the boundary maps to get a single boundary prediction
boundary_map = (boundary_up + boundary_down + boundary_left + boundary_right) > 0
boundary_map = boundary_map.int()

# boundary map
gt_b = F.max_pool2d(
1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
gt_b -= 1 - one_hot_gt
1 - gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
gt_b -= 1 - gt

pred_b = F.max_pool2d(
1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
Expand All @@ -440,10 +440,10 @@ def forward(self, pred, gt):
pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)

# reshape
gt_b = gt_b.view(n, c, -1)
pred_b = pred_b.view(n, c, -1)
gt_b_ext = gt_b_ext.view(n, c, -1)
pred_b_ext = pred_b_ext.view(n, c, -1)
gt_b = gt_b.view(n, 2, -1)
pred_b = pred_b.view(n, 2, -1)
gt_b_ext = gt_b_ext.view(n, 2, -1)
pred_b_ext = pred_b_ext.view(n, 2, -1)

# Precision, Recall
P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
Expand Down

0 comments on commit 5ac5f59

Please sign in to comment.