diff --git a/README.md b/README.md index 0c367fea..574101c5 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).** The main features of this library are: - High level API (just two lines to create neural network) - - 8 models architectures for binary and multi class segmentation (including legendary Unet) + - 9 models architectures for binary and multi class segmentation (including legendary Unet) - 99 available encoders - All encoders have pre-trained weights for faster and better convergence @@ -76,6 +76,7 @@ Congratulations! You are done! Now you can train your model with your favorite f #### Architectures - Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)] - Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)] + - MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)] - Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)] - FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)] - PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)] diff --git a/docs/models.rst b/docs/models.rst index 1bda2d27..47de61ee 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -9,6 +9,10 @@ Unet++ ~~~~~~ .. autoclass:: segmentation_models_pytorch.UnetPlusPlus +MAnet +~~~~~~ +.. autoclass:: segmentation_models_pytorch.MAnet + Linknet ~~~~~~~ .. autoclass:: segmentation_models_pytorch.Linknet diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 354c6687..a53b8e51 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -1,5 +1,6 @@ from .unet import Unet from .unetplusplus import UnetPlusPlus +from .manet import MAnet from .linknet import Linknet from .fpn import FPN from .pspnet import PSPNet @@ -24,10 +25,10 @@ def create_model( **kwargs, ) -> torch.nn.Module: """Models wrapper. Allows to create any model just with parametes - + """ - - archs = [Unet, UnetPlusPlus, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN] + + archs = [Unet, UnetPlusPlus, MAnet, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN] archs_dict = {a.__name__.lower(): a for a in archs} try: model_class = archs_dict[arch.lower()] diff --git a/segmentation_models_pytorch/manet/__init__.py b/segmentation_models_pytorch/manet/__init__.py new file mode 100644 index 00000000..f3bdc788 --- /dev/null +++ b/segmentation_models_pytorch/manet/__init__.py @@ -0,0 +1 @@ +from .model import MAnet diff --git a/segmentation_models_pytorch/manet/decoder.py b/segmentation_models_pytorch/manet/decoder.py new file mode 100644 index 00000000..2d587671 --- /dev/null +++ b/segmentation_models_pytorch/manet/decoder.py @@ -0,0 +1,188 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..base import modules as md + + +class PAB(nn.Module): + def __init__(self, in_channels, out_channels, pab_channels=64): + super(PAB, self).__init__() + # Series of 1x1 conv to generate attention feature maps + self.pab_channels = pab_channels + self.in_channels = in_channels + self.top_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) + self.center_conv = nn.Conv2d(in_channels, pab_channels, kernel_size=1) + self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.map_softmax = nn.Softmax(dim=1) + self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x): + bsize = x.size()[0] + h = x.size()[2] + w = x.size()[3] + x_top = self.top_conv(x) + x_center = self.center_conv(x) + x_bottom = self.bottom_conv(x) + + x_top = x_top.flatten(2) + x_center = x_center.flatten(2).transpose(1, 2) + x_bottom = x_bottom.flatten(2).transpose(1, 2) + + sp_map = torch.matmul(x_center, x_top) + sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w) + sp_map = torch.matmul(sp_map, x_bottom) + sp_map = sp_map.reshape(bsize, self.in_channels, h, w) + x = x + sp_map + x = self.out_conv(x) + return x + + +class MFAB(nn.Module): + def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16): + # MFAB is just a modified version of SE-blocks, one for skip, one for input + super(MFAB, self).__init__() + self.hl_conv = nn.Sequential( + md.Conv2dReLU( + in_channels, + in_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ), + md.Conv2dReLU( + in_channels, + skip_channels, + kernel_size=1, + use_batchnorm=use_batchnorm, + ) + ) + self.SE_ll = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(skip_channels, skip_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv2d(skip_channels // reduction, skip_channels, 1), + nn.Sigmoid(), + ) + self.SE_hl = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(skip_channels, skip_channels // reduction, 1), + nn.ReLU(inplace=True), + nn.Conv2d(skip_channels // reduction, skip_channels, 1), + nn.Sigmoid(), + ) + self.conv1 = md.Conv2dReLU( + skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + def forward(self, x, skip=None): + x = self.hl_conv(x) + x = F.interpolate(x, scale_factor=2, mode="nearest") + attention_hl = self.SE_hl(x) + if skip is not None: + attention_ll = self.SE_ll(skip) + attention_hl = attention_hl + attention_ll + x = x * attention_hl + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + skip_channels, + out_channels, + use_batchnorm=True + ): + super().__init__() + self.conv1 = md.Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = md.Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + + def forward(self, x, skip=None): + x = F.interpolate(x, scale_factor=2, mode="nearest") + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class MAnetDecoder(nn.Module): + def __init__( + self, + encoder_channels, + decoder_channels, + n_blocks=5, + reduction=16, + use_batchnorm=True, + pab_channels=64 + ): + super().__init__() + + if n_blocks != len(decoder_channels): + raise ValueError( + "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( + n_blocks, len(decoder_channels) + ) + ) + + encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution + encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder + + # computing blocks input and output channels + head_channels = encoder_channels[0] + in_channels = [head_channels] + list(decoder_channels[:-1]) + skip_channels = list(encoder_channels[1:]) + [0] + out_channels = decoder_channels + + self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) + + # combine decoder keyword arguments + kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here + blocks = [ + MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else + DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) + for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) + ] + # for the last we dont have skip connection -> use simple decoder block + self.blocks = nn.ModuleList(blocks) + + def forward(self, *features): + + features = features[1:] # remove first skip with same spatial resolution + features = features[::-1] # reverse channels to start from head of encoder + + head = features[0] + skips = features[1:] + + x = self.center(head) + for i, decoder_block in enumerate(self.blocks): + skip = skips[i] if i < len(skips) else None + x = decoder_block(x, skip) + + return x diff --git a/segmentation_models_pytorch/manet/model.py b/segmentation_models_pytorch/manet/model.py new file mode 100644 index 00000000..0dab8d67 --- /dev/null +++ b/segmentation_models_pytorch/manet/model.py @@ -0,0 +1,96 @@ +from typing import Optional, Union, List +from .decoder import MAnetDecoder +from ..encoders import get_encoder +from ..base import SegmentationModel +from ..base import SegmentationHead, ClassificationHead + + +class MAnet(SegmentationModel): + """MAnet_ : Multi-scale Attention Net. + The MA-Net can capture rich contextual dependencies based on the attention mechanism, using two blocks: + Position-wise Attention Block (PAB, which captures the spatial dependencies between pixels in a global view) + and Multi-scale Fusion Attention Block (MFAB, which captures the channel dependencies between any feature map by + multi-scale semantic feature fusion) + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution + encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features + two times smaller in spatial dimentions than previous one (e.g. for depth 0 we will have features + with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). + Default is 5 + encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and + other pretrained weights (see table with available weights for each encoder_name) + decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. + Lenght of the list should be the same as **encoder_depth** + decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers + is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. + Avaliable options are **True, False, "inplace"** + decoder_pab_channels: A number of channels for PAB module in decoder. + Default is 64. + in_channels: A number of input channels for the model, default is 3 (RGB images) + classes: A number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Avaliable options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"identity"**, **callable** and **None**. + Default is **None** + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + - classes (int): A number of classes + - pooling (str): One of "max", "avg". Default is "avg" + - dropout (float): Dropout factor in [0, 1) + - activation (str): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits) + + Returns: + ``torch.nn.Module``: **MAnet** + + .. _MAnet: + https://ieeexplore.ieee.org/abstract/document/9201310 + + """ + + def __init__( + self, + encoder_name: str = "resnet34", + encoder_depth: int = 5, + encoder_weights: str = "imagenet", + decoder_use_batchnorm: bool = True, + decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_pab_channels: int = 64, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, callable]] = None, + aux_params: Optional[dict] = None + ): + super().__init__() + + self.encoder = get_encoder( + encoder_name, + in_channels=in_channels, + depth=encoder_depth, + weights=encoder_weights, + ) + + self.decoder = MAnetDecoder( + encoder_channels=self.encoder.out_channels, + decoder_channels=decoder_channels, + n_blocks=encoder_depth, + use_batchnorm=decoder_use_batchnorm, + pab_channels=decoder_pab_channels + ) + + self.segmentation_head = SegmentationHead( + in_channels=decoder_channels[-1], + out_channels=classes, + activation=activation, + kernel_size=3, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "manet-{}".format(encoder_name) + self.initialize() diff --git a/tests/test_models.py b/tests/test_models.py index 865f42f8..29f60f11 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,7 +29,7 @@ def get_encoders(): def get_sample(model_class): - if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]: + if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]: sample = torch.ones([1, 3, 64, 64]) elif model_class == smp.PAN: sample = torch.ones([2, 3, 256, 256]) @@ -58,7 +58,7 @@ def _test_forward_backward(model, sample, test_shape=False): @pytest.mark.parametrize("encoder_depth", [3, 5]) @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]) def test_forward(model_class, encoder_name, encoder_depth, **kwargs): - if model_class is smp.Unet or model_class is smp.UnetPlusPlus: + if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet: kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] model = model_class( encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs @@ -75,7 +75,7 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs): @pytest.mark.parametrize( "model_class", - [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3] + [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3] ) def test_forward_backward(model_class): sample = get_sample(model_class) @@ -83,7 +83,7 @@ def test_forward_backward(model_class): _test_forward_backward(model, sample) -@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus]) +@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet]) def test_aux_output(model_class): model = model_class( DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)