Segmentation models is python library with Neural Networks for Image Segmentation based on PyTorch.
The main features of this library are:
- High level API (just two lines to create neural network)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 31 available encoders for each architecture
- All encoders have pre-trained weights for faster and better convergence
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
import segmentation_models_pytorch as smp
model = smp.Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = smp.Unet('resnet34', encoder_weights='imagenet')
Change number of output classes in the model:
model = smp.Unet('resnet34', classes=3, activation='softmax')
All models have pretrained encoders, so you have to prepare your data the same way as during weights pretraining:
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
- Training model for cars segmentation on CamVid dataset here.
- Training model with Catalyst (high-level framework for PyTorch) - here.
Encoder | Weights | Params, M |
---|---|---|
resnet18 | imagenet | 11M |
resnet34 | imagenet | 21M |
resnet50 | imagenet | 23M |
resnet101 | imagenet | 42M |
resnet152 | imagenet | 58M |
resnext50_32x4d | imagenet | 22M |
resnext101_32x8d | imagenet |
86M |
resnext101_32x16d | 191M | |
resnext101_32x32d | 466M | |
resnext101_32x48d | 826M | |
dpn68 | imagenet | 11M |
dpn68b | imagenet+5k | 11M |
dpn92 | imagenet+5k | 34M |
dpn98 | imagenet | 58M |
dpn107 | imagenet+5k | 84M |
dpn131 | imagenet | 76M |
vgg11 | imagenet | 9M |
vgg11_bn | imagenet | 9M |
vgg13 | imagenet | 9M |
vgg13_bn | imagenet | 9M |
vgg16 | imagenet | 14M |
vgg16_bn | imagenet | 14M |
vgg19 | imagenet | 20M |
vgg19_bn | imagenet | 20M |
senet154 | imagenet | 113M |
se_resnet50 | imagenet | 26M |
se_resnet101 | imagenet | 47M |
se_resnet152 | imagenet | 64M |
se_resnext50_32x4d | imagenet | 25M |
se_resnext101_32x4d | imagenet | 46M |
densenet121 | imagenet | 6M |
densenet169 | imagenet | 12M |
densenet201 | imagenet | 18M |
densenet161 | imagenet | 26M |
inceptionresnetv2 | imagenet imagenet+background |
54M |
inceptionv4 | imagenet imagenet+background |
41M |
efficientnet-b0 | imagenet | 4M |
efficientnet-b1 | imagenet | 6M |
efficientnet-b2 | imagenet | 7M |
efficientnet-b3 | imagenet | 10M |
efficientnet-b4 | imagenet | 17M |
efficientnet-b5 | imagenet | 28M |
efficientnet-b6 | imagenet | 40M |
efficientnet-b7 | imagenet | 63M |
model.encoder
- pretrained backbone to extract features of different spatial resolutionmodel.decoder
- segmentation head, depends on models architecture (Unet
/Linknet
/PSPNet
/FPN
)model.activation
- output activation function, one ofsigmoid
,softmax
model.forward(x)
- sequentially passx
through model`s encoder and decoder (return logits!)model.predict(x)
- inference method, switch model to.eval()
mode, call.forward(x)
and apply activation function withtorch.no_grad()
PyPI version:
$ pip install segmentation-models-pytorch
Latest version from source:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
Project is distributed under MIT License
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py