医学图像分割损失函数概述

持续更新补充代码,代码基于pytorch

在医学图像分割中,选取合适的损失函数是十分有必要的。已有的文献中提出了许多的损失函数,但只有少部分的文章对提出的损失函数进行了具体的研究。
损失函数主要是用于评估模型的 预测值和真实标签的匹配程度的重要指标。在过去的几年,不同的损失函数被提出并应用到医学图像分割上。一般形式上,损失函数$L$采用期望风险最小化的形式表示:

其中, $G={g_i}$, $S={s_i}$分别表示真实标签和预测的分割图像。

基于分布的损失函数

基于分布的损失函数旨在最小化两种分布的差异。这一类中最基本的是交叉熵,其他的都是基于交叉熵变换来的

cross entropy(CE)

交叉熵是从Kullback-Leibler(KL)散度推导出来的,它是衡量两种分布之间不同的度量。对于一般的机器学习任务,数据的分布是由训练集给出的。因此最小化KL散度等价于最小化交叉熵。交叉熵损失函数被广泛用于分类,由于分割问题是像素级分类问题,因此在分割问题上也适用。交叉熵被定义为:

其中,如果标签$c$是像素$i$的正确分类,则$g_i^c$是二值指标,$s_i^c$是对应的预测概率。

weighted cross entropy(WCE)

加权交叉熵是交叉熵的一般扩展形式:

其中,$w_c$是每类分类的权重;一般情况下,$w_c$是和类别频率呈反比的。

balanced cross entropy

balanced cross entropy与weighted cross entropy相似,也是解决不均衡样本的。不同之处在于不仅权衡了正例,还权衡了负例。

Code

1
2
3
4
5
import torch
import torch.nn.functional as F
# BCELoss
torch.nn.BCELoss()
F.binary_cross_entropy(inputs, targets, reduction='mean')

TopK loss

TopK loss损失函数旨在强迫网络在训练过程中关注硬样本。

其中,$t \in (0,1\ ]$是一个阈值,$1{…}$是一个二元指示函数。

Focal loss

Focal loss是采用标注的CE处理图像中前景和背景分布不均匀,可以减小正确分类类别的损失值。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(FocalLoss, self).__init__()

def forward(self, inputs, targets, alpha=0.8, gamma=2, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.contiguous().view(-1)
targets = targets.contiguous().view(-1)
#first compute binary cross-entropy
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
BCE_EXP = torch.exp(-BCE)
focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
return focal_loss

Distance map penalizad cross entropy loss(DPCE)

DPCE损失旨在引导忘了将重点放在难以识别的图像边缘部分。

其中$D$是惩罚项,·是哈达玛积(Hadamard product)。具体来说,$D$是通过计算ground truth的距离变换,然后将其还原得到的。

基于区域的损失函数

基于区域的损失函数旨在最小化ground truth $G$和预测分割区域$S$二者不匹配的区域,或者最大化$G$和$S$重叠区域。主要代表Dice loss。

Dice loss

dice loss损失可以直接优化dice coefficient,是最常用的分割指标之一。与交叉熵不同,它不需要对不平衡分割任务重新加权。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELoss, self).__init__()

def forward(self, inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.contiguous().view(-1)
targets = targets.contiguous().view(-1)
intersection = (inputs * targets).sum()
dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) # 注意这里已经使用1-dice
return dice_loss

sensitivity-specificity loss

sensitivity-specificity loss通过提高特异性的权重来解决类别不平衡的问题。

其中参数$w$控制这第一项和第二项之间的平衡

IoU loss

IoU loss和dice loss类似,是直接优化目标类别的分割指标

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class IoULoss(nn.Module):

def __init__(self, weight=None, size_average=True):
super(IoULoss, self).__init__()

def forward(self, inputs, targets, smooth=1):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.contiguous().view(-1)
targets = targets.contiguous().view(-1)
#intersection is equivalent to True Positive count
#union is the mutually inclusive area of all labels & predictions
intersection = (inputs * targets).sum()
total = (inputs + targets).sum()
union = total - intersection
IoU = (intersection + smooth)/(union + smooth)
return 1 - IoU

Tversky loss

为了在精度和召回率之间取得更好的平衡,Tversky损失重塑率dice loss并强调了错误的否定。

其中, $\alpha$和$\beta$是超参,控制着假阴性(false negative)和假阳性(false positive)的平衡

Focal Tversky loss

其中,

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class TverskyLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(TverskyLoss, self).__init__()

def forward(self, inputs, targets, smooth=1, alpha=0.5, beta=0.5):
#comment out if your model contains a sigmoid or equivalent activation layer
inputs = torch.sigmoid(inputs)
#flatten label and prediction tensors
inputs = inputs.contiguous().view(-1)
targets = targets.contiguous().view(-1)
#True Positives, False Positives & False Negatives
TP = (inputs * targets).sum()
FP = ((1-targets) * inputs).sum()
FN = (targets * (1-inputs)).sum()
Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)
return 1 - Tversky

generalized dice loss

generalized dice loss是dice loss多分类的扩展。

其中。用于不同标签集的属性不变性。

log-Cosh Dice loss

基于边界的损失函数

基于边界的损失函数是一种新的损失函数类型,旨在最小化ground truth和predicated segmentation的边界距离。

boundary(BD) loss

为了可微的形式计算两个边界间的距离$(\partial G,\partial S)$,边界损失使用边界上的损失而不是使用区域内的不平衡积分来减轻高度不平衡分割的困难。

Hausdorff Distance(HD) loss

直接最小化HD是十分困难的。Karimi等人提出了估计ground truth和predicated sensitivity的HD方法。用以下HD损失函数能减小HD,并用于直接训练。

其中,$d_G$和$d_S$为ground truth和segmentation的距离转换

Shape aware loss

复合损失函数

combo loss

combo loss是CE和dice loss的加权和。

exponential logarithmic loss

Wong等人提出了对CE和dice loss进行指数和对数变换。这样网络就可以被迫的关注预测不准的部分。

其中,