You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: segmentation_models_pytorch/decoders/deeplabv3/model.py
+31-6Lines changed: 31 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
35
35
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
36
36
**callable** and **None**.
37
37
Default is **None**
38
-
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
38
+
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
39
39
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
40
40
on top of encoder if **aux_params** is not **None** (default). Supported params:
41
41
- classes (int): A number of classes
42
42
- pooling (str): One of "max", "avg". Default is "avg"
43
43
- dropout (float): Dropout factor in [0, 1)
44
44
- activation (str): An activation function to apply "sigmoid"/"softmax"
45
45
(could be **None** to return logits)
46
-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
46
+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
47
+
Keys with ``None`` values are pruned before passing.
47
48
48
49
Returns:
49
50
``torch.nn.Module``: **DeepLabV3**
@@ -72,6 +73,13 @@ def __init__(
72
73
):
73
74
super().__init__()
74
75
76
+
ifencoder_output_stridenotin [8, 16]:
77
+
raiseValueError(
78
+
"DeeplabV3 support output stride 8 or 16, got {}.".format(
@@ -137,7 +153,8 @@ class DeepLabV3Plus(SegmentationModel):
137
153
- dropout (float): Dropout factor in [0, 1)
138
154
- activation (str): An activation function to apply "sigmoid"/"softmax"
139
155
(could be **None** to return logits)
140
-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
156
+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
157
+
Keys with ``None`` values are pruned before passing.
141
158
142
159
Returns:
143
160
``torch.nn.Module``: **DeepLabV3Plus**
@@ -166,6 +183,13 @@ def __init__(
166
183
):
167
184
super().__init__()
168
185
186
+
ifencoder_output_stridenotin [8, 16]:
187
+
raiseValueError(
188
+
"DeeplabV3Plus support output stride 8 or 16, got {}.".format(
0 commit comments