Skip to content

Commit 0455936

Browse files
committed
fix encoder depth & output stride
1 parent d490cdf commit 0455936

File tree

2 files changed

+46
-22
lines changed

2 files changed

+46
-22
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
nn.BatchNorm2d(out_channels),
6262
nn.ReLU(),
6363
)
64-
self.out_channels = out_channels
6564

6665
def forward(self, *features):
6766
return super().forward(features[-1])
@@ -71,21 +70,21 @@ class DeepLabV3PlusDecoder(nn.Module):
7170
def __init__(
7271
self,
7372
encoder_channels: Sequence[int, ...],
73+
encoder_depth: int,
7474
out_channels: int,
7575
atrous_rates: Iterable[int],
7676
output_stride: Literal[8, 16],
7777
aspp_separable: bool,
7878
aspp_dropout: float,
7979
):
8080
super().__init__()
81-
if output_stride not in {8, 16}:
81+
if encoder_depth < 3:
8282
raise ValueError(
83-
"Output stride should be 8 or 16, got {}.".format(output_stride)
83+
"Encoder depth for DeepLabV3Plus decoder cannot be less than 3, got {}.".format(
84+
encoder_depth
85+
)
8486
)
8587

86-
self.out_channels = out_channels
87-
self.output_stride = output_stride
88-
8988
self.aspp = nn.Sequential(
9089
ASPP(
9190
encoder_channels[-1],
@@ -101,10 +100,10 @@ def __init__(
101100
nn.ReLU(),
102101
)
103102

104-
scale_factor = 2 if output_stride == 8 else 4
103+
scale_factor = 4 if output_stride == 16 and encoder_depth > 3 else 2
105104
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
106105

107-
highres_in_channels = encoder_channels[-4]
106+
highres_in_channels = encoder_channels[2]
108107
highres_out_channels = 48 # proposed by authors of paper
109108
self.block1 = nn.Sequential(
110109
nn.Conv2d(
@@ -128,7 +127,7 @@ def __init__(
128127
def forward(self, *features):
129128
aspp_features = self.aspp(features[-1])
130129
aspp_features = self.up(aspp_features)
131-
high_res_features = self.block1(features[-4])
130+
high_res_features = self.block1(features[2])
132131
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
133132
fused_features = self.block2(concat_features)
134133
return fused_features
@@ -228,13 +227,13 @@ def forward(self, x):
228227
class SeparableConv2d(nn.Sequential):
229228
def __init__(
230229
self,
231-
in_channels,
232-
out_channels,
233-
kernel_size,
234-
stride=1,
235-
padding=0,
236-
dilation=1,
237-
bias=True,
230+
in_channels: int,
231+
out_channels: int,
232+
kernel_size: int,
233+
stride: int = 1,
234+
padding: int = 0,
235+
dilation: int = 1,
236+
bias: bool = True,
238237
):
239238
dephtwise_conv = nn.Conv2d(
240239
in_channels,

segmentation_models_pytorch/decoders/deeplabv3/model.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
3535
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
3636
**callable** and **None**.
3737
Default is **None**
38-
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
38+
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
3939
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
4040
on top of encoder if **aux_params** is not **None** (default). Supported params:
4141
- classes (int): A number of classes
4242
- pooling (str): One of "max", "avg". Default is "avg"
4343
- dropout (float): Dropout factor in [0, 1)
4444
- activation (str): An activation function to apply "sigmoid"/"softmax"
4545
(could be **None** to return logits)
46-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
46+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
47+
Keys with ``None`` values are pruned before passing.
4748
4849
Returns:
4950
``torch.nn.Module``: **DeepLabV3**
@@ -72,6 +73,13 @@ def __init__(
7273
):
7374
super().__init__()
7475

76+
if encoder_output_stride not in [8, 16]:
77+
raise ValueError(
78+
"DeeplabV3 support output stride 8 or 16, got {}.".format(
79+
encoder_output_stride
80+
)
81+
)
82+
7583
self.encoder = get_encoder(
7684
encoder_name,
7785
in_channels=in_channels,
@@ -81,6 +89,14 @@ def __init__(
8189
**kwargs,
8290
)
8391

92+
if upsampling is None:
93+
if encoder_depth <= 3:
94+
scale_factor = 2 ** encoder_depth
95+
else:
96+
scale_factor = encoder_output_stride
97+
else:
98+
scale_factor = upsampling
99+
84100
self.decoder = DeepLabV3Decoder(
85101
in_channels=self.encoder.out_channels[-1],
86102
out_channels=decoder_channels,
@@ -90,11 +106,11 @@ def __init__(
90106
)
91107

92108
self.segmentation_head = SegmentationHead(
93-
in_channels=self.decoder.out_channels,
109+
in_channels=decoder_channels,
94110
out_channels=classes,
95111
activation=activation,
96112
kernel_size=1,
97-
upsampling=encoder_output_stride if upsampling is None else upsampling,
113+
upsampling=scale_factor,
98114
)
99115

100116
if aux_params is not None:
@@ -137,7 +153,8 @@ class DeepLabV3Plus(SegmentationModel):
137153
- dropout (float): Dropout factor in [0, 1)
138154
- activation (str): An activation function to apply "sigmoid"/"softmax"
139155
(could be **None** to return logits)
140-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
156+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
157+
Keys with ``None`` values are pruned before passing.
141158
142159
Returns:
143160
``torch.nn.Module``: **DeepLabV3Plus**
@@ -166,6 +183,13 @@ def __init__(
166183
):
167184
super().__init__()
168185

186+
if encoder_output_stride not in [8, 16]:
187+
raise ValueError(
188+
"DeeplabV3Plus support output stride 8 or 16, got {}.".format(
189+
encoder_output_stride
190+
)
191+
)
192+
169193
self.encoder = get_encoder(
170194
encoder_name,
171195
in_channels=in_channels,
@@ -177,6 +201,7 @@ def __init__(
177201

178202
self.decoder = DeepLabV3PlusDecoder(
179203
encoder_channels=self.encoder.out_channels,
204+
encoder_depth=encoder_depth,
180205
out_channels=decoder_channels,
181206
atrous_rates=decoder_atrous_rates,
182207
output_stride=encoder_output_stride,
@@ -185,7 +210,7 @@ def __init__(
185210
)
186211

187212
self.segmentation_head = SegmentationHead(
188-
in_channels=self.decoder.out_channels,
213+
in_channels=decoder_channels,
189214
out_channels=classes,
190215
activation=activation,
191216
kernel_size=1,

0 commit comments

Comments
 (0)