Skip to content

Commit

Permalink
local change
Browse files Browse the repository at this point in the history
  • Loading branch information
Hsuxu committed Oct 13, 2021
0 parents commit 7acf46c
Show file tree
Hide file tree
Showing 14 changed files with 2,069 additions and 0 deletions.
110 changes: 110 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
.idea/
.vscode/
ProstateX-0002/
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
Data/
# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Loss_ToolBox
---

## Introduction
This repository include several losses for 3D image segmentation.
1. [Focal Loss](https://arxiv.org/abs/1708.02002) (PS:Borrow some code from [c0nn3r/RetinaNet](https://github.com/c0nn3r/RetinaNet))
2. [Lovasz-Softmax Loss](https://arxiv.org/abs/1705.08790)(Modify from orinial implementation [LovaszSoftmax](https://github.com/bermanmaxim/LovaszSoftmax))
3. [DiceLoss](https://arxiv.org/abs/1606.04797)

4 changes: 4 additions & 0 deletions seg_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .dice_loss import *
from focal_loss import *
from lovasz_loss import *
from tverskyloss import *
248 changes: 248 additions & 0 deletions seg_loss/dice_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from .focal_loss import BinaryFocalLoss


def make_one_hot(input, num_classes=None):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Shapes:
predict: A tensor of shape [N, *] without sigmoid activation function applied
target: A tensor of shape same with predict
Returns:
A tensor of shape [N, num_classes, *]
"""
if num_classes is None:
num_classes = input.max() + 1
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu().long(), 1)
return result


class BinaryDiceLoss(nn.Module):
"""Dice loss of binary class
Args:
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
Shapes:
output: A tensor of shape [N, *] without sigmoid activation function applied
target: A tensor of shape same with output
Returns:
Loss tensor according to arg reduction
Raise:
Exception if unexpected reduction
"""

def __init__(self, ignore_index=None, reduction='mean', **kwargs):
super(BinaryDiceLoss, self).__init__()
self.smooth = 1 # suggest set a large number when target area is large,like '10|100'
self.ignore_index = ignore_index
self.reduction = reduction
self.batch_dice = False # treat a large map when True
if 'batch_loss' in kwargs.keys():
self.batch_dice = kwargs['batch_loss']

def forward(self, output, target, use_sigmoid=True):
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
if use_sigmoid:
output = torch.sigmoid(output)

if self.ignore_index is not None:
validmask = (target != self.ignore_index).float()
output = output.mul(validmask) # can not use inplace for bp
target = target.float().mul(validmask)

dim0 = output.shape[0]
if self.batch_dice:
dim0 = 1

output = output.contiguous().view(dim0, -1)
target = target.contiguous().view(dim0, -1).float()

num = 2 * torch.sum(torch.mul(output, target), dim=1) + self.smooth
den = torch.sum(output.abs() + target.abs(), dim=1) + self.smooth

loss = 1 - (num / den)

if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'none':
return loss
else:
raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
"""Dice loss, need one hot encode input
Args:
weight: An array of shape [num_classes,]
ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
output: A tensor of shape [N, C, *]
target: A tensor of same shape with output
other args pass to BinaryDiceLoss
Return:
same as BinaryDiceLoss
"""

def __init__(self, weight=None, ignore_index=None, **kwargs):
super(DiceLoss, self).__init__()
self.kwargs = kwargs
self.weight = weight
if isinstance(ignore_index, (int, float)):
self.ignore_index = [int(ignore_index)]
elif ignore_index is None:
self.ignore_index = []
elif isinstance(ignore_index, (list, tuple)):
self.ignore_index = ignore_index
else:
raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index)))

def forward(self, output, target):
assert output.shape == target.shape, 'output & target shape do not match'
dice = BinaryDiceLoss(**self.kwargs)
total_loss = 0
output = F.softmax(output, dim=1)
for i in range(target.shape[1]):
if i not in self.ignore_index:
dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False)
if self.weight is not None:
assert self.weight.shape[0] == target.shape[1], \
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
dice_loss *= self.weights[i]
total_loss += (dice_loss)
loss = total_loss / (target.size(1) - len(self.ignore_index))
return loss


class WBCEWithLogitLoss(nn.Module):
"""
Weighted Binary Cross Entropy.
`WBCE(p,t)=-β*t*log(p)-(1-t)*log(1-p)`
To decrease the number of false negatives, set β>1.
To decrease the number of false positives, set β<1.
Args:
@param weight: positive sample weight
Shapes:
output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied
target: A tensor of shape same with output
"""

def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
super(WBCEWithLogitLoss, self).__init__()
assert reduction in ['none', 'mean', 'sum']
self.ignore_index = ignore_index
weight = float(weight)
self.weight = weight
self.reduction = reduction
self.smooth = 0.01

def forward(self, output, target):
assert output.shape[0] == target.shape[0], "output & target batch size don't match"

if self.ignore_index is not None:
valid_mask = (target != self.ignore_index).float()
output = output.mul(valid_mask) # can not use inplace for bp
target = target.float().mul(valid_mask)

batch_size = output.size(0)
output = output.view(batch_size, -1)
target = target.view(batch_size, -1)

output = torch.sigmoid(output)
# avoid `nan` loss
eps = 1e-6
output = torch.clamp(output, min=eps, max=1.0 - eps)
# soft label
target = torch.clamp(target, min=self.smooth, max=1.0 - self.smooth)

# loss = self.bce(output, target)
loss = -self.weight * target.mul(torch.log(output)) - ((1.0 - target).mul(torch.log(1.0 - output)))
if self.reduction == 'mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
elif self.reduction == 'none':
loss = loss
else:
raise NotImplementedError
return loss


class WBCE_DiceLoss(nn.Module):
def __init__(self, alpha=1.0, weight=1.0, ignore_index=None, reduction='mean'):
"""
combination of Weight Binary Cross Entropy and Binary Dice Loss
Args:
@param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient
@param reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'
@param alpha: weight between WBCE('Weight Binary Cross Entropy') and binary dice, apply on WBCE
Shapes:
output: A tensor of shape [N, *] without sigmoid activation function applied
target: A tensor of shape same with output
"""
super(WBCE_DiceLoss, self).__init__()
assert reduction in ['none', 'mean', 'sum']
assert 0 <= alpha <= 1, '`alpha` should in [0,1]'
self.alpha = alpha
self.ignore_index = ignore_index
self.reduction = reduction
self.dice = BinaryDiceLoss(ignore_index=ignore_index, reduction=reduction, general=True)
self.wbce = WBCEWithLogitLoss(weight=weight, ignore_index=ignore_index, reduction=reduction)
self.dice_loss = None
self.wbce_loss = None

def forward(self, output, target):
self.dice_loss = self.dice(output, target)
self.dice_loss = -torch.log(1 - self.dice_loss)
self.wbce_loss = self.wbce(output, target)
loss = self.alpha * self.wbce_loss + self.dice_loss
return loss


class Binary_Focal_Dice(nn.Module):
def __init__(self, **kwargs):
super(Binary_Focal_Dice, self).__init__()
self.dice = BinaryDiceLoss(**kwargs)
self.focal = BinaryFocalLoss(**kwargs)

def forward(self, logits, target):
dice_loss = self.dice(logits, target)
dice_loss = -torch.log(1 - dice_loss)
focal_loss = self.focal(logits, target)
loss = dice_loss + focal_loss
return loss, (dice_loss.detach(), focal_loss.detach())


def test():
input = torch.rand((3, 1, 32, 32, 32))
model = nn.Conv3d(1, 4, 3, padding=1)
target = torch.randint(0, 4, (3, 1, 32, 32, 32)).float()
target = make_one_hot(target, num_classes=4)
criterion = DiceLoss(ignore_index=[2, 3], reduction='mean')
loss = criterion(model(input), target)
loss.backward()
print(loss.item())

# input = torch.zeros((1, 2, 32, 32, 32))
# input[:, 0, ...] = 1
# target = torch.ones((1, 1, 32, 32, 32)).long()
# target_one_hot = make_one_hot(target, num_classes=2)
# # print(target_one_hot.size())
# criterion = DiceLoss()
# loss = criterion(input, target_one_hot)
# print(loss.item())


if __name__ == '__main__':
test()
Loading

0 comments on commit 7acf46c

Please sign in to comment.