Skip to content

Commit ca521c9

Browse files
authored
add implementation of TverskyLoss and TverskyLossFocall (#405)
* add implementation of TverskyLoss and TverskyLossFocall * add tests for TverskyLoss
1 parent 8a71ed5 commit ca521c9

File tree

5 files changed

+194
-29
lines changed

5 files changed

+194
-29
lines changed

segmentation_models_pytorch/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .lovasz import LovaszLoss
77
from .soft_bce import SoftBCEWithLogitsLoss
88
from .soft_ce import SoftCrossEntropyLoss
9+
from .tversky import TverskyLoss, TverskyLossFocal

segmentation_models_pytorch/losses/_functional.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
import torch.nn.functional as F
88

9-
109
__all__ = [
1110
"focal_loss_with_logits",
1211
"softmax_focal_loss_with_logits",
@@ -35,14 +34,14 @@ def to_tensor(x, dtype=None) -> torch.Tensor:
3534

3635

3736
def focal_loss_with_logits(
38-
output: torch.Tensor,
39-
target: torch.Tensor,
40-
gamma: float = 2.0,
41-
alpha: Optional[float] = 0.25,
42-
reduction: str = "mean",
43-
normalized: bool = False,
44-
reduced_threshold: Optional[float] = None,
45-
eps: float = 1e-6,
37+
output: torch.Tensor,
38+
target: torch.Tensor,
39+
gamma: float = 2.0,
40+
alpha: Optional[float] = 0.25,
41+
reduction: str = "mean",
42+
normalized: bool = False,
43+
reduced_threshold: Optional[float] = None,
44+
eps: float = 1e-6,
4645
) -> torch.Tensor:
4746
"""Compute binary focal loss between target and output logits.
4847
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
@@ -98,13 +97,13 @@ def focal_loss_with_logits(
9897

9998

10099
def softmax_focal_loss_with_logits(
101-
output: torch.Tensor,
102-
target: torch.Tensor,
103-
gamma: float = 2.0,
104-
reduction="mean",
105-
normalized=False,
106-
reduced_threshold: Optional[float] = None,
107-
eps: float = 1e-6,
100+
output: torch.Tensor,
101+
target: torch.Tensor,
102+
gamma: float = 2.0,
103+
reduction="mean",
104+
normalized=False,
105+
reduced_threshold: Optional[float] = None,
106+
eps: float = 1e-6,
108107
) -> torch.Tensor:
109108
"""Softmax version of focal loss between target and output logits.
110109
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
@@ -151,7 +150,7 @@ def softmax_focal_loss_with_logits(
151150

152151

153152
def soft_jaccard_score(
154-
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
153+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
155154
) -> torch.Tensor:
156155
assert output.size() == target.size()
157156
if dims is not None:
@@ -167,7 +166,7 @@ def soft_jaccard_score(
167166

168167

169168
def soft_dice_score(
170-
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
169+
output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None
171170
) -> torch.Tensor:
172171
assert output.size() == target.size()
173172
if dims is not None:
@@ -180,6 +179,22 @@ def soft_dice_score(
180179
return dice_score
181180

182181

182+
def soft_tversky_score(output: torch.Tensor, target: torch.Tensor, alpha: float, beta: float,
183+
smooth: float = 0.0, eps: float = 1e-7, dims=None) -> torch.Tensor:
184+
assert output.size() == target.size()
185+
if dims is not None:
186+
intersection = torch.sum(output * target, dim=dims) # TP
187+
fp = torch.sum(output * (1. - target), dim=dims)
188+
fn = torch.sum((1 - output) * target, dim=dims)
189+
else:
190+
intersection = torch.sum(output * target) # TP
191+
fp = torch.sum(output * (1. - target))
192+
fn = torch.sum((1 - output) * target)
193+
194+
tversky_score = (intersection + smooth) / (intersection + alpha * fp + beta * fn + smooth).clamp_min(eps)
195+
return tversky_score
196+
197+
183198
def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"):
184199
"""
185200
https://arxiv.org/pdf/1711.06753.pdf
@@ -211,7 +226,7 @@ def wing_loss(output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5
211226

212227

213228
def label_smoothed_nll_loss(
214-
lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1
229+
lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1
215230
) -> torch.Tensor:
216231
"""
217232
Source: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py

segmentation_models_pytorch/losses/dice.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
class DiceLoss(_Loss):
1313

1414
def __init__(
15-
self,
16-
mode: str,
17-
classes: Optional[List[int]] = None,
18-
log_loss: bool = False,
19-
from_logits: bool = True,
20-
smooth: float = 0.0,
21-
ignore_index: Optional[int] = None,
22-
eps: float = 1e-7,
15+
self,
16+
mode: str,
17+
classes: Optional[List[int]] = None,
18+
log_loss: bool = False,
19+
from_logits: bool = True,
20+
smooth: float = 0.0,
21+
ignore_index: Optional[int] = None,
22+
eps: float = 1e-7,
2323
):
2424
"""Implementation of Dice loss for image segmentation task.
2525
It supports binary, multiclass and multilabel cases
@@ -104,7 +104,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
104104
y_pred = y_pred * mask
105105
y_true = y_true * mask
106106

107-
scores = soft_dice_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
107+
scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)
108108

109109
if self.log_loss:
110110
loss = -torch.log(scores.clamp_min(self.eps))
@@ -122,4 +122,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
122122
if self.classes is not None:
123123
loss = loss[self.classes]
124124

125+
return self.aggregate_loss(loss)
126+
127+
def aggregate_loss(self, loss):
125128
return loss.mean()
129+
130+
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
131+
return soft_dice_score(output, target, smooth, eps, dims)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import List
2+
3+
import torch
4+
from ._functional import soft_tversky_score
5+
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
6+
from .dice import DiceLoss
7+
8+
__all__ = ["TverskyLoss", "TverskyLossFocal"]
9+
10+
11+
class TverskyLoss(DiceLoss):
12+
"""
13+
Implementation of Tversky loss for image segmentation task. Where TP and FP is weighted by alpha and beta params.
14+
With alpha == beta == 0.5, this loss becomes equal DiceLoss.
15+
It supports binary, multiclass and multilabel cases
16+
"""
17+
18+
def __init__(
19+
self,
20+
mode: str,
21+
classes: List[int] = None,
22+
log_loss=False,
23+
from_logits=True,
24+
smooth: float = 0.0,
25+
ignore_index=None,
26+
eps=1e-7,
27+
alpha=0.5,
28+
beta=0.5
29+
):
30+
"""
31+
:param mode: Metric mode {'binary', 'multiclass', 'multilabel'}
32+
:param classes: Optional list of classes that contribute in loss computation;
33+
By default, all channels are included.
34+
:param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
35+
:param from_logits: If True assumes input is raw logits
36+
:param smooth:
37+
:param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
38+
:param eps: Small epsilon for numerical stability
39+
:param alpha: Weight constant that penalize model for FPs (False Positives)
40+
:param beta: Weight constant that penalize model for FNs (False Positives)
41+
"""
42+
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
43+
super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps)
44+
self.alpha = alpha
45+
self.beta = beta
46+
47+
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
48+
return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims)
49+
50+
51+
class TverskyLossFocal(TverskyLoss):
52+
"""
53+
A variant on the Tversky loss that also includes the gamma modifier from Focal Loss https://arxiv.org/abs/1708.02002
54+
It supports binary, multiclass and multilabel cases
55+
"""
56+
57+
def __init__(
58+
self,
59+
mode: str,
60+
classes: List[int] = None,
61+
log_loss=False,
62+
from_logits=True,
63+
smooth: float = 0.0,
64+
ignore_index=None,
65+
eps=1e-7,
66+
alpha=0.5,
67+
beta=0.5,
68+
gamma=1
69+
):
70+
"""
71+
:param mode: Metric mode {'binary', 'multiclass', 'multilabel'}
72+
:param classes: Optional list of classes that contribute in loss computation;
73+
By default, all channels are included.
74+
:param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
75+
:param from_logits: If True assumes input is raw logits
76+
:param smooth:
77+
:param ignore_index: Label that indicates ignored pixels (does not contribute to loss)
78+
:param eps: Small epsilon for numerical stability
79+
:param alpha: Weight constant that penalize model for FPs (False Positives)
80+
:param beta: Weight constant that penalize model for FNs (False Positives)
81+
:param gamma: Constant that squares the error function
82+
"""
83+
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
84+
super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps, alpha, beta)
85+
self.gamma = gamma
86+
87+
def aggregate_loss(self, loss):
88+
return loss.mean() ** self.gamma

tests/test_losses.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import torch
33
import segmentation_models_pytorch as smp
44
import segmentation_models_pytorch.losses._functional as F
5-
from segmentation_models_pytorch.losses import DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss
5+
from segmentation_models_pytorch.losses import DiceLoss, JaccardLoss, SoftBCEWithLogitsLoss, SoftCrossEntropyLoss, \
6+
TverskyLoss, TverskyLossFocal
67

78

89
def test_focal_loss_with_logits():
@@ -71,6 +72,21 @@ def test_soft_dice_score(y_true, y_pred, expected, eps):
7172
assert float(actual) == pytest.approx(expected, eps)
7273

7374

75+
@pytest.mark.parametrize(
76+
["y_true", "y_pred", "expected", "eps", "alpha", "beta"],
77+
[
78+
[[1, 1, 1, 1], [1, 1, 1, 1], 1.0, 1e-5, 0.5, 0.5],
79+
[[0, 1, 1, 0], [0, 1, 1, 0], 1.0, 1e-5, 0.5, 0.5],
80+
[[1, 1, 1, 1], [1, 1, 0, 0], 2.0 / 3.0, 1e-5, 0.5, 0.5],
81+
],
82+
)
83+
def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta):
84+
y_true = torch.tensor(y_true, dtype=torch.float32)
85+
y_pred = torch.tensor(y_pred, dtype=torch.float32)
86+
actual = F.soft_tversky_score(y_pred, y_true, eps=eps, alpha=alpha, beta=beta)
87+
assert float(actual) == pytest.approx(expected, eps)
88+
89+
7490
@torch.no_grad()
7591
def test_dice_loss_binary():
7692
eps = 1e-5
@@ -109,6 +125,45 @@ def test_dice_loss_binary():
109125
assert float(loss) == pytest.approx(1.0, abs=eps)
110126

111127

128+
@torch.no_grad()
129+
def test_tversky_loss_binary():
130+
eps = 1e-5
131+
# with alpha=0.5; beta=0.5 it is equal to DiceLoss
132+
criterion = TverskyLoss(mode=smp.losses.BINARY_MODE, from_logits=False, alpha=0.5, beta=0.5)
133+
134+
# Ideal case
135+
y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1)
136+
y_true = torch.tensor(([1, 1, 1])).view(1, 1, 1, -1)
137+
loss = criterion(y_pred, y_true)
138+
assert float(loss) == pytest.approx(0.0, abs=eps)
139+
140+
y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1)
141+
y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1)
142+
loss = criterion(y_pred, y_true)
143+
assert float(loss) == pytest.approx(0.0, abs=eps)
144+
145+
y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1)
146+
y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1)
147+
loss = criterion(y_pred, y_true)
148+
assert float(loss) == pytest.approx(0.0, abs=eps)
149+
150+
# Worst case
151+
y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1)
152+
y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1)
153+
loss = criterion(y_pred, y_true)
154+
assert float(loss) == pytest.approx(0.0, abs=eps)
155+
156+
y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1)
157+
y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1)
158+
loss = criterion(y_pred, y_true)
159+
assert float(loss) == pytest.approx(1.0, abs=eps)
160+
161+
y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1)
162+
y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1)
163+
loss = criterion(y_pred, y_true)
164+
assert float(loss) == pytest.approx(1.0, abs=eps)
165+
166+
112167
@torch.no_grad()
113168
def test_binary_jaccard_loss():
114169
eps = 1e-5

0 commit comments

Comments
 (0)