Skip to content

Commit 13bc310

Browse files
authored
Add models wrapper for configs (#309)
1 parent bdc58db commit 13bc310

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

segmentation_models_pytorch/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,35 @@
1010
from . import utils
1111

1212
from .__version__ import __version__
13+
14+
from typing import Optional
15+
import torch
16+
17+
18+
def create_model(
19+
arch: str,
20+
encoder_name: str = "resnet34",
21+
encoder_weights: Optional[str] = "imagenet",
22+
in_channels: int = 3,
23+
classes: int = 1,
24+
**kwargs,
25+
) -> torch.nn.Module:
26+
"""Models wrapper. Allows to create any model just with parametes
27+
28+
"""
29+
30+
archs = [Unet, UnetPlusPlus, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
31+
archs_dict = {a.__name__.lower(): a for a in archs}
32+
try:
33+
model_class = archs_dict[arch.lower()]
34+
except KeyError:
35+
raise KeyError("Wrong architecture type `{}`. Avalibale options are: {}".format(
36+
arch, list(archs_dict.keys()),
37+
))
38+
return model_class(
39+
encoder_name=encoder_name,
40+
encoder_weights=encoder_weights,
41+
in_channels=in_channels,
42+
classes=classes,
43+
**kwargs,
44+
)

0 commit comments

Comments
 (0)