-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7acf46c
Showing
14 changed files
with
2,069 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
.idea/ | ||
.vscode/ | ||
ProstateX-0002/ | ||
# Created by .ignore support plugin (hsz.mobi) | ||
### Python template | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
Data/ | ||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Loss_ToolBox | ||
--- | ||
|
||
## Introduction | ||
This repository include several losses for 3D image segmentation. | ||
1. [Focal Loss](https://arxiv.org/abs/1708.02002) (PS:Borrow some code from [c0nn3r/RetinaNet](https://github.com/c0nn3r/RetinaNet)) | ||
2. [Lovasz-Softmax Loss](https://arxiv.org/abs/1705.08790)(Modify from orinial implementation [LovaszSoftmax](https://github.com/bermanmaxim/LovaszSoftmax)) | ||
3. [DiceLoss](https://arxiv.org/abs/1606.04797) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .dice_loss import * | ||
from focal_loss import * | ||
from lovasz_loss import * | ||
from tverskyloss import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
from .focal_loss import BinaryFocalLoss | ||
|
||
|
||
def make_one_hot(input, num_classes=None): | ||
"""Convert class index tensor to one hot encoding tensor. | ||
Args: | ||
input: A tensor of shape [N, 1, *] | ||
num_classes: An int of number of class | ||
Shapes: | ||
predict: A tensor of shape [N, *] without sigmoid activation function applied | ||
target: A tensor of shape same with predict | ||
Returns: | ||
A tensor of shape [N, num_classes, *] | ||
""" | ||
if num_classes is None: | ||
num_classes = input.max() + 1 | ||
shape = np.array(input.shape) | ||
shape[1] = num_classes | ||
shape = tuple(shape) | ||
result = torch.zeros(shape) | ||
result = result.scatter_(1, input.cpu().long(), 1) | ||
return result | ||
|
||
|
||
class BinaryDiceLoss(nn.Module): | ||
"""Dice loss of binary class | ||
Args: | ||
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient | ||
reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' | ||
Shapes: | ||
output: A tensor of shape [N, *] without sigmoid activation function applied | ||
target: A tensor of shape same with output | ||
Returns: | ||
Loss tensor according to arg reduction | ||
Raise: | ||
Exception if unexpected reduction | ||
""" | ||
|
||
def __init__(self, ignore_index=None, reduction='mean', **kwargs): | ||
super(BinaryDiceLoss, self).__init__() | ||
self.smooth = 1 # suggest set a large number when target area is large,like '10|100' | ||
self.ignore_index = ignore_index | ||
self.reduction = reduction | ||
self.batch_dice = False # treat a large map when True | ||
if 'batch_loss' in kwargs.keys(): | ||
self.batch_dice = kwargs['batch_loss'] | ||
|
||
def forward(self, output, target, use_sigmoid=True): | ||
assert output.shape[0] == target.shape[0], "output & target batch size don't match" | ||
if use_sigmoid: | ||
output = torch.sigmoid(output) | ||
|
||
if self.ignore_index is not None: | ||
validmask = (target != self.ignore_index).float() | ||
output = output.mul(validmask) # can not use inplace for bp | ||
target = target.float().mul(validmask) | ||
|
||
dim0 = output.shape[0] | ||
if self.batch_dice: | ||
dim0 = 1 | ||
|
||
output = output.contiguous().view(dim0, -1) | ||
target = target.contiguous().view(dim0, -1).float() | ||
|
||
num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth | ||
den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth | ||
|
||
loss = 1 - (num / den) | ||
|
||
if self.reduction == 'mean': | ||
return loss.mean() | ||
elif self.reduction == 'sum': | ||
return loss.sum() | ||
elif self.reduction == 'none': | ||
return loss | ||
else: | ||
raise Exception('Unexpected reduction {}'.format(self.reduction)) | ||
|
||
|
||
class DiceLoss(nn.Module): | ||
"""Dice loss, need one hot encode input | ||
Args: | ||
weight: An array of shape [num_classes,] | ||
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient | ||
output: A tensor of shape [N, C, *] | ||
target: A tensor of same shape with output | ||
other args pass to BinaryDiceLoss | ||
Return: | ||
same as BinaryDiceLoss | ||
""" | ||
|
||
def __init__(self, weight=None, ignore_index=None, **kwargs): | ||
super(DiceLoss, self).__init__() | ||
self.kwargs = kwargs | ||
self.weight = weight | ||
if isinstance(ignore_index, (int, float)): | ||
self.ignore_index = [int(ignore_index)] | ||
elif ignore_index is None: | ||
self.ignore_index = [] | ||
elif isinstance(ignore_index, (list, tuple)): | ||
self.ignore_index = ignore_index | ||
else: | ||
raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index))) | ||
|
||
def forward(self, output, target): | ||
assert output.shape == target.shape, 'output & target shape do not match' | ||
dice = BinaryDiceLoss(**self.kwargs) | ||
total_loss = 0 | ||
output = F.softmax(output, dim=1) | ||
for i in range(target.shape[1]): | ||
if i not in self.ignore_index: | ||
dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False) | ||
if self.weight is not None: | ||
assert self.weight.shape[0] == target.shape[1], \ | ||
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) | ||
dice_loss *= self.weights[i] | ||
total_loss += (dice_loss) | ||
loss = total_loss / (target.size(1) - len(self.ignore_index)) | ||
return loss | ||
|
||
|
||
class WBCEWithLogitLoss(nn.Module): | ||
""" | ||
Weighted Binary Cross Entropy. | ||
`WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)` | ||
To decrease the number of false negatives, set β>1. | ||
To decrease the number of false positives, set β<1. | ||
Args: | ||
@param weight: positive sample weight | ||
Shapes: | ||
output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied | ||
target: A tensor of shape same with output | ||
""" | ||
|
||
def __init__(self, weight=1.0, ignore_index=None, reduction='mean'): | ||
super(WBCEWithLogitLoss, self).__init__() | ||
assert reduction in ['none', 'mean', 'sum'] | ||
self.ignore_index = ignore_index | ||
weight = float(weight) | ||
self.weight = weight | ||
self.reduction = reduction | ||
self.smooth = 0.01 | ||
|
||
def forward(self, output, target): | ||
assert output.shape[0] == target.shape[0], "output & target batch size don't match" | ||
|
||
if self.ignore_index is not None: | ||
valid_mask = (target != self.ignore_index).float() | ||
output = output.mul(valid_mask) # can not use inplace for bp | ||
target = target.float().mul(valid_mask) | ||
|
||
batch_size = output.size(0) | ||
output = output.view(batch_size, -1) | ||
target = target.view(batch_size, -1) | ||
|
||
output = torch.sigmoid(output) | ||
# avoid `nan` loss | ||
eps = 1e-6 | ||
output = torch.clamp(output, min=eps, max=1.0 - eps) | ||
# soft label | ||
target = torch.clamp(target, min=self.smooth, max=1.0 - self.smooth) | ||
|
||
# loss = self.bce(output, target) | ||
loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output))) | ||
if self.reduction == 'mean': | ||
loss = torch.mean(loss) | ||
elif self.reduction == 'sum': | ||
loss = torch.sum(loss) | ||
elif self.reduction == 'none': | ||
loss = loss | ||
else: | ||
raise NotImplementedError | ||
return loss | ||
|
||
|
||
class WBCE_DiceLoss(nn.Module): | ||
def __init__(self, alpha=1.0, weight=1.0, ignore_index=None, reduction='mean'): | ||
""" | ||
combination of Weight Binary Cross Entropy and Binary Dice Loss | ||
Args: | ||
@param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient | ||
@param reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' | ||
@param alpha: weight between WBCE('Weight Binary Cross Entropy') and binary dice, apply on WBCE | ||
Shapes: | ||
output: A tensor of shape [N, *] without sigmoid activation function applied | ||
target: A tensor of shape same with output | ||
""" | ||
super(WBCE_DiceLoss, self).__init__() | ||
assert reduction in ['none', 'mean', 'sum'] | ||
assert 0 <= alpha <= 1, '`alpha` should in [0,1]' | ||
self.alpha = alpha | ||
self.ignore_index = ignore_index | ||
self.reduction = reduction | ||
self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction, general=True) | ||
self.wbce = WBCEWithLogitLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) | ||
self.dice_loss = None | ||
self.wbce_loss = None | ||
|
||
def forward(self, output, target): | ||
self.dice_loss = self.dice(output, target) | ||
self.dice_loss = -torch.log(1 - self.dice_loss) | ||
self.wbce_loss = self.wbce(output, target) | ||
loss = self.alpha * self.wbce_loss + self.dice_loss | ||
return loss | ||
|
||
|
||
class Binary_Focal_Dice(nn.Module): | ||
def __init__(self, **kwargs): | ||
super(Binary_Focal_Dice, self).__init__() | ||
self.dice = BinaryDiceLoss(**kwargs) | ||
self.focal = BinaryFocalLoss(**kwargs) | ||
|
||
def forward(self, logits, target): | ||
dice_loss = self.dice(logits, target) | ||
dice_loss = -torch.log(1 - dice_loss) | ||
focal_loss = self.focal(logits, target) | ||
loss = dice_loss + focal_loss | ||
return loss, (dice_loss.detach(), focal_loss.detach()) | ||
|
||
|
||
def test(): | ||
input = torch.rand((3, 1, 32, 32, 32)) | ||
model = nn.Conv3d(1, 4, 3, padding=1) | ||
target = torch.randint(0, 4, (3, 1, 32, 32, 32)).float() | ||
target = make_one_hot(target, num_classes=4) | ||
criterion = DiceLoss(ignore_index=[2, 3], reduction='mean') | ||
loss = criterion(model(input), target) | ||
loss.backward() | ||
print(loss.item()) | ||
|
||
# input = torch.zeros((1, 2, 32, 32, 32)) | ||
# input[:, 0, ...] = 1 | ||
# target = torch.ones((1, 1, 32, 32, 32)).long() | ||
# target_one_hot = make_one_hot(target, num_classes=2) | ||
# # print(target_one_hot.size()) | ||
# criterion = DiceLoss() | ||
# loss = criterion(input, target_one_hot) | ||
# print(loss.item()) | ||
|
||
|
||
if __name__ == '__main__': | ||
test() |
Oops, something went wrong.