diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 7cd9c3ec..91367cea 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -14,9 +14,7 @@ jobs: with: python-version: '3.9' - name: Install dependencies - run: | - pip install --upgrade pip uv - uv pip install build twine + run: pip install --upgrade pip twine - name: Build and publish env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index bf04ab0a..fbce9a45 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -51,7 +51,7 @@ jobs: run: uv pip list - name: Test with PyTest - run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match" + run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml --non-marked-only - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 @@ -73,7 +73,52 @@ jobs: - name: Show installed packages run: uv pip list - name: Test with PyTest - run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -k "logits_match" + run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -m "logits_match" + + test_torch_compile: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "compile" + + test_torch_export: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "torch_export" + + test_torch_script: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: uv pip install -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: uv pip list + - name: Test with PyTest + run: uv run pytest -v -rsx -n 2 -m "torch_script" minimum: runs-on: ubuntu-latest @@ -88,4 +133,4 @@ jobs: - name: Show installed packages run: uv pip list - name: Test with pytest - run: uv run pytest -v -rsx -n 2 -k "not logits_match" + run: uv run pytest -v -rsx -n 2 --non-marked-only diff --git a/Makefile b/Makefile index a58d230f..9cbf9bdc 100644 --- a/Makefile +++ b/Makefile @@ -1,26 +1,39 @@ -.PHONY: test +.PHONY: test # Declare the 'test' target as phony to avoid conflicts with files named 'test' +# Variables to store the paths of the python, pip, pytest, and ruff executables +PYTHON := $(shell which python) +PIP := $(shell which pip) +PYTEST := $(shell which pytest) +RUFF := $(shell which ruff) + +# Target to create a Python virtual environment .venv: - python3 -m venv .venv + $(PYTHON) -m venv $(shell dirname $(PYTHON)) +# Target to install development dependencies in the virtual environment install_dev: .venv - .venv/bin/pip install -e ".[test]" + $(PIP) install -e ".[test]" +# Target to run tests with pytest, using 2 parallel processes and only non-marked tests test: .venv - .venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match" + $(PYTEST) -v -rsx -n 2 tests/ --non-marked-only +# Target to run all tests with pytest, including slow tests, using 2 parallel processes test_all: .venv - RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/ + RUN_SLOW=1 $(PYTEST) -v -rsx -n 2 tests/ +# Target to generate a table by running a Python script table: - .venv/bin/python misc/generate_table.py + $(PYTHON) misc/generate_table.py +# Target to generate a table for timm by running a Python script table_timm: - .venv/bin/python misc/generate_table_timm.py + $(PYTHON) misc/generate_table_timm.py +# Target to fix and format code using ruff fixup: - .venv/bin/ruff check --fix - .venv/bin/ruff format + $(RUFF) check --fix + $(RUFF) format +# Target to run code formatting and tests all: fixup test - diff --git a/README.md b/README.md index b3a0b3ff..c3df5718 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,51 @@
![logo](https://i.ibb.co/dc1XdhT/Segmentation-Models-V2-Side-1-1.png) -**Python library with Neural Networks for Image +**Python library with Neural Networks for Image Semantic Segmentation based on [PyTorch](https://pytorch.org/).** -[![Generic badge](https://img.shields.io/badge/License-MIT-.svg?style=for-the-badge)](https://github.com/qubvel/segmentation_models.pytorch/blob/main/LICENSE) + [![GitHub Workflow Status (branch)](https://img.shields.io/github/actions/workflow/status/qubvel/segmentation_models.pytorch/tests.yml?branch=main&style=for-the-badge)](https://github.com/qubvel/segmentation_models.pytorch/actions/workflows/tests.yml) +![Codecov](https://img.shields.io/codecov/c/github/qubvel-org/segmentation_models.pytorch?style=for-the-badge) [![Read the Docs](https://img.shields.io/readthedocs/smp?style=for-the-badge&logo=readthedocs&logoColor=white)](https://smp.readthedocs.io/en/latest/)
-[![PyPI](https://img.shields.io/pypi/v/segmentation-models-pytorch?color=blue&style=for-the-badge&logo=pypi&logoColor=white)](https://pypi.org/project/segmentation-models-pytorch/) -[![PyPI - Downloads](https://img.shields.io/pypi/dm/segmentation-models-pytorch?style=for-the-badge&color=blue)](https://pepy.tech/project/segmentation-models-pytorch) -
-[![PyTorch - Version](https://img.shields.io/badge/PYTORCH-1.4+-red?style=for-the-badge&logo=pytorch)](https://pepy.tech/project/segmentation-models-pytorch) +[![PyPI](https://img.shields.io/pypi/v/segmentation-models-pytorch?color=red&style=for-the-badge&logo=pypi&logoColor=white)](https://pypi.org/project/segmentation-models-pytorch/) +[![PyTorch - Version](https://img.shields.io/badge/PYTORCH-1.9+-red?style=for-the-badge&logo=pytorch)](https://pepy.tech/project/segmentation-models-pytorch) [![Python - Version](https://img.shields.io/badge/PYTHON-3.9+-red?style=for-the-badge&logo=python&logoColor=white)](https://pepy.tech/project/segmentation-models-pytorch) +
+[![Generic badge](https://img.shields.io/badge/License-MIT-.svg?style=for-the-badge&color=blue)](https://github.com/qubvel/segmentation_models.pytorch/blob/main/LICENSE) +[![PyPI - Downloads](https://img.shields.io/pypi/dm/segmentation-models-pytorch?style=for-the-badge&color=blue)](https://pepy.tech/project/segmentation-models-pytorch)
-The main features of this library are: - - - High-level API (just two lines to create a neural network) - - 11 models architectures for binary and multi class segmentation (including legendary Unet) - - 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models)) - - All encoders have pre-trained weights for faster and better convergence - - Popular metrics and losses for training routines +The main features of the library are: + + - Super simple high-level API (just two lines to create a neural network) + - 12 encoder-decoder model architectures (Unet, Unet++, Segformer, DPT, ...) + - 800+ **pretrained** convolution- and transform-based encoders, including [timm](https://github.com/huggingface/pytorch-image-models) support + - Popular metrics and losses for training routines (Dice, Jaccard, Tversky, ...) + - ONNX export and torch script/trace/compile friendly + +### Community-Driven Project, Supported By + + + + + +
+ + withoutBG API Logo + + + withoutBG API +
+ https://withoutbg.com +
+

+ High-quality background removal API +
+

+
### [πŸ“š Project Documentation πŸ“š](http://smp.readthedocs.io/) @@ -31,21 +54,18 @@ Visit [Read The Docs Project Page](https://smp.readthedocs.io/) or read the foll ### πŸ“‹ Table of content 1. [Quick start](#start) 2. [Examples](#examples) - 3. [Models](#models) - 1. [Architectures](#architectures) - 2. [Encoders](#encoders) - 3. [Timm Encoders](#timm) + 3. [Models and encoders](#models-and-encoders) 4. [Models API](#api) 1. [Input channels](#input-channels) 2. [Auxiliary classification output](#auxiliary-classification-output) 3. [Depth](#depth) 5. [Installation](#installation) - 6. [Competitions won with the library](#competitions-won-with-the-library) + 6. [Competitions won with the library](#competitions) 7. [Contributing](#contributing) 8. [Citing](#citing) 9. [License](#license) -### ⏳ Quick start +## ⏳ Quick start #### 1. Create your first Segmentation model with SMP @@ -76,361 +96,71 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') Congratulations! You are done! Now you can train your model with your favorite framework! -### πŸ’‘ Examples - - Training model for pets binary segmentation with Pytorch-Lightning [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb) - - Training model for cars segmentation on CamVid dataset [here](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb). - - Training SMP model with [Catalyst](https://github.com/catalyst-team/catalyst) (high-level framework for PyTorch), [TTAch](https://github.com/qubvel/ttach) (TTA library for PyTorch) and [Albumentations](https://github.com/albu/albumentations) (fast image augmentation library) - [here](https://github.com/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/v21.02rc0/examples/notebooks/segmentation-tutorial.ipynb) - - Training SMP model with [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io) framework - [here](https://github.com/ternaus/cloths_segmentation) (clothes binary segmentation by [@ternaus](https://github.com/ternaus)). - - Export trained model to ONNX - [notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) - -### πŸ“¦ Models - -#### Architectures - - Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)] - - Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)] - - MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)] - - Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)] - - FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)] - - PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)] - - PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)] - - DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)] - - DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)] - - UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)] - - Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)] - -#### Encoders - -The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (`encoder_name` and `encoder_weights` parameters). - -
-ResNet -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|resnet18 |imagenet / ssl / swsl |11M | -|resnet34 |imagenet |21M | -|resnet50 |imagenet / ssl / swsl |23M | -|resnet101 |imagenet |42M | -|resnet152 |imagenet |58M | - -
-
- -
-ResNeXt -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|resnext50_32x4d |imagenet / ssl / swsl |22M | -|resnext101_32x4d |ssl / swsl |42M | -|resnext101_32x8d |imagenet / instagram / ssl / swsl|86M | -|resnext101_32x16d |instagram / ssl / swsl |191M | -|resnext101_32x32d |instagram |466M | -|resnext101_32x48d |instagram |826M | - -
-
- -
-ResNeSt -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|timm-resnest14d |imagenet |8M | -|timm-resnest26d |imagenet |15M | -|timm-resnest50d |imagenet |25M | -|timm-resnest101e |imagenet |46M | -|timm-resnest200e |imagenet |68M | -|timm-resnest269e |imagenet |108M | -|timm-resnest50d_4s2x40d |imagenet |28M | -|timm-resnest50d_1s4x24d |imagenet |23M | - -
-
- -
-Res2Ne(X)t -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|timm-res2net50_26w_4s |imagenet |23M | -|timm-res2net101_26w_4s |imagenet |43M | -|timm-res2net50_26w_6s |imagenet |35M | -|timm-res2net50_26w_8s |imagenet |46M | -|timm-res2net50_48w_2s |imagenet |23M | -|timm-res2net50_14w_8s |imagenet |23M | -|timm-res2next50 |imagenet |22M | - -
-
- -
-RegNet(x/y) -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|timm-regnetx_002 |imagenet |2M | -|timm-regnetx_004 |imagenet |4M | -|timm-regnetx_006 |imagenet |5M | -|timm-regnetx_008 |imagenet |6M | -|timm-regnetx_016 |imagenet |8M | -|timm-regnetx_032 |imagenet |14M | -|timm-regnetx_040 |imagenet |20M | -|timm-regnetx_064 |imagenet |24M | -|timm-regnetx_080 |imagenet |37M | -|timm-regnetx_120 |imagenet |43M | -|timm-regnetx_160 |imagenet |52M | -|timm-regnetx_320 |imagenet |105M | -|timm-regnety_002 |imagenet |2M | -|timm-regnety_004 |imagenet |3M | -|timm-regnety_006 |imagenet |5M | -|timm-regnety_008 |imagenet |5M | -|timm-regnety_016 |imagenet |10M | -|timm-regnety_032 |imagenet |17M | -|timm-regnety_040 |imagenet |19M | -|timm-regnety_064 |imagenet |29M | -|timm-regnety_080 |imagenet |37M | -|timm-regnety_120 |imagenet |49M | -|timm-regnety_160 |imagenet |80M | -|timm-regnety_320 |imagenet |141M | - -
-
- -
-GERNet -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|timm-gernet_s |imagenet |6M | -|timm-gernet_m |imagenet |18M | -|timm-gernet_l |imagenet |28M | - -
-
- -
-SE-Net -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|senet154 |imagenet |113M | -|se_resnet50 |imagenet |26M | -|se_resnet101 |imagenet |47M | -|se_resnet152 |imagenet |64M | -|se_resnext50_32x4d |imagenet |25M | -|se_resnext101_32x4d |imagenet |46M | +## πŸ’‘ Examples -
-
- -
-SK-ResNe(X)t -
+| Name | Link | Colab | +|-------------------------------------------|-----------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------| +| **Train** pets binary segmentation on OxfordPets | [Notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb) | +| **Train** cars binary segmentation on CamVid | [Notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/cars%20segmentation%20(camvid).ipynb) | +| **Train** multiclass segmentation on CamVid | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/camvid_segmentation_multiclass.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel-org/segmentation_models.pytorch/blob/main/examples/camvid_segmentation_multiclass.ipynb) | +| **Train** clothes binary segmentation by @ternaus | [Repo](https://github.com/ternaus/cloths_segmentation) | | +| **Load and inference** pretrained Segformer | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) | +| **Load and inference** pretrained DPT | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) | +| **Load and inference** pretrained DPT | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/upernet_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/upernet_inference_pretrained.ipynb) | +| **Save and load** models locally / to HuggingFace Hub |[Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb) +| **Export** trained model to ONNX | [Notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) | -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|timm-skresnet18 |imagenet |11M | -|timm-skresnet34 |imagenet |21M | -|timm-skresnext50_32x4d |imagenet |25M | -
-
+## πŸ“¦ Models and encoders -
-DenseNet -
+### Architectures +| Architecture | Paper | Documentation | Checkpoints | +|--------------|-------|---------------|------------| +| Unet | [paper](https://arxiv.org/abs/1505.04597) | [docs](https://smp.readthedocs.io/en/latest/models.html#unet) | | +| Unet++ | [paper](https://arxiv.org/pdf/1807.10165.pdf) | [docs](https://smp.readthedocs.io/en/latest/models.html#unetplusplus) | | +| MAnet | [paper](https://ieeexplore.ieee.org/abstract/document/9201310) | [docs](https://smp.readthedocs.io/en/latest/models.html#manet) | | +| Linknet | [paper](https://arxiv.org/abs/1707.03718) | [docs](https://smp.readthedocs.io/en/latest/models.html#linknet) | | +| FPN | [paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf) | [docs](https://smp.readthedocs.io/en/latest/models.html#fpn) | | +| PSPNet | [paper](https://arxiv.org/abs/1612.01105) | [docs](https://smp.readthedocs.io/en/latest/models.html#pspnet) | | +| PAN | [paper](https://arxiv.org/abs/1805.10180) | [docs](https://smp.readthedocs.io/en/latest/models.html#pan) | | +| DeepLabV3 | [paper](https://arxiv.org/abs/1706.05587) | [docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3) | | +| DeepLabV3+ | [paper](https://arxiv.org/abs/1802.02611) | [docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3plus) | | +| UPerNet | [paper](https://arxiv.org/abs/1807.10221) | [docs](https://smp.readthedocs.io/en/latest/models.html#upernet) | [checkpoints](https://huggingface.co/collections/smp-hub/upernet-67fadcdbe08418c6ea94f768) | +| Segformer | [paper](https://arxiv.org/abs/2105.15203) | [docs](https://smp.readthedocs.io/en/latest/models.html#segformer) | [checkpoints](https://huggingface.co/collections/smp-hub/segformer-6749eb4923dea2c355f29a1f) | +| DPT | [paper](https://arxiv.org/abs/2103.13413) | [docs](https://smp.readthedocs.io/en/latest/models.html#dpt) | [checkpoints](https://huggingface.co/collections/smp-hub/dpt-67f30487327c0599a0c62d68) | -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|densenet121 |imagenet |6M | -|densenet169 |imagenet |12M | -|densenet201 |imagenet |18M | -|densenet161 |imagenet |26M | +### Encoders -
-
+The library provides a wide range of **pretrained** encoders (also known as backbones) for segmentation models. Instead of using features from the final layer of a classification model, we extract **intermediate features** and feed them into the decoder for segmentation tasks. -
-Inception -
+All encoders come with **pretrained weights**, which help achieve **faster and more stable convergence** when training segmentation models. -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|inceptionresnetv2 |imagenet / imagenet+background |54M | -|inceptionv4 |imagenet / imagenet+background |41M | -|xception |imagenet |22M | +Given the extensive selection of supported encoders, you can choose the best one for your specific use case, for example: +- **Lightweight encoders** for low-latency applications or real-time inference on edge devices (mobilenet/mobileone). +- **High-capacity architectures** for complex tasks involving a large number of segmented classes, providing superior accuracy (convnext/swin/mit). -
-
- -
-EfficientNet -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|efficientnet-b0 |imagenet |4M | -|efficientnet-b1 |imagenet |6M | -|efficientnet-b2 |imagenet |7M | -|efficientnet-b3 |imagenet |10M | -|efficientnet-b4 |imagenet |17M | -|efficientnet-b5 |imagenet |28M | -|efficientnet-b6 |imagenet |40M | -|efficientnet-b7 |imagenet |63M | -|timm-efficientnet-b0 |imagenet / advprop / noisy-student|4M | -|timm-efficientnet-b1 |imagenet / advprop / noisy-student|6M | -|timm-efficientnet-b2 |imagenet / advprop / noisy-student|7M | -|timm-efficientnet-b3 |imagenet / advprop / noisy-student|10M | -|timm-efficientnet-b4 |imagenet / advprop / noisy-student|17M | -|timm-efficientnet-b5 |imagenet / advprop / noisy-student|28M | -|timm-efficientnet-b6 |imagenet / advprop / noisy-student|40M | -|timm-efficientnet-b7 |imagenet / advprop / noisy-student|63M | -|timm-efficientnet-b8 |imagenet / advprop |84M | -|timm-efficientnet-l2 |noisy-student |474M | -|timm-efficientnet-lite0 |imagenet |4M | -|timm-efficientnet-lite1 |imagenet |5M | -|timm-efficientnet-lite2 |imagenet |6M | -|timm-efficientnet-lite3 |imagenet |8M | -|timm-efficientnet-lite4 |imagenet |13M | +By selecting the right encoder, you can balance **efficiency, performance, and model complexity** to suit your project needs. -
-
- -
-MobileNet -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|mobilenet_v2 |imagenet |2M | -|timm-mobilenetv3_large_075 |imagenet |1.78M | -|timm-mobilenetv3_large_100 |imagenet |2.97M | -|timm-mobilenetv3_large_minimal_100|imagenet |1.41M | -|timm-mobilenetv3_small_075 |imagenet |0.57M | -|timm-mobilenetv3_small_100 |imagenet |0.93M | -|timm-mobilenetv3_small_minimal_100|imagenet |0.43M | +All encoders and corresponding pretrained weight are listed in the documentation: + - [table](https://smp.readthedocs.io/en/latest/encoders.html) with natively ported encoders + - [table](https://smp.readthedocs.io/en/latest/encoders_timm.html) with [timm](https://github.com/huggingface/pytorch-image-models) encoders supported -
-
+## πŸ” Models API -
-DPN -
+### Input channels -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|dpn68 |imagenet |11M | -|dpn68b |imagenet+5k |11M | -|dpn92 |imagenet+5k |34M | -|dpn98 |imagenet |58M | -|dpn107 |imagenet+5k |84M | -|dpn131 |imagenet |76M | +The input channels parameter allows you to create a model that can process a tensor with an arbitrary number of channels. +If you use pretrained weights from ImageNet, the weights of the first convolution will be reused: + - For the 1-channel case, it would be a sum of the weights of the first convolution layer. + - Otherwise, channels would be populated with weights like `new_weight[:, i] = pretrained_weight[:, i % 3]`, and then scaled with `new_weight * 3 / new_in_channels`. -
-
- -
-VGG -
- -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|vgg11 |imagenet |9M | -|vgg11_bn |imagenet |9M | -|vgg13 |imagenet |9M | -|vgg13_bn |imagenet |9M | -|vgg16 |imagenet |14M | -|vgg16_bn |imagenet |14M | -|vgg19 |imagenet |20M | -|vgg19_bn |imagenet |20M | - -
-
- -
-Mix Vision Transformer -
- -Backbone from SegFormer pretrained on Imagenet! Can be used with other decoders from package, you can combine Mix Vision Transformer with Unet, FPN and others! - -Limitations: - - - encoder is **not** supported by Linknet, Unet++ - - encoder is supported by FPN only for encoder **depth = 5** - -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|mit_b0 |imagenet |3M | -|mit_b1 |imagenet |13M | -|mit_b2 |imagenet |24M | -|mit_b3 |imagenet |44M | -|mit_b4 |imagenet |60M | -|mit_b5 |imagenet |81M | - -
-
- -
-MobileOne -
- -Apple's "sub-one-ms" Backbone pretrained on Imagenet! Can be used with all decoders. - -Note: In the official github repo the s0 variant has additional num_conv_branches, leading to more params than s1. - -|Encoder |Weights |Params, M | -|--------------------------------|:------------------------------:|:------------------------------:| -|mobileone_s0 |imagenet |4.6M | -|mobileone_s1 |imagenet |4.0M | -|mobileone_s2 |imagenet |6.5M | -|mobileone_s3 |imagenet |8.8M | -|mobileone_s4 |imagenet |13.6M | - -
-
- - -\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)). - -#### Timm Encoders - -[docs](https://smp.readthedocs.io/en/latest/encoders_timm.html) - -Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported - - - not all transformer models have ``features_only`` functionality implemented that is required for encoder - - some models have inappropriate strides - -Total number of supported encoders: 549 - - [table with available encoders](https://smp.readthedocs.io/en/latest/encoders_timm.html) - -### πŸ” Models API - - - `model.encoder` - pretrained backbone to extract features of different spatial resolution - - `model.decoder` - depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`) - - `model.segmentation_head` - last block to produce required number of mask channels (include also optional upsampling and activation) - - `model.classification_head` - optional block which create classification head on top of encoder - - `model.forward(x)` - sequentially pass `x` through model\`s encoder, decoder and segmentation head (and classification head if specified) - -##### Input channels -Input channels parameter allows you to create models, which process tensors with arbitrary number of channels. -If you use pretrained weights from imagenet - weights of first convolution will be reused. For -1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be -populated with weights like `new_weight[:, i] = pretrained_weight[:, i % 3]` and than scaled with `new_weight * 3 / new_in_channels`. ```python model = smp.FPN('resnet34', in_channels=1) mask = model(torch.ones([1, 1, 64, 64])) ``` -##### Auxiliary classification output +### Auxiliary classification output + All models support `aux_params` parameters, which is default set to `None`. If `aux_params = None` then classification auxiliary output is not created, else model produce not only `mask`, but also `label` output with shape `NC`. @@ -447,50 +177,54 @@ model = smp.Unet('resnet34', classes=4, aux_params=aux_params) mask, label = model(x) ``` -##### Depth +### Depth + Depth parameter specify a number of downsampling operations in encoder, so you can make your model lighter if specify smaller `depth`. ```python model = smp.Unet('resnet34', encoder_depth=4) ``` - -### πŸ›  Installation +## πŸ›  Installation PyPI version: + ```bash $ pip install segmentation-models-pytorch ```` -Latest version from source: + +The latest version from GitHub: + ```bash $ pip install git+https://github.com/qubvel/segmentation_models.pytorch ```` -### πŸ† Competitions won with the library +## πŸ† Competitions won with the library -`Segmentation Models` package is widely used in the image segmentation competitions. +`Segmentation Models` package is widely used in image segmentation competitions. [Here](https://github.com/qubvel/segmentation_models.pytorch/blob/main/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions. -### 🀝 Contributing +## 🀝 Contributing -#### Install SMP +1. Install SMP in dev mode ```bash -make install_dev # create .venv, install SMP in dev mode +make install_dev # Create .venv, install SMP in dev mode ``` -#### Run tests and code checks +2. Run tests and code checks ```bash +make test # Run tests suite with pytest make fixup # Ruff for formatting and lint checks ``` -#### Update table with encoders +3. Update a table (in case you added an encoder) ```bash -make table # generate a table with encoders and print to stdout +make table # Generates a table with encoders and print to stdout ``` -### πŸ“ Citing +## πŸ“ Citing ``` @misc{Iakubovskii:2019, Author = {Pavel Iakubovskii}, @@ -502,5 +236,5 @@ make table # generate a table with encoders and print to stdout } ``` -### πŸ›‘οΈ License +## πŸ›‘οΈ License The project is primarily distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/main/LICENSE), while some files are subject to other licenses. Please refer to [LICENSES](licenses/LICENSES.md) and license statements in each file for careful check, especially for commercial use. diff --git a/docs/conf.py b/docs/conf.py index c7dde9e5..4cc70a6b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,9 +100,7 @@ def get_version(): "timm", "cv2", "PIL", - "pretrainedmodels", "torchvision", - "efficientnet-pytorch", "segmentation_models_pytorch.encoders", "segmentation_models_pytorch.utils", # 'segmentation_models_pytorch.base', diff --git a/docs/encoders.rst b/docs/encoders.rst index 652745b7..2de35dec 100644 --- a/docs/encoders.rst +++ b/docs/encoders.rst @@ -1,363 +1,141 @@ πŸ” Available Encoders ===================== -ResNet -~~~~~~ - -+-------------+-------------------------+-------------+ -| Encoder | Weights | Params, M | -+=============+=========================+=============+ -| resnet18 | imagenet / ssl / swsl | 11M | -+-------------+-------------------------+-------------+ -| resnet34 | imagenet | 21M | -+-------------+-------------------------+-------------+ -| resnet50 | imagenet / ssl / swsl | 23M | -+-------------+-------------------------+-------------+ -| resnet101 | imagenet | 42M | -+-------------+-------------------------+-------------+ -| resnet152 | imagenet | 58M | -+-------------+-------------------------+-------------+ - -ResNeXt -~~~~~~~ - -+----------------------+-------------------------------------+-------------+ -| Encoder | Weights | Params, M | -+======================+=====================================+=============+ -| resnext50\_32x4d | imagenet / ssl / swsl | 22M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x4d | ssl / swsl | 42M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x8d | imagenet / instagram / ssl / swsl | 86M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x16d | instagram / ssl / swsl | 191M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x32d | instagram | 466M | -+----------------------+-------------------------------------+-------------+ -| resnext101\_32x48d | instagram | 826M | -+----------------------+-------------------------------------+-------------+ - -ResNeSt -~~~~~~~ - -+----------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+============================+============+=============+ -| timm-resnest14d | imagenet | 8M | -+----------------------------+------------+-------------+ -| timm-resnest26d | imagenet | 15M | -+----------------------------+------------+-------------+ -| timm-resnest50d | imagenet | 25M | -+----------------------------+------------+-------------+ -| timm-resnest101e | imagenet | 46M | -+----------------------------+------------+-------------+ -| timm-resnest200e | imagenet | 68M | -+----------------------------+------------+-------------+ -| timm-resnest269e | imagenet | 108M | -+----------------------------+------------+-------------+ -| timm-resnest50d\_4s2x40d | imagenet | 28M | -+----------------------------+------------+-------------+ -| timm-resnest50d\_1s4x24d | imagenet | 23M | -+----------------------------+------------+-------------+ - -Res2Ne(X)t -~~~~~~~~~~ - -+----------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+============================+============+=============+ -| timm-res2net50\_26w\_4s | imagenet | 23M | -+----------------------------+------------+-------------+ -| timm-res2net101\_26w\_4s | imagenet | 43M | -+----------------------------+------------+-------------+ -| timm-res2net50\_26w\_6s | imagenet | 35M | -+----------------------------+------------+-------------+ -| timm-res2net50\_26w\_8s | imagenet | 46M | -+----------------------------+------------+-------------+ -| timm-res2net50\_48w\_2s | imagenet | 23M | -+----------------------------+------------+-------------+ -| timm-res2net50\_14w\_8s | imagenet | 23M | -+----------------------------+------------+-------------+ -| timm-res2next50 | imagenet | 22M | -+----------------------------+------------+-------------+ - -RegNet(x/y) -~~~~~~~~~~ - -+---------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=====================+============+=============+ -| timm-regnetx\_002 | imagenet | 2M | -+---------------------+------------+-------------+ -| timm-regnetx\_004 | imagenet | 4M | -+---------------------+------------+-------------+ -| timm-regnetx\_006 | imagenet | 5M | -+---------------------+------------+-------------+ -| timm-regnetx\_008 | imagenet | 6M | -+---------------------+------------+-------------+ -| timm-regnetx\_016 | imagenet | 8M | -+---------------------+------------+-------------+ -| timm-regnetx\_032 | imagenet | 14M | -+---------------------+------------+-------------+ -| timm-regnetx\_040 | imagenet | 20M | -+---------------------+------------+-------------+ -| timm-regnetx\_064 | imagenet | 24M | -+---------------------+------------+-------------+ -| timm-regnetx\_080 | imagenet | 37M | -+---------------------+------------+-------------+ -| timm-regnetx\_120 | imagenet | 43M | -+---------------------+------------+-------------+ -| timm-regnetx\_160 | imagenet | 52M | -+---------------------+------------+-------------+ -| timm-regnetx\_320 | imagenet | 105M | -+---------------------+------------+-------------+ -| timm-regnety\_002 | imagenet | 2M | -+---------------------+------------+-------------+ -| timm-regnety\_004 | imagenet | 3M | -+---------------------+------------+-------------+ -| timm-regnety\_006 | imagenet | 5M | -+---------------------+------------+-------------+ -| timm-regnety\_008 | imagenet | 5M | -+---------------------+------------+-------------+ -| timm-regnety\_016 | imagenet | 10M | -+---------------------+------------+-------------+ -| timm-regnety\_032 | imagenet | 17M | -+---------------------+------------+-------------+ -| timm-regnety\_040 | imagenet | 19M | -+---------------------+------------+-------------+ -| timm-regnety\_064 | imagenet | 29M | -+---------------------+------------+-------------+ -| timm-regnety\_080 | imagenet | 37M | -+---------------------+------------+-------------+ -| timm-regnety\_120 | imagenet | 49M | -+---------------------+------------+-------------+ -| timm-regnety\_160 | imagenet | 80M | -+---------------------+------------+-------------+ -| timm-regnety\_320 | imagenet | 141M | -+---------------------+------------+-------------+ - -GERNet -~~~~~~ - -+-------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=========================+============+=============+ -| timm-gernet\_s | imagenet | 6M | -+-------------------------+------------+-------------+ -| timm-gernet\_m | imagenet | 18M | -+-------------------------+------------+-------------+ -| timm-gernet\_l | imagenet | 28M | -+-------------------------+------------+-------------+ - -SE-Net -~~~~~~ - -+-------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=========================+============+=============+ -| senet154 | imagenet | 113M | -+-------------------------+------------+-------------+ -| se\_resnet50 | imagenet | 26M | -+-------------------------+------------+-------------+ -| se\_resnet101 | imagenet | 47M | -+-------------------------+------------+-------------+ -| se\_resnet152 | imagenet | 64M | -+-------------------------+------------+-------------+ -| se\_resnext50\_32x4d | imagenet | 25M | -+-------------------------+------------+-------------+ -| se\_resnext101\_32x4d | imagenet | 46M | -+-------------------------+------------+-------------+ - -SK-ResNe(X)t -~~~~~~~~~~~~ - -+---------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+===========================+============+=============+ -| timm-skresnet18 | imagenet | 11M | -+---------------------------+------------+-------------+ -| timm-skresnet34 | imagenet | 21M | -+---------------------------+------------+-------------+ -| timm-skresnext50\_32x4d | imagenet | 25M | -+---------------------------+------------+-------------+ - -DenseNet -~~~~~~~~ - -+---------------+------------+-------------+ -| Encoder | Weights | Params, M | -+===============+============+=============+ -| densenet121 | imagenet | 6M | -+---------------+------------+-------------+ -| densenet169 | imagenet | 12M | -+---------------+------------+-------------+ -| densenet201 | imagenet | 18M | -+---------------+------------+-------------+ -| densenet161 | imagenet | 26M | -+---------------+------------+-------------+ - -Inception -~~~~~~~~~ - -+---------------------+----------------------------------+-------------+ -| Encoder | Weights | Params, M | -+=====================+==================================+=============+ -| inceptionresnetv2 | imagenet / imagenet+background | 54M | -+---------------------+----------------------------------+-------------+ -| inceptionv4 | imagenet / imagenet+background | 41M | -+---------------------+----------------------------------+-------------+ -| xception | imagenet | 22M | -+---------------------+----------------------------------+-------------+ - -EfficientNet -~~~~~~~~~~~~ - -+------------------------+--------------------------------------+-------------+ -| Encoder | Weights | Params, M | -+========================+======================================+=============+ -| efficientnet-b0 | imagenet | 4M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b1 | imagenet | 6M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b2 | imagenet | 7M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b3 | imagenet | 10M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b4 | imagenet | 17M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b5 | imagenet | 28M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b6 | imagenet | 40M | -+------------------------+--------------------------------------+-------------+ -| efficientnet-b7 | imagenet | 63M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b0 | imagenet / advprop / noisy-student | 4M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b1 | imagenet / advprop / noisy-student | 6M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b2 | imagenet / advprop / noisy-student | 7M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b3 | imagenet / advprop / noisy-student | 10M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b4 | imagenet / advprop / noisy-student | 17M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b5 | imagenet / advprop / noisy-student | 28M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b6 | imagenet / advprop / noisy-student | 40M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b7 | imagenet / advprop / noisy-student | 63M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-b8 | imagenet / advprop | 84M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-l2 | noisy-student / noisy-student-475 | 474M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite0| imagenet | 4M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite1| imagenet | 4M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite2| imagenet | 6M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite3| imagenet | 8M | -+------------------------+--------------------------------------+-------------+ -| timm-efficientnet-lite4| imagenet | 13M | -+------------------------+--------------------------------------+-------------+ - -MobileNet -~~~~~~~~~ - -+---------------------------------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=======================================+============+=============+ -| mobilenet\_v2 | imagenet | 2M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_large\_075 | imagenet | 1.78M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_large\_100 | imagenet | 2.97M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_large\_minimal\_100 | imagenet | 1.41M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_small\_075 | imagenet | 0.57M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_small\_100 | imagenet | 0.93M | -+---------------------------------------+------------+-------------+ -| timm-mobilenetv3\_small\_minimal\_100 | imagenet | 0.43M | -+---------------------------------------+------------+-------------+ - -DPN -~~~ - -+-----------+---------------+-------------+ -| Encoder | Weights | Params, M | -+===========+===============+=============+ -| dpn68 | imagenet | 11M | -+-----------+---------------+-------------+ -| dpn68b | imagenet+5k | 11M | -+-----------+---------------+-------------+ -| dpn92 | imagenet+5k | 34M | -+-----------+---------------+-------------+ -| dpn98 | imagenet | 58M | -+-----------+---------------+-------------+ -| dpn107 | imagenet+5k | 84M | -+-----------+---------------+-------------+ -| dpn131 | imagenet | 76M | -+-----------+---------------+-------------+ - -VGG -~~~ - -+-------------+------------+-------------+ -| Encoder | Weights | Params, M | -+=============+============+=============+ -| vgg11 | imagenet | 9M | -+-------------+------------+-------------+ -| vgg11\_bn | imagenet | 9M | -+-------------+------------+-------------+ -| vgg13 | imagenet | 9M | -+-------------+------------+-------------+ -| vgg13\_bn | imagenet | 9M | -+-------------+------------+-------------+ -| vgg16 | imagenet | 14M | -+-------------+------------+-------------+ -| vgg16\_bn | imagenet | 14M | -+-------------+------------+-------------+ -| vgg19 | imagenet | 20M | -+-------------+------------+-------------+ -| vgg19\_bn | imagenet | 20M | -+-------------+------------+-------------+ - - -Mix Visual Transformer -~~~~~~~~~~~~~~~~~~~~~ - -+-----------+----------+------------+ -| Encoder | Weights | Params, M | -+===========+==========+============+ -| mit\_b0 | imagenet | 3M | -+-----------+----------+------------+ -| mit\_b1 | imagenet | 13M | -+-----------+----------+------------+ -| mit\_b2 | imagenet | 24M | -+-----------+----------+------------+ -| mit\_b3 | imagenet | 44M | -+-----------+----------+------------+ -| mit\_b4 | imagenet | 60M | -+-----------+----------+------------+ -| mit\_b5 | imagenet | 81M | -+-----------+----------+------------+ - -MobileOne -~~~~~~~~~~~~~~~~~~~~~ - -+-----------------+----------+------------+ -| Encoder | Weights | Params, M | -+=================+==========+============+ -| mobileone\_s0 | imagenet | 4.6M | -+-----------------+----------+------------+ -| mobileone\_s1 | imagenet | 4.0M | -+-----------------+----------+------------+ -| mobileone\_s2 | imagenet | 6.5M | -+-----------------+----------+------------+ -| mobileone\_s3 | imagenet | 8.8M | -+-----------------+----------+------------+ -| mobileone\_s4 | imagenet | 13.6M | -+-----------------+----------+------------+ +**Segmentation Models PyTorch** provides support for a wide range of encoders. +This flexibility allows you to use these encoders with any model in the library by +specifying the encoder name in the ``encoder_name`` parameter during model initialization. + +Here’s a quick example of using a ResNet34 encoder with the ``Unet`` model: + +.. code-block:: python + + from segmentation_models_pytorch import Unet + + # Initialize Unet with ResNet34 encoder pre-trained on ImageNet + model = Unet(encoder_name="resnet34", encoder_weights="imagenet") + + +The following encoder families are supported by the library, enabling you to choose the one that best fits your use case: + +- **Mix Vision Transformer (mit)** +- **MobileOne** +- **MobileNet** +- **EfficientNet** +- **ResNet** +- **ResNeXt** +- **SENet** +- **DPN** +- **VGG** +- **DenseNet** +- **Xception** +- **Inception** + +Choosing the Right Encoder +-------------------------- + +1. **Small Models for Edge Devices** + Consider encoders like **MobileNet** or **MobileOne**, which have a smaller parameter count and are optimized for lightweight deployment. + +2. **High Performance** + If you require state-of-the-art accuracy **Mix Vision Transformer (mit)**, **EfficientNet** families offer excellent balance between performance and computational efficiency. + +For each encoder, the table below provides detailed information: + +1. **Pretrained Weights** + Specifies the available pretrained weights (e.g., ``imagenet``, ``imagenet21k``). + +2. **Params, M**: + The total number of parameters in the encoder, measured in millions. This metric helps you assess the model's size and computational requirements. + +3. **Script**: + Indicates whether the encoder can be scripted with ``torch.jit.script``. + +4. **Compile**: + Indicates whether the encoder is compatible with ``torch.compile(model, fullgraph=True, dynamic=True, backend="eager")``. + You may still get some issues with another backends, such as ``inductor``, depending on the torch/cuda/... dependencies version, + but most of the time it will work. + +5. **Export**: + Indicates whether the encoder can be exported using ``torch.export.export``, making it suitable for deployment in different environments (e.g., ONNX). + + +============================ ==================================== =========== ======== ========= ======== +Encoder Pretrained weights Params, M Script Compile Export +============================ ==================================== =========== ======== ========= ======== +resnet18 imagenet / ssl / swsl 11M βœ… βœ… βœ… +resnet34 imagenet 21M βœ… βœ… βœ… +resnet50 imagenet / ssl / swsl 23M βœ… βœ… βœ… +resnet101 imagenet 42M βœ… βœ… βœ… +resnet152 imagenet 58M βœ… βœ… βœ… +resnext50_32x4d imagenet / ssl / swsl 22M βœ… βœ… βœ… +resnext101_32x4d ssl / swsl 42M βœ… βœ… βœ… +resnext101_32x8d imagenet / instagram / ssl / swsl 86M βœ… βœ… βœ… +resnext101_32x16d instagram / ssl / swsl 191M βœ… βœ… βœ… +resnext101_32x32d instagram 466M βœ… βœ… βœ… +resnext101_32x48d instagram 826M βœ… βœ… βœ… +dpn68 imagenet 11M ❌ βœ… βœ… +dpn68b imagenet+5k 11M ❌ βœ… βœ… +dpn92 imagenet+5k 34M ❌ βœ… βœ… +dpn98 imagenet 58M ❌ βœ… βœ… +dpn107 imagenet+5k 84M ❌ βœ… βœ… +dpn131 imagenet 76M ❌ βœ… βœ… +vgg11 imagenet 9M βœ… βœ… βœ… +vgg11_bn imagenet 9M βœ… βœ… βœ… +vgg13 imagenet 9M βœ… βœ… βœ… +vgg13_bn imagenet 9M βœ… βœ… βœ… +vgg16 imagenet 14M βœ… βœ… βœ… +vgg16_bn imagenet 14M βœ… βœ… βœ… +vgg19 imagenet 20M βœ… βœ… βœ… +vgg19_bn imagenet 20M βœ… βœ… βœ… +senet154 imagenet 113M βœ… βœ… βœ… +se_resnet50 imagenet 26M βœ… βœ… βœ… +se_resnet101 imagenet 47M βœ… βœ… βœ… +se_resnet152 imagenet 64M βœ… βœ… βœ… +se_resnext50_32x4d imagenet 25M βœ… βœ… βœ… +se_resnext101_32x4d imagenet 46M βœ… βœ… βœ… +densenet121 imagenet 6M βœ… βœ… βœ… +densenet169 imagenet 12M βœ… βœ… βœ… +densenet201 imagenet 18M βœ… βœ… βœ… +densenet161 imagenet 26M βœ… βœ… βœ… +inceptionresnetv2 imagenet / imagenet+background 54M βœ… βœ… βœ… +inceptionv4 imagenet / imagenet+background 41M βœ… βœ… βœ… +efficientnet-b0 imagenet / advprop 4M βœ… βœ… βœ… +efficientnet-b1 imagenet / advprop 6M βœ… βœ… βœ… +efficientnet-b2 imagenet / advprop 7M βœ… βœ… βœ… +efficientnet-b3 imagenet / advprop 10M βœ… βœ… βœ… +efficientnet-b4 imagenet / advprop 17M βœ… βœ… βœ… +efficientnet-b5 imagenet / advprop 28M βœ… βœ… βœ… +efficientnet-b6 imagenet / advprop 40M βœ… βœ… βœ… +efficientnet-b7 imagenet / advprop 63M βœ… βœ… βœ… +mobilenet_v2 imagenet 2M βœ… βœ… βœ… +xception imagenet 20M βœ… βœ… βœ… +timm-efficientnet-b0 imagenet / advprop / noisy-student 4M βœ… βœ… βœ… +timm-efficientnet-b1 imagenet / advprop / noisy-student 6M βœ… βœ… βœ… +timm-efficientnet-b2 imagenet / advprop / noisy-student 7M βœ… βœ… βœ… +timm-efficientnet-b3 imagenet / advprop / noisy-student 10M βœ… βœ… βœ… +timm-efficientnet-b4 imagenet / advprop / noisy-student 17M βœ… βœ… βœ… +timm-efficientnet-b5 imagenet / advprop / noisy-student 28M βœ… βœ… βœ… +timm-efficientnet-b6 imagenet / advprop / noisy-student 40M βœ… βœ… βœ… +timm-efficientnet-b7 imagenet / advprop / noisy-student 63M βœ… βœ… βœ… +timm-efficientnet-b8 imagenet / advprop 84M βœ… βœ… βœ… +timm-efficientnet-l2 noisy-student / noisy-student-475 474M βœ… βœ… βœ… +timm-tf_efficientnet_lite0 imagenet 3M βœ… βœ… βœ… +timm-tf_efficientnet_lite1 imagenet 4M βœ… βœ… βœ… +timm-tf_efficientnet_lite2 imagenet 4M βœ… βœ… βœ… +timm-tf_efficientnet_lite3 imagenet 6M βœ… βœ… βœ… +timm-tf_efficientnet_lite4 imagenet 11M βœ… βœ… βœ… +timm-skresnet18 imagenet 11M βœ… βœ… βœ… +timm-skresnet34 imagenet 21M βœ… βœ… βœ… +timm-skresnext50_32x4d imagenet 23M βœ… βœ… βœ… +mit_b0 imagenet 3M βœ… βœ… βœ… +mit_b1 imagenet 13M βœ… βœ… βœ… +mit_b2 imagenet 24M βœ… βœ… βœ… +mit_b3 imagenet 44M βœ… βœ… βœ… +mit_b4 imagenet 60M βœ… βœ… βœ… +mit_b5 imagenet 81M βœ… βœ… βœ… +mobileone_s0 imagenet 4M βœ… βœ… βœ… +mobileone_s1 imagenet 3M βœ… βœ… βœ… +mobileone_s2 imagenet 5M βœ… βœ… βœ… +mobileone_s3 imagenet 8M βœ… βœ… βœ… +mobileone_s4 imagenet 12M βœ… βœ… βœ… +============================ ==================================== =========== ======== ========= ======== diff --git a/docs/encoders_dpt.rst b/docs/encoders_dpt.rst new file mode 100644 index 00000000..9ce3af31 --- /dev/null +++ b/docs/encoders_dpt.rst @@ -0,0 +1,461 @@ +.. _dpt-encoders: + +DPT Encoders +============ + +This is a list of Vision Transformer encoders that are compatible with the DPT architecture. +While other Vision Transformer encoders from timm may also be compatible, the ones listed below are tested to work properly. + +.. list-table:: Encoder Name + :widths: 100 + :header-rows: 0 + + * - tu-fastvit_ma36.apple_dist_in1k + * - tu-fastvit_ma36.apple_in1k + * - tu-fastvit_mci0.apple_mclip + * - tu-fastvit_mci1.apple_mclip + * - tu-fastvit_mci2.apple_mclip + * - tu-fastvit_s12.apple_dist_in1k + * - tu-fastvit_s12.apple_in1k + * - tu-fastvit_sa12.apple_dist_in1k + * - tu-fastvit_sa12.apple_in1k + * - tu-fastvit_sa24.apple_dist_in1k + * - tu-fastvit_sa24.apple_in1k + * - tu-fastvit_sa36.apple_dist_in1k + * - tu-fastvit_sa36.apple_in1k + * - tu-fastvit_t8.apple_dist_in1k + * - tu-fastvit_t8.apple_in1k + * - tu-fastvit_t12.apple_dist_in1k + * - tu-fastvit_t12.apple_in1k + * - tu-flexivit_base.300ep_in1k + * - tu-flexivit_base.300ep_in21k + * - tu-flexivit_base.600ep_in1k + * - tu-flexivit_base.1000ep_in21k + * - tu-flexivit_base.1200ep_in1k + * - tu-flexivit_base.patch16_in21k + * - tu-flexivit_base.patch30_in21k + * - tu-flexivit_large.300ep_in1k + * - tu-flexivit_large.600ep_in1k + * - tu-flexivit_large.1200ep_in1k + * - tu-flexivit_small.300ep_in1k + * - tu-flexivit_small.600ep_in1k + * - tu-flexivit_small.1200ep_in1k + * - tu-maxvit_base_tf_224.in1k + * - tu-maxvit_base_tf_224.in21k + * - tu-maxvit_base_tf_384.in1k + * - tu-maxvit_base_tf_384.in21k_ft_in1k + * - tu-maxvit_base_tf_512.in1k + * - tu-maxvit_base_tf_512.in21k_ft_in1k + * - tu-maxvit_large_tf_224.in1k + * - tu-maxvit_large_tf_224.in21k + * - tu-maxvit_large_tf_384.in1k + * - tu-maxvit_large_tf_384.in21k_ft_in1k + * - tu-maxvit_large_tf_512.in1k + * - tu-maxvit_large_tf_512.in21k_ft_in1k + * - tu-maxvit_nano_rw_256.sw_in1k + * - tu-maxvit_rmlp_base_rw_224.sw_in12k + * - tu-maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k + * - tu-maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k + * - tu-maxvit_rmlp_nano_rw_256.sw_in1k + * - tu-maxvit_rmlp_pico_rw_256.sw_in1k + * - tu-maxvit_rmlp_small_rw_224.sw_in1k + * - tu-maxvit_rmlp_tiny_rw_256.sw_in1k + * - tu-maxvit_small_tf_224.in1k + * - tu-maxvit_small_tf_384.in1k + * - tu-maxvit_small_tf_512.in1k + * - tu-maxvit_tiny_rw_224.sw_in1k + * - tu-maxvit_tiny_tf_224.in1k + * - tu-maxvit_tiny_tf_384.in1k + * - tu-maxvit_tiny_tf_512.in1k + * - tu-maxvit_xlarge_tf_224.in21k + * - tu-maxvit_xlarge_tf_384.in21k_ft_in1k + * - tu-maxvit_xlarge_tf_512.in21k_ft_in1k + * - tu-maxxvit_rmlp_nano_rw_256.sw_in1k + * - tu-maxxvit_rmlp_small_rw_256.sw_in1k + * - tu-maxxvitv2_nano_rw_256.sw_in1k + * - tu-maxxvitv2_rmlp_base_rw_224.sw_in12k + * - tu-maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k + * - tu-maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k + * - tu-mobilevit_s.cvnets_in1k + * - tu-mobilevit_xs.cvnets_in1k + * - tu-mobilevit_xxs.cvnets_in1k + * - tu-mobilevitv2_050.cvnets_in1k + * - tu-mobilevitv2_075.cvnets_in1k + * - tu-mobilevitv2_100.cvnets_in1k + * - tu-mobilevitv2_125.cvnets_in1k + * - tu-mobilevitv2_150.cvnets_in1k + * - tu-mobilevitv2_150.cvnets_in22k_ft_in1k + * - tu-mobilevitv2_150.cvnets_in22k_ft_in1k_384 + * - tu-mobilevitv2_175.cvnets_in1k + * - tu-mobilevitv2_175.cvnets_in22k_ft_in1k + * - tu-mobilevitv2_175.cvnets_in22k_ft_in1k_384 + * - tu-mobilevitv2_200.cvnets_in1k + * - tu-mobilevitv2_200.cvnets_in22k_ft_in1k + * - tu-mobilevitv2_200.cvnets_in22k_ft_in1k_384 + * - tu-mvitv2_base.fb_in1k + * - tu-mvitv2_base_cls.fb_inw21k + * - tu-mvitv2_huge_cls.fb_inw21k + * - tu-mvitv2_large.fb_in1k + * - tu-mvitv2_large_cls.fb_inw21k + * - tu-mvitv2_small.fb_in1k + * - tu-mvitv2_tiny.fb_in1k + * - tu-samvit_base_patch16.sa1b + * - tu-samvit_huge_patch16.sa1b + * - tu-samvit_large_patch16.sa1b + * - tu-test_vit2.r160_in1k + * - tu-test_vit3.r160_in1k + * - tu-test_vit.r160_in1k + * - tu-vit_base_mci_224.apple_mclip + * - tu-vit_base_mci_224.apple_mclip_lt + * - tu-vit_base_patch8_224.augreg2_in21k_ft_in1k + * - tu-vit_base_patch8_224.augreg_in21k + * - tu-vit_base_patch8_224.augreg_in21k_ft_in1k + * - tu-vit_base_patch8_224.dino + * - tu-vit_base_patch16_224.augreg2_in21k_ft_in1k + * - tu-vit_base_patch16_224.augreg_in1k + * - tu-vit_base_patch16_224.augreg_in21k + * - tu-vit_base_patch16_224.augreg_in21k_ft_in1k + * - tu-vit_base_patch16_224.dino + * - tu-vit_base_patch16_224.mae + * - tu-vit_base_patch16_224.orig_in21k + * - tu-vit_base_patch16_224.orig_in21k_ft_in1k + * - tu-vit_base_patch16_224.sam_in1k + * - tu-vit_base_patch16_224_miil.in21k + * - tu-vit_base_patch16_224_miil.in21k_ft_in1k + * - tu-vit_base_patch16_384.augreg_in1k + * - tu-vit_base_patch16_384.augreg_in21k_ft_in1k + * - tu-vit_base_patch16_384.orig_in21k_ft_in1k + * - tu-vit_base_patch16_clip_224.datacompxl + * - tu-vit_base_patch16_clip_224.dfn2b + * - tu-vit_base_patch16_clip_224.laion2b + * - tu-vit_base_patch16_clip_224.laion2b_ft_in1k + * - tu-vit_base_patch16_clip_224.laion2b_ft_in12k + * - tu-vit_base_patch16_clip_224.laion2b_ft_in12k_in1k + * - tu-vit_base_patch16_clip_224.laion400m_e32 + * - tu-vit_base_patch16_clip_224.metaclip_2pt5b + * - tu-vit_base_patch16_clip_224.metaclip_400m + * - tu-vit_base_patch16_clip_224.openai + * - tu-vit_base_patch16_clip_224.openai_ft_in1k + * - tu-vit_base_patch16_clip_224.openai_ft_in12k + * - tu-vit_base_patch16_clip_224.openai_ft_in12k_in1k + * - tu-vit_base_patch16_clip_384.laion2b_ft_in1k + * - tu-vit_base_patch16_clip_384.laion2b_ft_in12k_in1k + * - tu-vit_base_patch16_clip_384.openai_ft_in1k + * - tu-vit_base_patch16_clip_384.openai_ft_in12k_in1k + * - tu-vit_base_patch16_clip_quickgelu_224.metaclip_2pt5b + * - tu-vit_base_patch16_clip_quickgelu_224.metaclip_400m + * - tu-vit_base_patch16_clip_quickgelu_224.openai + * - tu-vit_base_patch16_plus_clip_240.laion400m_e32 + * - tu-vit_base_patch16_rope_reg1_gap_256.sbb_in1k + * - tu-vit_base_patch16_rpn_224.sw_in1k + * - tu-vit_base_patch16_siglip_224.v2_webli + * - tu-vit_base_patch16_siglip_224.webli + * - tu-vit_base_patch16_siglip_256.v2_webli + * - tu-vit_base_patch16_siglip_256.webli + * - tu-vit_base_patch16_siglip_256.webli_i18n + * - tu-vit_base_patch16_siglip_384.v2_webli + * - tu-vit_base_patch16_siglip_384.webli + * - tu-vit_base_patch16_siglip_512.v2_webli + * - tu-vit_base_patch16_siglip_512.webli + * - tu-vit_base_patch16_siglip_gap_224.v2_webli + * - tu-vit_base_patch16_siglip_gap_224.webli + * - tu-vit_base_patch16_siglip_gap_256.v2_webli + * - tu-vit_base_patch16_siglip_gap_256.webli + * - tu-vit_base_patch16_siglip_gap_256.webli_i18n + * - tu-vit_base_patch16_siglip_gap_384.v2_webli + * - tu-vit_base_patch16_siglip_gap_384.webli + * - tu-vit_base_patch16_siglip_gap_512.v2_webli + * - tu-vit_base_patch16_siglip_gap_512.webli + * - tu-vit_base_patch32_224.augreg_in1k + * - tu-vit_base_patch32_224.augreg_in21k + * - tu-vit_base_patch32_224.augreg_in21k_ft_in1k + * - tu-vit_base_patch32_224.orig_in21k + * - tu-vit_base_patch32_224.sam_in1k + * - tu-vit_base_patch32_384.augreg_in1k + * - tu-vit_base_patch32_384.augreg_in21k_ft_in1k + * - tu-vit_base_patch32_clip_224.datacompxl + * - tu-vit_base_patch32_clip_224.laion2b + * - tu-vit_base_patch32_clip_224.laion2b_ft_in1k + * - tu-vit_base_patch32_clip_224.laion2b_ft_in12k_in1k + * - tu-vit_base_patch32_clip_224.laion400m_e32 + * - tu-vit_base_patch32_clip_224.metaclip_2pt5b + * - tu-vit_base_patch32_clip_224.metaclip_400m + * - tu-vit_base_patch32_clip_224.openai + * - tu-vit_base_patch32_clip_224.openai_ft_in1k + * - tu-vit_base_patch32_clip_256.datacompxl + * - tu-vit_base_patch32_clip_384.laion2b_ft_in12k_in1k + * - tu-vit_base_patch32_clip_384.openai_ft_in12k_in1k + * - tu-vit_base_patch32_clip_448.laion2b_ft_in12k_in1k + * - tu-vit_base_patch32_clip_quickgelu_224.laion400m_e32 + * - tu-vit_base_patch32_clip_quickgelu_224.metaclip_2pt5b + * - tu-vit_base_patch32_clip_quickgelu_224.metaclip_400m + * - tu-vit_base_patch32_clip_quickgelu_224.openai + * - tu-vit_base_patch32_siglip_256.v2_webli + * - tu-vit_base_patch32_siglip_gap_256.v2_webli + * - tu-vit_base_r50_s16_224.orig_in21k + * - tu-vit_base_r50_s16_384.orig_in21k_ft_in1k + * - tu-vit_betwixt_patch16_reg1_gap_256.sbb_in1k + * - tu-vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k + * - tu-vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k + * - tu-vit_betwixt_patch16_reg4_gap_256.sbb_in1k + * - tu-vit_betwixt_patch16_reg4_gap_256.sbb_in12k + * - tu-vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k + * - tu-vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k + * - tu-vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k + * - tu-vit_betwixt_patch32_clip_224.tinyclip_laion400m + * - tu-vit_giant_patch16_gap_224.in22k_ijepa + * - tu-vit_giantopt_patch16_siglip_256.v2_webli + * - tu-vit_giantopt_patch16_siglip_384.v2_webli + * - tu-vit_giantopt_patch16_siglip_gap_256.v2_webli + * - tu-vit_giantopt_patch16_siglip_gap_384.v2_webli + * - tu-vit_huge_patch16_gap_448.in1k_ijepa + * - tu-vit_large_patch16_224.augreg_in21k + * - tu-vit_large_patch16_224.augreg_in21k_ft_in1k + * - tu-vit_large_patch16_224.mae + * - tu-vit_large_patch16_224.orig_in21k + * - tu-vit_large_patch16_384.augreg_in21k_ft_in1k + * - tu-vit_large_patch16_siglip_256.v2_webli + * - tu-vit_large_patch16_siglip_256.webli + * - tu-vit_large_patch16_siglip_384.v2_webli + * - tu-vit_large_patch16_siglip_384.webli + * - tu-vit_large_patch16_siglip_512.v2_webli + * - tu-vit_large_patch16_siglip_gap_256.v2_webli + * - tu-vit_large_patch16_siglip_gap_256.webli + * - tu-vit_large_patch16_siglip_gap_384.v2_webli + * - tu-vit_large_patch16_siglip_gap_384.webli + * - tu-vit_large_patch16_siglip_gap_512.v2_webli + * - tu-vit_large_patch32_224.orig_in21k + * - tu-vit_large_patch32_384.orig_in21k_ft_in1k + * - tu-vit_large_r50_s32_224.augreg_in21k + * - tu-vit_large_r50_s32_224.augreg_in21k_ft_in1k + * - tu-vit_large_r50_s32_384.augreg_in21k_ft_in1k + * - tu-vit_little_patch16_reg1_gap_256.sbb_in12k + * - tu-vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k + * - tu-vit_little_patch16_reg4_gap_256.sbb_in1k + * - tu-vit_medium_patch16_clip_224.tinyclip_yfcc15m + * - tu-vit_medium_patch16_gap_240.sw_in12k + * - tu-vit_medium_patch16_gap_256.sw_in12k_ft_in1k + * - tu-vit_medium_patch16_gap_384.sw_in12k_ft_in1k + * - tu-vit_medium_patch16_reg1_gap_256.sbb_in1k + * - tu-vit_medium_patch16_reg4_gap_256.sbb_in1k + * - tu-vit_medium_patch16_reg4_gap_256.sbb_in12k + * - tu-vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k + * - tu-vit_medium_patch16_rope_reg1_gap_256.sbb_in1k + * - tu-vit_medium_patch32_clip_224.tinyclip_laion400m + * - tu-vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k + * - tu-vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k + * - tu-vit_mediumd_patch16_reg4_gap_256.sbb_in12k + * - tu-vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k + * - tu-vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k + * - tu-vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k + * - tu-vit_pwee_patch16_reg1_gap_256.sbb_in1k + * - tu-vit_relpos_base_patch16_224.sw_in1k + * - tu-vit_relpos_base_patch16_clsgap_224.sw_in1k + * - tu-vit_relpos_base_patch32_plus_rpn_256.sw_in1k + * - tu-vit_relpos_medium_patch16_224.sw_in1k + * - tu-vit_relpos_medium_patch16_cls_224.sw_in1k + * - tu-vit_relpos_medium_patch16_rpn_224.sw_in1k + * - tu-vit_relpos_small_patch16_224.sw_in1k + * - tu-vit_small_patch8_224.dino + * - tu-vit_small_patch16_224.augreg_in1k + * - tu-vit_small_patch16_224.augreg_in21k + * - tu-vit_small_patch16_224.augreg_in21k_ft_in1k + * - tu-vit_small_patch16_224.dino + * - tu-vit_small_patch16_384.augreg_in1k + * - tu-vit_small_patch16_384.augreg_in21k_ft_in1k + * - tu-vit_small_patch32_224.augreg_in21k + * - tu-vit_small_patch32_224.augreg_in21k_ft_in1k + * - tu-vit_small_patch32_384.augreg_in21k_ft_in1k + * - tu-vit_small_r26_s32_224.augreg_in21k + * - tu-vit_small_r26_s32_224.augreg_in21k_ft_in1k + * - tu-vit_small_r26_s32_384.augreg_in21k_ft_in1k + * - tu-vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k + * - tu-vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k + * - tu-vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k + * - tu-vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k + * - tu-vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k + * - tu-vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k + * - tu-vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k + * - tu-vit_so400m_patch16_siglip_256.v2_webli + * - tu-vit_so400m_patch16_siglip_256.webli_i18n + * - tu-vit_so400m_patch16_siglip_384.v2_webli + * - tu-vit_so400m_patch16_siglip_512.v2_webli + * - tu-vit_so400m_patch16_siglip_gap_256.v2_webli + * - tu-vit_so400m_patch16_siglip_gap_256.webli_i18n + * - tu-vit_so400m_patch16_siglip_gap_384.v2_webli + * - tu-vit_so400m_patch16_siglip_gap_512.v2_webli + * - tu-vit_srelpos_medium_patch16_224.sw_in1k + * - tu-vit_srelpos_small_patch16_224.sw_in1k + * - tu-vit_tiny_patch16_224.augreg_in21k + * - tu-vit_tiny_patch16_224.augreg_in21k_ft_in1k + * - tu-vit_tiny_patch16_384.augreg_in21k_ft_in1k + * - tu-vit_tiny_r_s16_p8_224.augreg_in21k + * - tu-vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k + * - tu-vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k + * - tu-vit_wee_patch16_reg1_gap_256.sbb_in1k + * - tu-vit_xsmall_patch16_clip_224.tinyclip_yfcc15m + * - tu-vitamin_base_224.datacomp1b_clip + * - tu-vitamin_base_224.datacomp1b_clip_ltt + * - tu-vitamin_large2_224.datacomp1b_clip + * - tu-vitamin_large2_256.datacomp1b_clip + * - tu-vitamin_large2_336.datacomp1b_clip + * - tu-vitamin_large2_384.datacomp1b_clip + * - tu-vitamin_large_224.datacomp1b_clip + * - tu-vitamin_large_256.datacomp1b_clip + * - tu-vitamin_large_336.datacomp1b_clip + * - tu-vitamin_large_384.datacomp1b_clip + * - tu-vitamin_small_224.datacomp1b_clip + * - tu-vitamin_small_224.datacomp1b_clip_ltt + * - tu-vitamin_xlarge_256.datacomp1b_clip + * - tu-vitamin_xlarge_336.datacomp1b_clip + * - tu-vitamin_xlarge_384.datacomp1b_clip + * - tu-hiera_small_abswin_256.sbb2_e200_in12k + * - tu-hiera_small_abswin_256.sbb2_e200_in12k_ft_in1k + * - tu-hiera_small_abswin_256.sbb2_pd_e200_in12k + * - tu-hiera_small_abswin_256.sbb2_pd_e200_in12k_ft_in1k + * - tu-swin_base_patch4_window7_224.ms_in1k + * - tu-swin_base_patch4_window7_224.ms_in22k + * - tu-swin_base_patch4_window7_224.ms_in22k_ft_in1k + * - tu-swin_base_patch4_window12_384.ms_in1k + * - tu-swin_base_patch4_window12_384.ms_in22k + * - tu-swin_base_patch4_window12_384.ms_in22k_ft_in1k + * - tu-swin_large_patch4_window7_224.ms_in22k + * - tu-swin_large_patch4_window7_224.ms_in22k_ft_in1k + * - tu-swin_large_patch4_window12_384.ms_in22k + * - tu-swin_large_patch4_window12_384.ms_in22k_ft_in1k + * - tu-swin_s3_base_224.ms_in1k + * - tu-swin_s3_small_224.ms_in1k + * - tu-swin_s3_tiny_224.ms_in1k + * - tu-swin_small_patch4_window7_224.ms_in1k + * - tu-swin_small_patch4_window7_224.ms_in22k + * - tu-swin_small_patch4_window7_224.ms_in22k_ft_in1k + * - tu-swin_tiny_patch4_window7_224.ms_in1k + * - tu-swin_tiny_patch4_window7_224.ms_in22k + * - tu-swin_tiny_patch4_window7_224.ms_in22k_ft_in1k + * - tu-swinv2_base_window8_256.ms_in1k + * - tu-swinv2_base_window12_192.ms_in22k + * - tu-swinv2_base_window12to16_192to256.ms_in22k_ft_in1k + * - tu-swinv2_base_window12to24_192to384.ms_in22k_ft_in1k + * - tu-swinv2_base_window16_256.ms_in1k + * - tu-swinv2_cr_small_224.sw_in1k + * - tu-swinv2_cr_small_ns_224.sw_in1k + * - tu-swinv2_cr_tiny_ns_224.sw_in1k + * - tu-swinv2_large_window12_192.ms_in22k + * - tu-swinv2_large_window12to16_192to256.ms_in22k_ft_in1k + * - tu-swinv2_large_window12to24_192to384.ms_in22k_ft_in1k + * - tu-swinv2_small_window8_256.ms_in1k + * - tu-swinv2_small_window16_256.ms_in1k + * - tu-swinv2_tiny_window8_256.ms_in1k + * - tu-swinv2_tiny_window16_256.ms_in1k + * - tu-efficientformer_l1.snap_dist_in1k + * - tu-efficientformer_l3.snap_dist_in1k + * - tu-efficientformer_l7.snap_dist_in1k + * - tu-beit_base_patch16_224.in22k_ft_in22k + * - tu-beit_base_patch16_224.in22k_ft_in22k_in1k + * - tu-beit_base_patch16_384.in22k_ft_in22k_in1k + * - tu-beit_large_patch16_224.in22k_ft_in22k + * - tu-beit_large_patch16_224.in22k_ft_in22k_in1k + * - tu-beit_large_patch16_384.in22k_ft_in22k_in1k + * - tu-beit_large_patch16_512.in22k_ft_in22k_in1k + * - tu-beitv2_base_patch16_224.in1k_ft_in1k + * - tu-beitv2_base_patch16_224.in1k_ft_in22k + * - tu-beitv2_base_patch16_224.in1k_ft_in22k_in1k + * - tu-beitv2_large_patch16_224.in1k_ft_in1k + * - tu-beitv2_large_patch16_224.in1k_ft_in22k + * - tu-beitv2_large_patch16_224.in1k_ft_in22k_in1k + * - tu-cait_m36_384.fb_dist_in1k + * - tu-cait_m48_448.fb_dist_in1k + * - tu-cait_s24_224.fb_dist_in1k + * - tu-cait_s24_384.fb_dist_in1k + * - tu-cait_s36_384.fb_dist_in1k + * - tu-cait_xs24_384.fb_dist_in1k + * - tu-cait_xxs24_224.fb_dist_in1k + * - tu-cait_xxs24_384.fb_dist_in1k + * - tu-cait_xxs36_224.fb_dist_in1k + * - tu-cait_xxs36_384.fb_dist_in1k + * - tu-coatnet_0_rw_224.sw_in1k + * - tu-coatnet_1_rw_224.sw_in1k + * - tu-coatnet_2_rw_224.sw_in12k + * - tu-coatnet_2_rw_224.sw_in12k_ft_in1k + * - tu-coatnet_3_rw_224.sw_in12k + * - tu-coatnet_bn_0_rw_224.sw_in1k + * - tu-coatnet_nano_rw_224.sw_in1k + * - tu-coatnet_rmlp_1_rw2_224.sw_in12k + * - tu-coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k + * - tu-coatnet_rmlp_1_rw_224.sw_in1k + * - tu-coatnet_rmlp_2_rw_224.sw_in1k + * - tu-coatnet_rmlp_2_rw_224.sw_in12k + * - tu-coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k + * - tu-coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k + * - tu-coatnet_rmlp_nano_rw_224.sw_in1k + * - tu-deit3_base_patch16_224.fb_in1k + * - tu-deit3_base_patch16_224.fb_in22k_ft_in1k + * - tu-deit3_base_patch16_384.fb_in1k + * - tu-deit3_base_patch16_384.fb_in22k_ft_in1k + * - tu-deit3_large_patch16_224.fb_in1k + * - tu-deit3_large_patch16_224.fb_in22k_ft_in1k + * - tu-deit3_large_patch16_384.fb_in1k + * - tu-deit3_large_patch16_384.fb_in22k_ft_in1k + * - tu-deit3_medium_patch16_224.fb_in1k + * - tu-deit3_medium_patch16_224.fb_in22k_ft_in1k + * - tu-deit3_small_patch16_224.fb_in1k + * - tu-deit3_small_patch16_224.fb_in22k_ft_in1k + * - tu-deit3_small_patch16_384.fb_in1k + * - tu-deit3_small_patch16_384.fb_in22k_ft_in1k + * - tu-deit_base_distilled_patch16_224.fb_in1k + * - tu-deit_base_distilled_patch16_384.fb_in1k + * - tu-deit_base_patch16_224.fb_in1k + * - tu-deit_base_patch16_384.fb_in1k + * - tu-deit_small_distilled_patch16_224.fb_in1k + * - tu-deit_small_patch16_224.fb_in1k + * - tu-deit_tiny_distilled_patch16_224.fb_in1k + * - tu-deit_tiny_patch16_224.fb_in1k + * - tu-regnety_160.deit_in1k + * - tu-twins_pcpvt_base.in1k + * - tu-twins_pcpvt_large.in1k + * - tu-twins_pcpvt_small.in1k + * - tu-twins_svt_base.in1k + * - tu-twins_svt_large.in1k + * - tu-twins_svt_small.in1k + * - tu-xcit_large_24_p8_224.fb_dist_in1k + * - tu-xcit_large_24_p8_224.fb_in1k + * - tu-xcit_large_24_p8_384.fb_dist_in1k + * - tu-xcit_large_24_p16_224.fb_dist_in1k + * - tu-xcit_large_24_p16_224.fb_in1k + * - tu-xcit_large_24_p16_384.fb_dist_in1k + * - tu-xcit_medium_24_p8_224.fb_dist_in1k + * - tu-xcit_medium_24_p8_224.fb_in1k + * - tu-xcit_medium_24_p8_384.fb_dist_in1k + * - tu-xcit_medium_24_p16_224.fb_dist_in1k + * - tu-xcit_medium_24_p16_224.fb_in1k + * - tu-xcit_medium_24_p16_384.fb_dist_in1k + * - tu-xcit_nano_12_p8_224.fb_dist_in1k + * - tu-xcit_nano_12_p8_224.fb_in1k + * - tu-xcit_nano_12_p8_384.fb_dist_in1k + * - tu-xcit_nano_12_p16_224.fb_dist_in1k + * - tu-xcit_nano_12_p16_224.fb_in1k + * - tu-xcit_nano_12_p16_384.fb_dist_in1k + * - tu-xcit_small_12_p8_224.fb_dist_in1k + * - tu-xcit_small_12_p8_224.fb_in1k + * - tu-xcit_small_12_p8_384.fb_dist_in1k + * - tu-xcit_small_12_p16_224.fb_dist_in1k + * - tu-xcit_small_12_p16_224.fb_in1k + * - tu-xcit_small_12_p16_384.fb_dist_in1k + * - tu-xcit_small_24_p8_224.fb_dist_in1k + * - tu-xcit_small_24_p8_224.fb_in1k + * - tu-xcit_small_24_p8_384.fb_dist_in1k + * - tu-xcit_small_24_p16_224.fb_dist_in1k + * - tu-xcit_small_24_p16_224.fb_in1k + * - tu-xcit_small_24_p16_384.fb_dist_in1k + * - tu-xcit_tiny_12_p8_224.fb_dist_in1k + * - tu-xcit_tiny_12_p8_224.fb_in1k + * - tu-xcit_tiny_12_p8_384.fb_dist_in1k + * - tu-xcit_tiny_12_p16_224.fb_dist_in1k + * - tu-xcit_tiny_12_p16_224.fb_in1k + * - tu-xcit_tiny_12_p16_384.fb_dist_in1k + * - tu-xcit_tiny_24_p8_224.fb_dist_in1k + * - tu-xcit_tiny_24_p8_224.fb_in1k + * - tu-xcit_tiny_24_p8_384.fb_dist_in1k + * - tu-xcit_tiny_24_p16_224.fb_dist_in1k + * - tu-xcit_tiny_24_p16_224.fb_in1k + * - tu-xcit_tiny_24_p16_384.fb_dist_in1k \ No newline at end of file diff --git a/docs/models.rst b/docs/models.rst index c2037afb..ab04bb5e 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -81,3 +81,18 @@ Segformer ~~~~~~~~~ .. autoclass:: segmentation_models_pytorch.Segformer + +.. _dpt: + +DPT +~~~ + +.. note:: + + See full list of DPT-compatible timm encoders in :ref:`dpt-encoders`. + +.. note:: + + For some encoders, the model requires ``dynamic_img_size=True`` to be passed in order to work with resolutions different from what the encoder was trained for. + +.. autoclass:: segmentation_models_pytorch.DPT diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 7fc04dd7..e6627b83 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -53,7 +53,7 @@ You are done! Now you can train your model with your favorite framework, or as s for images, gt_masks in dataloader: - predicted_mask = model(image) + predicted_mask = model(images) loss = loss_fn(predicted_mask, gt_masks) loss.backward() diff --git a/docs/save_load.rst b/docs/save_load.rst index e90e4eba..15434eb6 100644 --- a/docs/save_load.rst +++ b/docs/save_load.rst @@ -40,6 +40,14 @@ For example: # Alternatively, load the model directly from the Hugging Face Hub model = smp.from_pretrained('username/my-model') +Loading pre-trained model with different number of classes for fine-tuning: + +.. code:: python + + import segmentation_models_pytorch as smp + + model = smp.from_pretrained('', classes=5, strict=False) + Saving model Metrics and Dataset Name ------------------------------------- diff --git a/examples/binary_segmentation_buildings.py b/examples/binary_segmentation_buildings.py new file mode 100644 index 00000000..33636477 --- /dev/null +++ b/examples/binary_segmentation_buildings.py @@ -0,0 +1,498 @@ +""" +This script demonstrates how to train a binary segmentation model using the +CamVid dataset and segmentation_models_pytorch. The CamVid dataset is a +collection of videos with pixel-level annotations for semantic segmentation. +The dataset includes 367 training images, 101 validation images, and 233 test. +Each training image has a corresponding mask that labels each pixel as belonging +to these classes with the numerical labels as follows: +- Sky: 0 +- Building: 1 +- Pole: 2 +- Road: 3 +- Pavement: 4 +- Tree: 5 +- SignSymbol: 6 +- Fence: 7 +- Car: 8 +- Pedestrian: 9 +- Bicyclist: 10 +- Unlabelled: 11 + +In this script, we focus on binary segmentation, where the goal is to classify +each pixel as whether belonging to a certain class (Foregorund) or +not (Background). + +Class Labels: +- 0: Background +- 1: Foreground + +The script includes the following steps: +1. Set the device to GPU if available, otherwise use CPU. +2. Download the CamVid dataset if it is not already present. +3. Define hyperparameters for training. +4. Define a custom dataset class for loading and preprocessing the CamVid + dataset. +5. Define a function to visualize images and masks. +6. Create datasets and dataloaders for training, validation, and testing. +7. Define a model class for the segmentation task. +8. Train the model using the training and validation datasets. +9. Evaluate the model using the test dataset and save the output masks and + metrics. +""" + +import logging +import os + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.optim import lr_scheduler +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as BaseDataset +from tqdm import tqdm + +import segmentation_models_pytorch as smp + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(message)s", + datefmt="%d:%m:%Y %H:%M:%S", +) + +# ---------------------------- +# Set the device to GPU if available +# ---------------------------- +device = "cuda" if torch.cuda.is_available() else "cpu" +logging.info(f"Using device: {device}") +if device == "cpu": + os.system("export OMP_NUM_THREADS=64") + torch.set_num_threads(os.cpu_count()) + +# ---------------------------- +# Download the CamVid dataset, if needed +# ---------------------------- +# Change this to your desired directory +main_dir = "./examples/binary_segmentation_data/" + +data_dir = os.path.join(main_dir, "dataset") +if not os.path.exists(data_dir): + logging.info("Loading data...") + os.system(f"git clone https://github.com/alexgkendall/SegNet-Tutorial {data_dir}") + logging.info("Done!") + +# Create a directory to store the output masks +output_dir = os.path.join(main_dir, "output_images") +os.makedirs(output_dir, exist_ok=True) + +# ---------------------------- +# Define the hyperparameters +# ---------------------------- +epochs_max = 200 # Number of epochs to train the model +adam_lr = 2e-4 # Learning rate for the Adam optimizer +eta_min = 1e-5 # Minimum learning rate for the scheduler +batch_size = 8 # Batch size for training +input_image_reshape = (320, 320) # Desired shape for the input images and masks +foreground_class = 1 # 1 for binary segmentation + + +# ---------------------------- +# Define a custom dataset class for the CamVid dataset +# ---------------------------- +class Dataset(BaseDataset): + """ + A custom dataset class for binary segmentation tasks. + + Parameters: + ---------- + + - images_dir (str): Directory containing the input images. + - masks_dir (str): Directory containing the corresponding masks. + - input_image_reshape (tuple, optional): Desired shape for the input + images and masks. Default is (320, 320). + - foreground_class (int, optional): The class value in the mask to be + considered as the foreground. Default is 1. + - augmentation (callable, optional): A function/transform to apply to the + images and masks for data augmentation. + """ + + def __init__( + self, + images_dir, + masks_dir, + input_image_reshape=(320, 320), + foreground_class=1, + augmentation=None, + ): + self.ids = os.listdir(images_dir) + self.images_filepaths = [ + os.path.join(images_dir, image_id) for image_id in self.ids + ] + self.masks_filepaths = [ + os.path.join(masks_dir, image_id) for image_id in self.ids + ] + + self.input_image_reshape = input_image_reshape + self.foreground_class = foreground_class + self.augmentation = augmentation + + def __getitem__(self, i): + """ + Retrieves the image and corresponding mask at index `i`. + + Parameters: + ---------- + + - i (int): Index of the image and mask to retrieve. + Returns: + - A tuple containing: + - image (torch.Tensor): The preprocessed image tensor of shape + (1, input_image_reshape) - e.g., (1, 320, 320) - normalized to [0, 1]. + - mask_remap (torch.Tensor): The preprocessed mask tensor of + shape input_image_reshape with values 0 or 1. + """ + # Read the image + image = cv2.imread( + self.images_filepaths[i], cv2.IMREAD_GRAYSCALE + ) # Read image as grayscale + image = np.expand_dims(image, axis=-1) # Add channel dimension + + # resize image to input_image_reshape + image = cv2.resize(image, self.input_image_reshape) + + # Read the mask in grayscale mode + mask = cv2.imread(self.masks_filepaths[i], 0) + + # Update the mask: Set foreground_class to 1 and the rest to 0 + mask_remap = np.where(mask == self.foreground_class, 1, 0).astype(np.uint8) + + # resize mask to input_image_reshape + mask_remap = cv2.resize(mask_remap, self.input_image_reshape) + + if self.augmentation: + sample = self.augmentation(image=image, mask=mask_remap) + image, mask_remap = sample["image"], sample["mask"] + + # Convert to PyTorch tensors + # Add channel dimension if missing + if image.ndim == 2: + image = np.expand_dims(image, axis=-1) + + # HWC -> CHW and normalize to [0, 1] + image = torch.tensor(image).float().permute(2, 0, 1) / 255.0 + + # Ensure mask is LongTensor + mask_remap = torch.tensor(mask_remap).long() + + return image, mask_remap + + def __len__(self): + return len(self.ids) + + +# Define a class for the CamVid model +class CamVidModel(torch.nn.Module): + """ + A PyTorch model for binary segmentation using the Segmentation Models + PyTorch library. + + Parameters: + ---------- + + - arch (str): The architecture name of the segmentation model + (e.g., 'Unet', 'FPN'). + - encoder_name (str): The name of the encoder to use + (e.g., 'resnet34', 'vgg16'). + - in_channels (int, optional): Number of input channels (e.g., 3 for RGB). + - out_classes (int, optional): Number of output classes (e.g., 1 for binary) + **kwargs: Additional keyword arguments to pass to the model + creation function. + """ + + def __init__(self, arch, encoder_name, in_channels=3, out_classes=1, **kwargs): + super().__init__() + self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) + self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) + self.model = smp.create_model( + arch, + encoder_name=encoder_name, + in_channels=in_channels, + classes=out_classes, + **kwargs, + ) + + def forward(self, image): + # Normalize image + image = (image - self.mean) / self.std + mask = self.model(image) + return mask + + +def visualize(output_dir, image_filename, **images): + """PLot images in one row.""" + n = len(images) + plt.figure(figsize=(16, 5)) + for i, (name, image) in enumerate(images.items()): + plt.subplot(1, n, i + 1) + plt.xticks([]) + plt.yticks([]) + plt.title(" ".join(name.split("_")).title()) + plt.imshow(image) + plt.show() + plt.savefig(os.path.join(output_dir, image_filename)) + plt.close() + + +# Use multiple CPUs in parallel +def train_and_evaluate_one_epoch( + model, train_dataloader, valid_dataloader, optimizer, scheduler, loss_fn, device +): + # Set the model to training mode + model.train() + train_loss = 0 + for batch in tqdm(train_dataloader, desc="Training"): + images, masks = batch + images, masks = images.to(device), masks.to(device) + + optimizer.zero_grad() + outputs = model(images) + + loss = loss_fn(outputs, masks) + loss.backward() + optimizer.step() + + train_loss += loss.item() + + scheduler.step() + avg_train_loss = train_loss / len(train_dataloader) + + # Set the model to evaluation mode + model.eval() + val_loss = 0 + with torch.no_grad(): + for batch in tqdm(valid_dataloader, desc="Evaluating"): + images, masks = batch + images, masks = images.to(device), masks.to(device) + + outputs = model(images) + loss = loss_fn(outputs, masks) + + val_loss += loss.item() + + avg_val_loss = val_loss / len(valid_dataloader) + return avg_train_loss, avg_val_loss + + +def train_model( + model, + train_dataloader, + valid_dataloader, + optimizer, + scheduler, + loss_fn, + device, + epochs, +): + train_losses = [] + val_losses = [] + + for epoch in range(epochs): + avg_train_loss, avg_val_loss = train_and_evaluate_one_epoch( + model, + train_dataloader, + valid_dataloader, + optimizer, + scheduler, + loss_fn, + device, + ) + train_losses.append(avg_train_loss) + val_losses.append(avg_val_loss) + + logging.info( + f"Epoch {epoch + 1}/{epochs}, Training Loss: {avg_train_loss:.2f}, Validation Loss: {avg_val_loss:.2f}" + ) + + history = { + "train_losses": train_losses, + "val_losses": val_losses, + } + return history + + +def test_model(model, output_dir, test_dataloader, loss_fn, device): + # Set the model to evaluation mode + model.eval() + test_loss = 0 + tp, fp, fn, tn = 0, 0, 0, 0 + with torch.no_grad(): + for batch in tqdm(test_dataloader, desc="Evaluating"): + images, masks = batch + images, masks = images.to(device), masks.to(device) + + outputs = model(images) + loss = loss_fn(outputs, masks) + + for i, output in enumerate(outputs): + input = images[i].cpu().numpy().transpose(1, 2, 0) + output = output.squeeze().cpu().numpy() + + visualize( + output_dir, + f"output_{i}.png", + input_image=input, + output_mask=output, + binary_mask=output > 0.5, + ) + + test_loss += loss.item() + + prob_mask = outputs.sigmoid().squeeze(1) + pred_mask = (prob_mask > 0.5).long() + batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats( + pred_mask, masks, mode="binary" + ) + tp += batch_tp.sum().item() + fp += batch_fp.sum().item() + fn += batch_fn.sum().item() + tn += batch_tn.sum().item() + + test_loss_mean = test_loss / len(test_dataloader) + logging.info(f"Test Loss: {test_loss_mean:.2f}") + + iou_score = smp.metrics.iou_score( + torch.tensor([tp]), + torch.tensor([fp]), + torch.tensor([fn]), + torch.tensor([tn]), + reduction="micro", + ) + + return test_loss_mean, iou_score.item() + + +# ---------------------------- +# Define the data directories and create the datasets +# ---------------------------- +x_train_dir = os.path.join(data_dir, "CamVid", "train") +y_train_dir = os.path.join(data_dir, "CamVid", "trainannot") + +x_val_dir = os.path.join(data_dir, "CamVid", "val") +y_val_dir = os.path.join(data_dir, "CamVid", "valannot") + +x_test_dir = os.path.join(data_dir, "CamVid", "test") +y_test_dir = os.path.join(data_dir, "CamVid", "testannot") + +train_dataset = Dataset( + x_train_dir, + y_train_dir, + input_image_reshape=input_image_reshape, + foreground_class=foreground_class, +) +valid_dataset = Dataset( + x_val_dir, + y_val_dir, + input_image_reshape=input_image_reshape, + foreground_class=foreground_class, +) +test_dataset = Dataset( + x_test_dir, + y_test_dir, + input_image_reshape=input_image_reshape, + foreground_class=foreground_class, +) + +image, mask = train_dataset[0] +logging.info(f"Unique values in mask: {np.unique(mask)}") +logging.info(f"Image shape: {image.shape}") +logging.info(f"Mask shape: {mask.shape}") + +# ---------------------------- +# Create the dataloaders using the datasets +# ---------------------------- +logging.info(f"Train size: {len(train_dataset)}") +logging.info(f"Valid size: {len(valid_dataset)}") +logging.info(f"Test size: {len(test_dataset)}") + +train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) +test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) + +# ---------------------------- +# Lets look at some samples +# ---------------------------- +# Visualize and save train sample +sample = train_dataset[0] +visualize( + output_dir, + "train_sample.png", + train_image=sample[0].numpy().transpose(1, 2, 0), + train_mask=sample[1].squeeze(), +) + +# Visualize and save validation sample +sample = valid_dataset[0] +visualize( + output_dir, + "validation_sample.png", + validation_image=sample[0].numpy().transpose(1, 2, 0), + validation_mask=sample[1].squeeze(), +) + +# Visualize and save test sample +sample = test_dataset[0] +visualize( + output_dir, + "test_sample.png", + test_image=sample[0].numpy().transpose(1, 2, 0), + test_mask=sample[1].squeeze(), +) + +# ---------------------------- +# Create and train the model +# ---------------------------- +max_iter = epochs_max * len(train_dataloader) # Total number of iterations + +model = CamVidModel("Unet", "resnet34", in_channels=3, out_classes=1) + +# Training loop +model = model.to(device) + +# Define the Adam optimizer +optimizer = torch.optim.Adam(model.parameters(), lr=adam_lr) + +# Define the learning rate scheduler +scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iter, eta_min=eta_min) + +# Define the loss function +loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True) + +# Train the model +history = train_model( + model, + train_dataloader, + valid_dataloader, + optimizer, + scheduler, + loss_fn, + device, + epochs_max, +) + +# Visualize the training and validation losses +plt.figure(figsize=(10, 5)) +plt.plot(history["train_losses"], label="Train Loss") +plt.plot(history["val_losses"], label="Validation Loss") +plt.xlabel("Epochs") +plt.ylabel("Loss") +plt.title("Training and Validation Losses") +plt.legend() +plt.savefig(os.path.join(output_dir, "train_val_losses.png")) +plt.close() + + +# Evaluate the model +test_loss = test_model(model, output_dir, test_dataloader, loss_fn, device) + +logging.info(f"Test Loss: {test_loss[0]}, IoU Score: {test_loss[1]}") +logging.info(f"The output masks are saved in {output_dir}.") diff --git a/examples/dpt_inference_pretrained.ipynb b/examples/dpt_inference_pretrained.ipynb new file mode 100644 index 00000000..adfb5a15 --- /dev/null +++ b/examples/dpt_inference_pretrained.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure you have the latest version of the libraries\n", + "!pip install -U segmentation-models-pytorch\n", + "!pip install albumentations matplotlib requests pillow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import numpy as np\n", + "import albumentations as A\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import segmentation_models_pytorch as smp\n", + "\n", + "from PIL import Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading weights from local directory\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# More checkpoints can be found here:\n", + "checkpoint = \"smp-hub/dpt-large-ade20k\"\n", + "\n", + "# Load pretrained model and preprocessing function\n", + "model = smp.from_pretrained(checkpoint).eval().to(device)\n", + "preprocessing = A.Compose.from_pretrained(checkpoint)\n", + "\n", + "# Load image\n", + "url = \"https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg\"\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "\n", + "# Preprocess image\n", + "image = np.array(image)\n", + "normalized_image = preprocessing(image=image)[\"image\"]\n", + "input_tensor = torch.as_tensor(normalized_image)\n", + "input_tensor = input_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> BCHW\n", + "input_tensor = input_tensor.to(device)\n", + "\n", + "# Perform inference\n", + "with torch.no_grad():\n", + " output_mask = model(input_tensor)\n", + "\n", + "# Postprocess mask\n", + "mask = torch.nn.functional.interpolate(\n", + " output_mask, size=image.shape[:2], mode=\"bilinear\", align_corners=False\n", + ")\n", + "mask = mask[0].argmax(0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot results\n", + "plt.figure(figsize=(12, 6))\n", + "\n", + "plt.subplot(121)\n", + "plt.axis(\"off\")\n", + "plt.imshow(image)\n", + "plt.title(\"Input Image\")\n", + "\n", + "plt.subplot(122)\n", + "plt.axis(\"off\")\n", + "plt.imshow(mask)\n", + "plt.title(\"Output Mask\")\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/segformer_inference_pretrained.ipynb b/examples/segformer_inference_pretrained.ipynb index a0dda7d4..d2d195fd 100644 --- a/examples/segformer_inference_pretrained.ipynb +++ b/examples/segformer_inference_pretrained.ipynb @@ -13,9 +13,9 @@ "metadata": {}, "outputs": [], "source": [ - "# fix for HF hub download\n", - "# see PR https://github.com/albumentations-team/albumentations/pull/2171\n", - "!pip install -U git+https://github.com/qubvel/albumentations@patch-2" + "# make sure you have the latest version of the libraries\n", + "!pip install -U segmentation-models-pytorch\n", + "!pip install albumentations matplotlib requests pillow" ] }, { diff --git a/examples/upernet_inference_pretrained.ipynb b/examples/upernet_inference_pretrained.ipynb new file mode 100644 index 00000000..aa644858 --- /dev/null +++ b/examples/upernet_inference_pretrained.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/upernet_inference_pretrained.ipynb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# make sure you have the latest version of the libraries\n", + "!pip install -U segmentation-models-pytorch\n", + "!pip install albumentations matplotlib requests pillow" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/projects/segmentation_models.pytorch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import requests\n", + "import numpy as np\n", + "import albumentations as A\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import torch\n", + "import segmentation_models_pytorch as smp\n", + "\n", + "from PIL import Image" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preprocessing:\n", + " Compose([\n", + " Resize(p=1.0, height=512, width=512, interpolation=1, mask_interpolation=0),\n", + " Normalize(p=1.0, mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375), max_pixel_value=1.0, normalization='standard'),\n", + "], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}, is_check_shapes=True)\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# More checkpoints can be found here:\n", + "# https://huggingface.co/collections/smp-hub/upernet-67fadcdbe08418c6ea94f768\n", + "checkpoint = \"smp-hub/upernet-swin-tiny\"\n", + "\n", + "# Load pretrained model and preprocessing function\n", + "model = smp.from_pretrained(checkpoint).eval().to(device)\n", + "preprocessing = A.Compose.from_pretrained(checkpoint)\n", + "print(\"Preprocessing:\\n\", preprocessing)\n", + "\n", + "# Load image\n", + "url = \"https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg\"\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "\n", + "# Preprocess image\n", + "image = np.array(image)\n", + "normalized_image = preprocessing(image=image)[\"image\"]\n", + "input_tensor = torch.as_tensor(normalized_image)\n", + "input_tensor = input_tensor.permute(2, 0, 1).unsqueeze(0) # HWC -> BCHW\n", + "input_tensor = input_tensor.to(device)\n", + "\n", + "# Perform inference\n", + "with torch.no_grad():\n", + " output_mask = model(input_tensor)\n", + "\n", + "# Postprocess mask\n", + "mask = torch.nn.functional.interpolate(\n", + " output_mask, size=image.shape[:2], mode=\"bilinear\", align_corners=False\n", + ")\n", + "mask = mask[0].argmax(0).cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot results\n", + "plt.figure(figsize=(12, 6))\n", + "\n", + "plt.subplot(121)\n", + "plt.axis(\"off\")\n", + "plt.imshow(image)\n", + "plt.title(\"Input Image\")\n", + "\n", + "plt.subplot(122)\n", + "plt.axis(\"off\")\n", + "plt.imshow(mask)\n", + "plt.title(\"Output Mask\")\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/licenses/LICENSES.md b/licenses/LICENSES.md index 670764a2..e51ad8d0 100644 --- a/licenses/LICENSES.md +++ b/licenses/LICENSES.md @@ -13,13 +13,20 @@ The majority of the code is licensed under the [MIT License](LICENSE). However, * [segmentation_models_pytorch/encoders/mix_transformer.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/mix_transformer.py) * [LICENSE_nvidia](LICENSE_nvidia.md) - - Apple License * Applies to the MobileOne encoder * [segmentation_models_pytorch/encoders/mobileone.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/mobileone.py) * [LICENSE_apple](LICENSE_apple.md) - BSD 3-Clause License - * Applies to the DeepLabV3 decoder + * Applies to several encoders and the DeepLabV3 decoder + * [segmentation_models_pytorch/encoders/_dpn.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_dpn.py) + * [segmentation_models_pytorch/encoders/_inceptionresnetv2.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_inceptionresnetv2.py) + * [segmentation_models_pytorch/encoders/_inceptionv4.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_inceptionv4.py) + * [segmentation_models_pytorch/encoders/_senet.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_senet.py) + * [segmentation_models_pytorch/encoders/_xception.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_xception.py) * [segmentation_models_pytorch/decoders/deeplabv3/decoder.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/decoders/deeplabv3/decoder.py) +- Apache-2.0 License + * Applies to the EfficientNet encoder + * [segmentation_models_pytorch/encoders/_efficientnet.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_efficientnet.py) diff --git a/licenses/LICENSE_apache.md b/licenses/LICENSE_apache.md new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/licenses/LICENSE_apache.md @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/misc/generate_table.py b/misc/generate_table.py index f14b1a3c..4e0efed5 100644 --- a/misc/generate_table.py +++ b/misc/generate_table.py @@ -1,10 +1,17 @@ +import os import segmentation_models_pytorch as smp +from tqdm import tqdm + encoders = smp.encoders.encoders WIDTH = 32 -COLUMNS = ["Encoder", "Weights", "Params, M"] +COLUMNS = ["Encoder", "Pretrained weights", "Params, M", "Script", "Compile", "Export"] +FILE = "encoders_table.md" + +if os.path.exists(FILE): + os.remove(FILE) def wrap_row(r): @@ -16,18 +23,23 @@ def wrap_row(r): ["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1) ) -print(wrap_row(header)) -print(wrap_row(separator)) +print(wrap_row(header), file=open(FILE, "a")) +print(wrap_row(separator), file=open(FILE, "a")) -for encoder_name, encoder in encoders.items(): +for encoder_name, encoder in tqdm(encoders.items()): weights = "
".join(encoder["pretrained_settings"].keys()) - encoder_name = encoder_name.ljust(WIDTH, " ") - weights = weights.ljust(WIDTH, " ") model = encoder["encoder"](**encoder["params"], depth=5) + + script = "βœ…" if model._is_torch_scriptable else "❌" + compile = "βœ…" if model._is_torch_compilable else "❌" + export = "βœ…" if model._is_torch_exportable else "❌" + params = sum(p.numel() for p in model.parameters()) params = str(params // 1000000) + "M" - params = params.ljust(WIDTH, " ") - row = "|".join([encoder_name, weights, params]) - print(wrap_row(row)) + row = [encoder_name, weights, params, script, compile, export] + row = [str(r).ljust(WIDTH, " ") for r in row] + row = "|".join(row) + + print(wrap_row(row), file=open(FILE, "a")) diff --git a/misc/generate_table_timm.py b/misc/generate_table_timm.py index 6c2a1b24..8e875583 100644 --- a/misc/generate_table_timm.py +++ b/misc/generate_table_timm.py @@ -17,30 +17,68 @@ def has_dilation_support(name): return False +def valid_vit_encoder_for_dpt(name): + if "vit" not in name: + return False + encoder = timm.create_model(name) + feature_info = encoder.feature_info + feature_info_obj = timm.models.FeatureInfo( + feature_info=feature_info, out_indices=[0, 1, 2, 3] + ) + reduction_scales = list(feature_info_obj.reduction()) + + if len(set(reduction_scales)) > 1: + return False + + output_stride = reduction_scales[0] + if bin(output_stride).count("1") != 1: + return False + + return True + + def make_table(data): names = data.keys() max_len1 = max([len(x) for x in names]) + 2 max_len2 = len("support dilation") + 2 + max_len3 = len("Supported for DPT") + 2 - l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n" - l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n" + l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n" + l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n" top = ( "| " + "Encoder name".ljust(max_len1 - 2) + " | " + "Support dilation".center(max_len2 - 2) + + " | " + + "Supported for DPT".center(max_len3 - 2) + " |\n" ) table = l1 + top + l2 for k in sorted(data.keys()): - support = ( - "βœ…".center(max_len2 - 3) - if data[k]["has_dilation"] - else " ".center(max_len2 - 2) + if "has_dilation" in data[k] and data[k]["has_dilation"]: + support = "βœ…".center(max_len2 - 3) + + else: + support = " ".center(max_len2 - 2) + + if "supported_only_for_dpt" in data[k]: + supported_for_dpt = "βœ…".center(max_len3 - 3) + + else: + supported_for_dpt = " ".center(max_len3 - 2) + + table += ( + "| " + + k.ljust(max_len1 - 2) + + " | " + + support + + " | " + + supported_for_dpt + + " |\n" ) - table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n" table += l1 return table @@ -55,8 +93,13 @@ def make_table(data): check_features_and_reduction(name) has_dilation = has_dilation_support(name) supported_models[name] = dict(has_dilation=has_dilation) + except Exception: - continue + try: + if valid_vit_encoder_for_dpt(name): + supported_models[name] = dict(supported_only_for_dpt=True) + except Exception: + continue table = make_table(supported_models) print(table) diff --git a/misc/generate_test_models.py b/misc/generate_test_models.py index 61d6bfd0..a26cbc66 100644 --- a/misc/generate_test_models.py +++ b/misc/generate_test_models.py @@ -9,33 +9,50 @@ api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN")) -for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): - model = model_class(encoder_name=ENCODER_NAME) - model = model.eval() - - # generate test sample - torch.manual_seed(423553) - sample = torch.rand(1, 3, 256, 256) - - with torch.no_grad(): - output = model(sample) +def save_and_push(model, inputs, outputs, model_name, encoder_name): with tempfile.TemporaryDirectory() as tmpdir: # save model model.save_pretrained(f"{tmpdir}") # save input and output - torch.save(sample, f"{tmpdir}/input-tensor.pth") - torch.save(output, f"{tmpdir}/output-tensor.pth") + torch.save(inputs, f"{tmpdir}/input-tensor.pth") + torch.save(outputs, f"{tmpdir}/output-tensor.pth") # create repo - repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}" + repo_id = f"{HUB_REPO}/{model_name}-{encoder_name}" if not api.repo_exists(repo_id=repo_id): api.create_repo(repo_id=repo_id, repo_type="model") # upload to hub api.upload_folder( folder_path=tmpdir, - repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}", + repo_id=f"{HUB_REPO}/{model_name}-{encoder_name}", repo_type="model", ) + + +for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): + if model_name == "dpt": + encoder_name = "tu-test_vit" + model = smp.DPT( + encoder_name=encoder_name, + decoder_readout="cat", + decoder_intermediate_channels=(16, 32, 64, 64), + decoder_fusion_channels=16, + dynamic_img_size=True, + ) + else: + encoder_name = ENCODER_NAME + model = model_class(encoder_name=encoder_name) + + model = model.eval() + + # generate test sample + torch.manual_seed(423553) + sample = torch.rand(1, 3, 256, 256) + + with torch.no_grad(): + output = model(sample) + + save_and_push(model, sample, output, model_name, encoder_name) diff --git a/pyproject.toml b/pyproject.toml index 0e9310b5..f3e55a96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,12 +17,10 @@ classifiers = [ 'Programming Language :: Python :: Implementation :: PyPy', ] dependencies = [ - 'efficientnet-pytorch>=0.6.1', 'huggingface-hub>=0.24', 'numpy>=1.19.3', 'pillow>=8', - 'pretrainedmodels>=0.7.1', - 'six>=1.5', + 'safetensors>=0.3.1', 'timm>=0.9', 'torch>=1.8', 'torchvision>=0.9', @@ -39,11 +37,13 @@ docs = [ 'sphinx-book-theme', ] test = [ + 'gitpython', 'packaging', 'pytest', 'pytest-cov', 'pytest-xdist', - 'ruff', + 'ruff>=0.9', + 'setuptools', ] [project.urls] @@ -61,18 +61,10 @@ include = ['segmentation_models_pytorch*'] [tool.pytest.ini_options] markers = [ - "deeplabv3", - "deeplabv3plus", - "fpn", - "linknet", - "manet", - "pan", - "psp", - "segformer", - "unet", - "unetplusplus", - "upernet", "logits_match", + "compile", + "torch_export", + "torch_script", ] [tool.coverage.run] diff --git a/requirements/docs.txt b/requirements/docs.txt index 072a7e16..26afc33e 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,5 +1,5 @@ autodocsumm==0.2.14 -huggingface-hub==0.27.1 +huggingface-hub==0.30.2 six==1.17.0 -sphinx==8.1.3 -sphinx-book-theme==1.1.3 +sphinx==8.2.3 +sphinx-book-theme==1.1.4 diff --git a/requirements/minimum.old b/requirements/minimum.old index 1080bdb4..678f83f4 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -1,9 +1,7 @@ -efficientnet-pytorch==0.6.1 huggingface-hub==0.24.0 numpy==1.19.3 pillow==8.0.0 -pretrainedmodels==0.7.1 -six==1.5.0 +safetensors==0.3.1 timm==0.9.0 torch==1.9.0 torchvision==0.10.0 diff --git a/requirements/required.txt b/requirements/required.txt index e04033b5..bdb4d9e3 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -1,10 +1,8 @@ -efficientnet-pytorch==0.7.1 -huggingface_hub==0.27.1 -numpy==2.2.1 -pillow==11.1.0 -pretrainedmodels==0.7.4 -six==1.17.0 -timm==1.0.12 -torch==2.5.1 -torchvision==0.20.1 +huggingface_hub==0.30.2 +numpy==2.2.4 +pillow==11.2.1 +safetensors==0.5.3 +timm==1.0.15 +torch==2.6.0 +torchvision==0.21.0 tqdm==4.67.1 diff --git a/requirements/test.txt b/requirements/test.txt index 5f27affd..23f6025c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,7 @@ +gitpython==3.1.44 packaging==24.2 -pytest==8.3.4 +pytest==8.3.5 pytest-xdist==3.6.1 -pytest-cov==6.0.0 -ruff==0.8.6 +pytest-cov==6.1.1 +ruff==0.11.5 +setuptools==78.1.0 \ No newline at end of file diff --git a/scripts/models-conversions/dpt-original-to-smp.py b/scripts/models-conversions/dpt-original-to-smp.py new file mode 100644 index 00000000..fab1705d --- /dev/null +++ b/scripts/models-conversions/dpt-original-to-smp.py @@ -0,0 +1,122 @@ +import cv2 +import torch +import albumentations as A +import segmentation_models_pytorch as smp + +MODEL_WEIGHTS_PATH = r"dpt_large-ade20k-b12dca68.pt" +HF_HUB_PATH = "qubvel-hf/dpt-large-ade20k" +PUSH_TO_HUB = False + + +def get_transform(): + return A.Compose( + [ + A.LongestMaxSize(max_size=480, interpolation=cv2.INTER_CUBIC), + A.Normalize( + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0 + ), + # This is not correct transform, ideally image should resized without padding to multiple of 32, + # but we take there is no such transform in albumentations, here is closest one + A.PadIfNeeded( + min_height=None, + min_width=None, + pad_height_divisor=32, + pad_width_divisor=32, + border_mode=cv2.BORDER_CONSTANT, + value=0, + p=1, + ), + ] + ) + + +if __name__ == "__main__": + # fmt: off + smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150, dynamic_img_size=True) + dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH, weights_only=True) + + for layer_index in range(0, 4): + for param in ["running_mean", "running_var", "num_batches_tracked", "weight", "bias"]: + for block_index in [1, 2]: + for bn_index in [1, 2]: + # Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model, + # Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ... + # and so on ... + # This is because order of calling fusion layers is reversed in original DPT implementation + dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"] = \ + dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}") + + if param in ["weight", "bias"]: + if param == "weight": + for block_index in [1, 2]: + for conv_index in [1, 2]: + dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"] = \ + dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}") + + dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"] = \ + dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}") + + dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.project.{param}"] = \ + dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.out_conv.{param}") + + dpt_model_dict[f"decoder.projection_blocks.{layer_index}.project.0.{param}"] = \ + dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}") + + dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"] = \ + dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.3.{param}") + + if layer_index != 2: + dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"] = \ + dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.4.{param}") + + # Changing state dict keys for segmentation head + dpt_model_dict = { + name.replace("scratch.output_conv", "segmentation_head.head"): parameter + for name, parameter in dpt_model_dict.items() + } + + # Changing state dict keys for encoder layers + dpt_model_dict = { + name.replace("pretrained.model", "encoder.model"): parameter + for name, parameter in dpt_model_dict.items() + } + + # Removing keys, value pairs associated with auxiliary head + dpt_model_dict = { + name: parameter + for name, parameter in dpt_model_dict.items() + if not name.startswith("auxlayer") + } + # fmt: on + + smp_model.load_state_dict(dpt_model_dict, strict=True) + + # ------- DO NOT touch this section ------- + smp_model.eval() + + input_tensor = torch.ones((1, 3, 384, 384)) + output = smp_model(input_tensor) + + print(output.shape) + print(output[0, 0, :3, :3]) + + expected_slice = torch.tensor( + [ + [3.4243, 3.4553, 3.4863], + [3.3332, 3.2876, 3.2419], + [3.2422, 3.1199, 2.9975], + ] + ) + + torch.testing.assert_close( + output[0, 0, :3, :3], expected_slice, atol=1e-4, rtol=1e-4 + ) + + # Saving + transform = get_transform() + + transform.save_pretrained(HF_HUB_PATH) + smp_model.save_pretrained(HF_HUB_PATH, push_to_hub=PUSH_TO_HUB) + + # Re-loading to make sure everything is saved correctly + smp_model = smp.from_pretrained(HF_HUB_PATH) diff --git a/scripts/models-conversions/upernet-hf-to-smp.py b/scripts/models-conversions/upernet-hf-to-smp.py new file mode 100644 index 00000000..8cd3162f --- /dev/null +++ b/scripts/models-conversions/upernet-hf-to-smp.py @@ -0,0 +1,249 @@ +import re +import torch +import albumentations as A +import segmentation_models_pytorch as smp +from huggingface_hub import hf_hub_download, HfApi +from collections import defaultdict + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# fmt: off +CONVNEXT_MAPPING = { + r"backbone.embeddings.patch_embeddings.(weight|bias)": r"encoder.model.stem_0.\1", + r"backbone.embeddings.layernorm.(weight|bias)": r"encoder.model.stem_1.\1", + r"backbone.encoder.stages.(\d+).layers.(\d+).layer_scale_parameter": r"encoder.model.stages_\1.blocks.\2.gamma", + r"backbone.encoder.stages.(\d+).layers.(\d+).dwconv.(weight|bias)": r"encoder.model.stages_\1.blocks.\2.conv_dw.\3", + r"backbone.encoder.stages.(\d+).layers.(\d+).layernorm.(weight|bias)": r"encoder.model.stages_\1.blocks.\2.norm.\3", + r"backbone.encoder.stages.(\d+).layers.(\d+).pwconv(\d+).(weight|bias)": r"encoder.model.stages_\1.blocks.\2.mlp.fc\3.\4", + r"backbone.encoder.stages.(\d+).downsampling_layer.(\d+).(weight|bias)": r"encoder.model.stages_\1.downsample.\2.\3", +} + +SWIN_MAPPING = { + r"backbone.embeddings.patch_embeddings.projection": r"encoder.model.patch_embed.proj", + r"backbone.embeddings.norm": r"encoder.model.patch_embed.norm", + r"backbone.encoder.layers.(\d+).blocks.(\d+).layernorm_before": r"encoder.model.layers_\1.blocks.\2.norm1", + r"backbone.encoder.layers.(\d+).blocks.(\d+).attention.self.relative_position_bias_table": r"encoder.model.layers_\1.blocks.\2.attn.relative_position_bias_table", + r"backbone.encoder.layers.(\d+).blocks.(\d+).attention.self.(query|key|value)": r"encoder.model.layers_\1.blocks.\2.attn.\3", + r"backbone.encoder.layers.(\d+).blocks.(\d+).attention.output.dense": r"encoder.model.layers_\1.blocks.\2.attn.proj", + r"backbone.encoder.layers.(\d+).blocks.(\d+).layernorm_after": r"encoder.model.layers_\1.blocks.\2.norm2", + r"backbone.encoder.layers.(\d+).blocks.(\d+).intermediate.dense": r"encoder.model.layers_\1.blocks.\2.mlp.fc1", + r"backbone.encoder.layers.(\d+).blocks.(\d+).output.dense": r"encoder.model.layers_\1.blocks.\2.mlp.fc2", + r"backbone.encoder.layers.(\d+).downsample.reduction": lambda x: f"encoder.model.layers_{1 + int(x.group(1))}.downsample.reduction", + r"backbone.encoder.layers.(\d+).downsample.norm": lambda x: f"encoder.model.layers_{1 + int(x.group(1))}.downsample.norm", +} + +DECODER_MAPPING = { + + # started from 1 in hf + r"backbone.hidden_states_norms.stage(\d+)": lambda x: f"decoder.feature_norms.{int(x.group(1)) - 1}", + + r"decode_head.psp_modules.(\d+).(\d+).conv.weight": r"decoder.psp.blocks.\1.\2.0.weight", + r"decode_head.psp_modules.(\d+).(\d+).batch_norm": r"decoder.psp.blocks.\1.\2.1", + r"decode_head.bottleneck.conv.weight": r"decoder.psp.out_conv.0.weight", + r"decode_head.bottleneck.batch_norm": r"decoder.psp.out_conv.1", + + # fpn blocks are in reverse order (3 blocks total, so 2 - i) + r"decode_head.lateral_convs.(\d+).conv.weight": lambda x: f"decoder.fpn_lateral_blocks.{2 - int(x.group(1))}.conv_norm_relu.0.weight", + r"decode_head.lateral_convs.(\d+).batch_norm": lambda x: f"decoder.fpn_lateral_blocks.{2 - int(x.group(1))}.conv_norm_relu.1", + r"decode_head.fpn_convs.(\d+).conv.weight": lambda x: f"decoder.fpn_conv_blocks.{2 - int(x.group(1))}.0.weight", + r"decode_head.fpn_convs.(\d+).batch_norm": lambda x: f"decoder.fpn_conv_blocks.{2 - int(x.group(1))}.1", + + r"decode_head.fpn_bottleneck.conv.weight": r"decoder.fusion_block.0.weight", + r"decode_head.fpn_bottleneck.batch_norm": r"decoder.fusion_block.1", + r"decode_head.classifier": r"segmentation_head.0", +} +# fmt: on + +PRETRAINED_CHECKPOINTS = { + "convnext-tiny": { + "repo_id": "openmmlab/upernet-convnext-tiny", + "encoder_name": "tu-convnext_tiny", + "decoder_channels": 512, + "classes": 150, + "mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING}, + }, + "convnext-small": { + "repo_id": "openmmlab/upernet-convnext-small", + "encoder_name": "tu-convnext_small", + "decoder_channels": 512, + "classes": 150, + "mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING}, + }, + "convnext-base": { + "repo_id": "openmmlab/upernet-convnext-base", + "encoder_name": "tu-convnext_base", + "decoder_channels": 512, + "classes": 150, + "mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING}, + }, + "convnext-large": { + "repo_id": "openmmlab/upernet-convnext-large", + "encoder_name": "tu-convnext_large", + "decoder_channels": 512, + "classes": 150, + "mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING}, + }, + "convnext-xlarge": { + "repo_id": "openmmlab/upernet-convnext-xlarge", + "encoder_name": "tu-convnext_xlarge", + "decoder_channels": 512, + "classes": 150, + "mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING}, + }, + "swin-tiny": { + "repo_id": "openmmlab/upernet-swin-tiny", + "encoder_name": "tu-swin_tiny_patch4_window7_224", + "decoder_channels": 512, + "classes": 150, + "extra_kwargs": {"img_size": 512}, + "mapping": {**SWIN_MAPPING, **DECODER_MAPPING}, + }, + "swin-small": { + "repo_id": "openmmlab/upernet-swin-small", + "encoder_name": "tu-swin_small_patch4_window7_224", + "decoder_channels": 512, + "classes": 150, + "extra_kwargs": {"img_size": 512}, + "mapping": {**SWIN_MAPPING, **DECODER_MAPPING}, + }, + "swin-large": { + "repo_id": "openmmlab/upernet-swin-large", + "encoder_name": "tu-swin_large_patch4_window12_384", + "decoder_channels": 512, + "classes": 150, + "extra_kwargs": {"img_size": 512}, + "mapping": {**SWIN_MAPPING, **DECODER_MAPPING}, + }, +} + + +def convert_old_keys_to_new_keys(state_dict_keys: dict, keys_mapping: dict): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in keys_mapping.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def group_qkv_layers(state_dict: dict) -> dict: + """Find corresponding layer names for query, key and value layers and stack them in a single layer""" + + state_dict = state_dict.copy() # shallow copy + + result = defaultdict(dict) + layer_names = list(state_dict.keys()) + qkv_names = ["query", "key", "value"] + for layer_name in layer_names: + for pattern in qkv_names: + if pattern in layer_name: + new_key = layer_name.replace(pattern, "qkv") + result[new_key][pattern] = state_dict.pop(layer_name) + break + + # merge them all + for new_key, patterns in result.items(): + state_dict[new_key] = torch.cat( + [patterns[qkv_name] for qkv_name in qkv_names], dim=0 + ) + + return state_dict + + +def convert_model(model_name: str, push_to_hub: bool = False): + params = PRETRAINED_CHECKPOINTS[model_name] + + print(f"Converting model: {model_name}") + print(f"Downloading weights from: {params['repo_id']}") + + hf_weights_path = hf_hub_download( + repo_id=params["repo_id"], filename="pytorch_model.bin" + ) + hf_state_dict = torch.load(hf_weights_path, weights_only=True) + print(f"Loaded HuggingFace state dict with {len(hf_state_dict)} keys") + + # Rename keys + keys_mapping = convert_old_keys_to_new_keys(hf_state_dict.keys(), params["mapping"]) + + smp_state_dict = {} + for old_key, new_key in keys_mapping.items(): + smp_state_dict[new_key] = hf_state_dict[old_key] + + # remove aux head + smp_state_dict = { + k: v for k, v in smp_state_dict.items() if "auxiliary_head." not in k + } + + # [swin] group qkv layers and remove `relative_position_index` + smp_state_dict = group_qkv_layers(smp_state_dict) + smp_state_dict = { + k: v for k, v in smp_state_dict.items() if "relative_position_index" not in k + } + + # Create model + print(f"Creating SMP UPerNet model with encoder: {params['encoder_name']}") + extra_kwargs = params.get("extra_kwargs", {}) + smp_model = smp.UPerNet( + encoder_name=params["encoder_name"], + encoder_weights=None, + decoder_channels=params["decoder_channels"], + classes=params["classes"], + **extra_kwargs, + ) + + print("Loading weights into SMP model...") + smp_model.load_state_dict(smp_state_dict, strict=True) + + # Check we can run the model + print("Verifying model with test inference...") + smp_model.eval() + sample = torch.ones(1, 3, 512, 512) + with torch.no_grad(): + output = smp_model(sample) + print(f"Test inference successful. Output shape: {output.shape}") + + # Save model with preprocessing + smp_repo_id = f"smp-hub/upernet-{model_name}" + print(f"Saving model to: {smp_repo_id}") + smp_model.save_pretrained(save_directory=smp_repo_id) + + transform = A.Compose( + [ + A.Resize(512, 512), + A.Normalize( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + max_pixel_value=1.0, + ), + ] + ) + transform.save_pretrained(save_directory=smp_repo_id) + + if push_to_hub: + print(f"Pushing model to HuggingFace Hub: {smp_repo_id}") + api = HfApi() + if not api.repo_exists(smp_repo_id): + api.create_repo(repo_id=smp_repo_id, repo_type="model") + api.upload_folder( + repo_id=smp_repo_id, + folder_path=smp_repo_id, + repo_type="model", + ) + + print(f"Conversion of {model_name} completed successfully!") + + +if __name__ == "__main__": + print(f"Starting conversion of {len(PRETRAINED_CHECKPOINTS)} UPerNet models") + for model_name in PRETRAINED_CHECKPOINTS.keys(): + convert_model(model_name, push_to_hub=True) + print("All conversions completed!") diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index f1807836..37c64ef6 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -1,5 +1,3 @@ -import warnings - from . import datasets from . import encoders from . import decoders @@ -16,6 +14,7 @@ from .decoders.pan import PAN from .decoders.upernet import UPerNet from .decoders.segformer import Segformer +from .decoders.dpt import DPT from .base.hub_mixin import from_pretrained from .__version__ import __version__ @@ -24,12 +23,6 @@ from typing import Optional as _Optional import torch as _torch -# Suppress the specific SyntaxWarning for `pretrainedmodels` -warnings.filterwarnings("ignore", message="is with a literal", category=SyntaxWarning) -warnings.filterwarnings( - "ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning -) # for python >= 3.12 - _MODEL_ARCHITECTURES = [ Unet, UnetPlusPlus, @@ -42,6 +35,7 @@ PAN, UPerNet, Segformer, + DPT, ] MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES} @@ -92,6 +86,7 @@ def create_model( "PAN", "UPerNet", "Segformer", + "DPT", "from_pretrained", "create_model", "__version__", diff --git a/segmentation_models_pytorch/__version__.py b/segmentation_models_pytorch/__version__.py index b87975ee..3d187266 100644 --- a/segmentation_models_pytorch/__version__.py +++ b/segmentation_models_pytorch/__version__.py @@ -1,3 +1 @@ -VERSION = (0, 4, 0) - -__version__ = ".".join(map(str, VERSION)) +__version__ = "0.5.0" diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index 360aa521..a18380d1 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -1,3 +1,4 @@ +import torch import json from pathlib import Path from typing import Optional, Union @@ -114,12 +115,15 @@ def save_pretrained( return result @property + @torch.jit.unused def config(self) -> dict: return self._hub_mixin_config @wraps(PyTorchModelHubMixin.from_pretrained) -def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): +def from_pretrained( + pretrained_model_name_or_path: str, *args, strict: bool = True, **kwargs +): config_path = Path(pretrained_model_name_or_path) / "config.json" if not config_path.exists(): config_path = hf_hub_download( @@ -135,7 +139,9 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): import segmentation_models_pytorch as smp model_class = getattr(smp, model_class_name) - return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + return model_class.from_pretrained( + pretrained_model_name_or_path, *args, **kwargs, strict=strict + ) def supports_config_loading(func): diff --git a/segmentation_models_pytorch/base/initialization.py b/segmentation_models_pytorch/base/initialization.py index 4bea4aa6..cf518edd 100644 --- a/segmentation_models_pytorch/base/initialization.py +++ b/segmentation_models_pytorch/base/initialization.py @@ -8,7 +8,9 @@ def initialize_decoder(module): if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): + elif isinstance( + m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm2d) + ): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 6d7bf643..9b0db714 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -1,16 +1,29 @@ import torch +import warnings +from typing import TypeVar, Type from . import initialization as init from .hub_mixin import SMPHubMixin +from .utils import is_torch_compiling + +T = TypeVar("T", bound="SegmentationModel") class SegmentationModel(torch.nn.Module, SMPHubMixin): """Base class for all segmentation models.""" - # if model supports shape not divisible by 2 ^ n - # set to False + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + + # if model supports shape not divisible by 2 ^ n set to False requires_divisible_input_shape = True + # Fix type-hint for models, to avoid HubMixin signature + def __new__(cls: Type[T], *args, **kwargs) -> T: + instance = super().__new__(cls, *args, **kwargs) + return instance + def initialize(self): init.initialize_decoder(self.decoder) init.initialize_head(self.segmentation_head) @@ -21,6 +34,9 @@ def check_input_shape(self, x): """Check if the input shape is divisible by the output stride. If not, raise a RuntimeError. """ + if not self.requires_divisible_input_shape: + return + h, w = x.shape[-2:] output_stride = self.encoder.output_stride if h % output_stride != 0 or w % output_stride != 0: @@ -42,11 +58,13 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if not torch.jit.is_tracing() or self.requires_divisible_input_shape: + if not ( + torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling() + ): self.check_input_shape(x) features = self.encoder(x) - decoder_output = self.decoder(*features) + decoder_output = self.decoder(features) masks = self.segmentation_head(decoder_output) @@ -69,7 +87,53 @@ def predict(self, x): """ if self.training: self.eval() + x = self(x) + return x - x = self.forward(x) + def load_state_dict(self, state_dict, **kwargs): + # for compatibility of weights for + # timm- ported encoders with TimmUniversalEncoder + from segmentation_models_pytorch.encoders import TimmUniversalEncoder - return x + if isinstance(self.encoder, TimmUniversalEncoder): + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + is_deprecated_encoder = any( + self.encoder.name.startswith(pattern) for pattern in patterns + ) + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("encoder.") and not key.startswith( + "encoder.model." + ): + new_key = "encoder.model." + key.removeprefix("encoder.") + if "gernet" in self.encoder.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + # To be able to load weight with mismatched sizes + # We are going to filter mismatched sizes as well if strict=False + strict = kwargs.get("strict", True) + if not strict: + mismatched_keys = [] + model_state_dict = self.state_dict() + common_keys = set(model_state_dict.keys()) & set(state_dict.keys()) + for key in common_keys: + if model_state_dict[key].shape != state_dict[key].shape: + mismatched_keys.append( + (key, model_state_dict[key].shape, state_dict[key].shape) + ) + state_dict.pop(key) + + if mismatched_keys: + str_keys = "\n".join( + [ + f" - {key}: {s} (weights) -> {m} (model)" + for key, m, s in mismatched_keys + ] + ) + text = f"\n\n !!!!!! Mismatched keys !!!!!!\n\nYou should TRAIN the model to use it:\n{str_keys}\n" + warnings.warn(text, stacklevel=-1) + + return super().load_state_dict(state_dict, **kwargs) diff --git a/segmentation_models_pytorch/base/modules.py b/segmentation_models_pytorch/base/modules.py index cbd643b6..15cfdb12 100644 --- a/segmentation_models_pytorch/base/modules.py +++ b/segmentation_models_pytorch/base/modules.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Union + import torch import torch.nn as nn @@ -7,43 +9,109 @@ InPlaceABN = None +def get_norm_layer( + use_norm: Union[bool, str, Dict[str, Any]], out_channels: int +) -> nn.Module: + supported_norms = ("inplace", "batchnorm", "identity", "layernorm", "instancenorm") + + # Step 1. Convert tot dict representation + + ## Check boolean + if use_norm is True: + norm_params = {"type": "batchnorm"} + elif use_norm is False: + norm_params = {"type": "identity"} + + ## Check string + elif isinstance(use_norm, str): + norm_str = use_norm.lower() + if norm_str == "inplace": + norm_params = { + "type": "inplace", + "activation": "leaky_relu", + "activation_param": 0.0, + } + elif norm_str in supported_norms: + norm_params = {"type": norm_str} + else: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in " + f"{supported_norms}" + ) + + ## Check dict + elif isinstance(use_norm, dict): + norm_params = use_norm + + else: + raise ValueError( + f"Invalid type for use_norm should either be a bool (batchnorm/identity), " + f"a string in {supported_norms}, or a dict like {{'type': 'batchnorm', **kwargs}}" + ) + + # Step 2. Check if the dict is valid + if "type" not in norm_params: + raise ValueError( + f"Malformed dictionary given in use_norm: {use_norm}. Should contain key 'type'." + ) + if norm_params["type"] not in supported_norms: + raise ValueError( + f"Unrecognized normalization type string provided: {use_norm}. Should be in {supported_norms}" + ) + if norm_params["type"] == "inplace" and InPlaceABN is None: + raise RuntimeError( + "In order to use `use_norm='inplace'` the inplace_abn package must be installed. Use:\n" + " $ pip install -U wheel setuptools\n" + " $ pip install inplace_abn --no-build-isolation\n" + "Also see: https://github.com/mapillary/inplace_abn" + ) + + # Step 3. Initialize the norm layer + norm_type = norm_params["type"] + norm_kwargs = {k: v for k, v in norm_params.items() if k != "type"} + + if norm_type == "inplace": + norm = InPlaceABN(out_channels, **norm_kwargs) + elif norm_type == "batchnorm": + norm = nn.BatchNorm2d(out_channels, **norm_kwargs) + elif norm_type == "identity": + norm = nn.Identity() + elif norm_type == "layernorm": + norm = nn.LayerNorm(out_channels, **norm_kwargs) + elif norm_type == "instancenorm": + norm = nn.InstanceNorm2d(out_channels, **norm_kwargs) + else: + raise ValueError(f"Unrecognized normalization type: {norm_type}") + + return norm + + class Conv2dReLU(nn.Sequential): def __init__( self, - in_channels, - out_channels, - kernel_size, - padding=0, - stride=1, - use_batchnorm=True, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + stride: int = 1, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): - if use_batchnorm == "inplace" and InPlaceABN is None: - raise RuntimeError( - "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " - + "To install see: https://github.com/mapillary/inplace_abn" - ) + norm = get_norm_layer(use_norm, out_channels) + is_identity = isinstance(norm, nn.Identity) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, - bias=not (use_batchnorm), + bias=is_identity, ) - relu = nn.ReLU(inplace=True) - - if use_batchnorm == "inplace": - bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) - relu = nn.Identity() - elif use_batchnorm and use_batchnorm != "inplace": - bn = nn.BatchNorm2d(out_channels) - - else: - bn = nn.Identity() + is_inplaceabn = InPlaceABN is not None and isinstance(norm, InPlaceABN) + activation = nn.Identity() if is_inplaceabn else nn.ReLU(inplace=True) - super(Conv2dReLU, self).__init__(conv, bn, relu) + super(Conv2dReLU, self).__init__(conv, norm, activation) class SCSEModule(nn.Module): diff --git a/segmentation_models_pytorch/base/utils.py b/segmentation_models_pytorch/base/utils.py new file mode 100644 index 00000000..a0d41943 --- /dev/null +++ b/segmentation_models_pytorch/base/utils.py @@ -0,0 +1,14 @@ +import torch + + +@torch.jit.unused +def is_torch_compiling(): + try: + return torch.compiler.is_compiling() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_compiling() + except Exception: + return False diff --git a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py index 3fd73786..6a801a70 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/decoder.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/decoder.py @@ -31,7 +31,7 @@ """ from collections.abc import Iterable, Sequence -from typing import Literal +from typing import Literal, List import torch from torch import nn @@ -40,7 +40,7 @@ __all__ = ["DeepLabV3Decoder", "DeepLabV3PlusDecoder"] -class DeepLabV3Decoder(nn.Sequential): +class DeepLabV3Decoder(nn.Module): def __init__( self, in_channels: int, @@ -49,21 +49,25 @@ def __init__( aspp_separable: bool, aspp_dropout: float, ): - super().__init__( - ASPP( - in_channels, - out_channels, - atrous_rates, - separable=aspp_separable, - dropout=aspp_dropout, - ), - nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), + super().__init__() + self.aspp = ASPP( + in_channels, + out_channels, + atrous_rates, + separable=aspp_separable, + dropout=aspp_dropout, ) + self.conv = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU() - def forward(self, *features): - return super().forward(features[-1]) + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + x = features[-1] + x = self.aspp(x) + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x class DeepLabV3PlusDecoder(nn.Module): @@ -124,7 +128,7 @@ def __init__( nn.ReLU(), ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: aspp_features = self.aspp(features[-1]) aspp_features = self.up(aspp_features) high_res_features = self.block1(features[2]) @@ -174,7 +178,7 @@ def __init__(self, in_channels: int, out_channels: int): nn.ReLU(), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: size = x.shape[-2:] for mod in self: x = mod(x) @@ -216,7 +220,7 @@ def __init__( nn.Dropout(dropout), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: res = [] for conv in self.convs: res.append(conv(x)) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 654e38d4..38ca9e04 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -34,8 +34,7 @@ class DeepLabV3(SegmentationModel): classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: @@ -121,6 +120,21 @@ def __init__( else: self.classification_head = None + def load_state_dict(self, state_dict, *args, **kwargs): + # For backward compatibility, previously Decoder module was Sequential + # and was not scriptable. + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if key.startswith("decoder.0."): + new_key = key.replace("decoder.0.", "decoder.aspp.") + elif key.startswith("decoder.1."): + new_key = key.replace("decoder.1.", "decoder.conv.") + elif key.startswith("decoder.2."): + new_key = key.replace("decoder.2.", "decoder.bn.") + state_dict[new_key] = state_dict.pop(key) + return super().load_state_dict(state_dict, *args, **kwargs) + class DeepLabV3Plus(SegmentationModel): """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable @@ -144,8 +158,7 @@ class DeepLabV3Plus(SegmentationModel): classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: diff --git a/segmentation_models_pytorch/decoders/dpt/__init__.py b/segmentation_models_pytorch/decoders/dpt/__init__.py new file mode 100644 index 00000000..c729fe90 --- /dev/null +++ b/segmentation_models_pytorch/decoders/dpt/__init__.py @@ -0,0 +1,3 @@ +from .model import DPT + +__all__ = ["DPT"] diff --git a/segmentation_models_pytorch/decoders/dpt/decoder.py b/segmentation_models_pytorch/decoders/dpt/decoder.py new file mode 100644 index 00000000..345ecca1 --- /dev/null +++ b/segmentation_models_pytorch/decoders/dpt/decoder.py @@ -0,0 +1,320 @@ +import torch +import torch.nn as nn +from segmentation_models_pytorch.base.modules import Activation +from typing import Optional, Sequence, Union, Callable, Literal + + +class ReadoutConcatBlock(nn.Module): + """ + Concatenates the cls tokens with the features to make use of the global information aggregated in the prefix (cls) tokens. + Projects the combined feature map to the original embedding dimension using a MLP. + + According to: + https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L79-L90 + """ + + def __init__(self, embed_dim: int, has_prefix_tokens: bool): + super().__init__() + in_features = embed_dim * 2 if has_prefix_tokens else embed_dim + out_features = embed_dim + self.project = nn.Sequential( + nn.Linear(in_features, out_features), + nn.GELU(), + ) + + def forward( + self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None + ) -> torch.Tensor: + batch_size, embed_dim, height, width = features.shape + + # Rearrange to (batch_size, height * width, embed_dim) + features = features.view(batch_size, embed_dim, -1) + features = features.transpose(1, 2).contiguous() + + if prefix_tokens is not None: + # (batch_size, num_prefix_tokens, embed_dim) -> (batch_size, 1, embed_dim) + prefix_tokens = prefix_tokens[:, :1].expand_as(features) + features = torch.cat([features, prefix_tokens], dim=2) + + # Project to embedding dimension + features = self.project(features) + + # Rearrange back to (batch_size, embed_dim, height, width) + features = features.transpose(1, 2) + features = features.view(batch_size, -1, height, width) + + return features + + +class ReadoutAddBlock(nn.Module): + """ + Adds the prefix tokens to the features to make use of the global information aggregated in the prefix (cls) tokens. + + According to: + https://github.com/isl-org/DPT/blob/cd3fe90bb4c48577535cc4d51b602acca688a2ee/dpt/vit.py#L71-L76 + """ + + def forward( + self, features: torch.Tensor, prefix_tokens: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if prefix_tokens is not None: + batch_size, embed_dim, height, width = features.shape + prefix_tokens = prefix_tokens.mean(dim=1) + prefix_tokens = prefix_tokens.view(batch_size, embed_dim, 1, 1) + features = features + prefix_tokens + return features + + +class ReadoutIgnoreBlock(nn.Module): + """ + Ignores the prefix tokens and returns the features as is. + """ + + def forward(self, features: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return features + + +class ReassembleBlock(nn.Module): + """ + Processes the features such that they have progressively increasing embedding size and progressively decreasing + spatial dimension + """ + + def __init__( + self, + in_channels: int, + mid_channels: int, + out_channels: int, + upsample_factor: int, + ): + super().__init__() + + self.project_to_out_channel = nn.Conv2d( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + ) + + if upsample_factor > 1.0: + self.upsample = nn.ConvTranspose2d( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=int(upsample_factor), + stride=int(upsample_factor), + ) + elif upsample_factor == 1.0: + self.upsample = nn.Identity() + else: + self.upsample = nn.Conv2d( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=3, + stride=int(1 / upsample_factor), + padding=1, + ) + + self.project_to_feature_dim = nn.Conv2d( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_to_out_channel(x) + x = self.upsample(x) + x = self.project_to_feature_dim(x) + return x + + +class ResidualConvBlock(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + + self.conv_1 = nn.Conv2d( + in_channels=feature_dim, + out_channels=feature_dim, + kernel_size=3, + padding=1, + bias=False, + ) + self.batch_norm_1 = nn.BatchNorm2d(num_features=feature_dim) + self.conv_2 = nn.Conv2d( + in_channels=feature_dim, + out_channels=feature_dim, + kernel_size=3, + padding=1, + bias=False, + ) + self.batch_norm_2 = nn.BatchNorm2d(num_features=feature_dim) + self.activation = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + # Block 1 + x = self.activation(x) + x = self.conv_1(x) + x = self.batch_norm_1(x) + + # Block 2 + x = self.activation(x) + x = self.conv_2(x) + x = self.batch_norm_2(x) + + # Add residual + x = x + residual + + return x + + +class FusionBlock(nn.Module): + """ + Fuses the processed encoder features in a residual manner and upsamples them + """ + + def __init__(self, feature_dim: int): + super().__init__() + self.residual_conv_block1 = ResidualConvBlock(feature_dim) + self.residual_conv_block2 = ResidualConvBlock(feature_dim) + self.project = nn.Conv2d(feature_dim, feature_dim, kernel_size=1) + self.activation = nn.ReLU() + + def forward( + self, + feature: torch.Tensor, + previous_feature: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + feature = self.residual_conv_block1(feature) + if previous_feature is not None: + feature = feature + previous_feature + feature = self.residual_conv_block2(feature) + feature = nn.functional.interpolate( + feature, scale_factor=2, align_corners=True, mode="bilinear" + ) + feature = self.project(feature) + return feature + + +class DPTDecoder(nn.Module): + """ + Decoder part for DPT + + Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of + [1/4, 1/8, 1/16, 1/32, ...] relative to the input image spatial dimension. + + The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the + output has a downsampling ratio of 1/2 relative to the input image spatial dimension + + """ + + def __init__( + self, + encoder_out_channels: Sequence[int] = (756, 756, 756, 756), + encoder_output_strides: Sequence[int] = (16, 16, 16, 16), + encoder_has_prefix_tokens: bool = True, + readout: Literal["cat", "add", "ignore"] = "cat", + intermediate_channels: Sequence[int] = (256, 512, 1024, 1024), + fusion_channels: int = 256, + ): + super().__init__() + + if not ( + len(encoder_out_channels) + == len(encoder_output_strides) + == len(intermediate_channels) + ): + raise ValueError( + "encoder_out_channels, encoder_output_strides and intermediate_channels must have the same length" + ) + + num_blocks = len(encoder_out_channels) + + # If encoder has prefix tokens (e.g. cls_token), then we can concat/add/ignore them + # according to the readout mode + if readout == "cat": + blocks = [ + ReadoutConcatBlock(in_channels, encoder_has_prefix_tokens) + for in_channels in encoder_out_channels + ] + elif readout == "add": + blocks = [ReadoutAddBlock() for _ in encoder_out_channels] + elif readout == "ignore": + blocks = [ReadoutIgnoreBlock() for _ in encoder_out_channels] + else: + raise ValueError( + f"Invalid readout mode: {readout}, should be one of: 'cat', 'add', 'ignore'" + ) + self.projection_blocks = nn.ModuleList(blocks) + + # Upsample factors to resize features to [1/4, 1/8, 1/16, 1/32, ...] scales + scale_factors = [ + stride / 2 ** (i + 2) for i, stride in enumerate(encoder_output_strides) + ] + self.reassemble_blocks = nn.ModuleList() + for i in range(num_blocks): + block = ReassembleBlock( + in_channels=encoder_out_channels[i], + mid_channels=intermediate_channels[i], + out_channels=fusion_channels, + upsample_factor=scale_factors[i], + ) + self.reassemble_blocks.append(block) + + # Fusion blocks to fuse the processed features in a sequential manner + fusion_blocks = [FusionBlock(fusion_channels) for _ in range(num_blocks)] + self.fusion_blocks = nn.ModuleList(fusion_blocks) + + def forward( + self, features: list[torch.Tensor], prefix_tokens: list[Optional[torch.Tensor]] + ) -> torch.Tensor: + # Process the encoder features to scale of [1/4, 1/8, 1/16, 1/32, ...] + processed_features = [] + for i, (feature, prefix_tokens_i) in enumerate(zip(features, prefix_tokens)): + projected_feature = self.projection_blocks[i](feature, prefix_tokens_i) + processed_feature = self.reassemble_blocks[i](projected_feature) + processed_features.append(processed_feature) + + # Fusion and progressive upsampling starting from the last processed feature + processed_features = processed_features[::-1] + fused_feature = None + for fusion_block, feature in zip(self.fusion_blocks, processed_features): + fused_feature = fusion_block(feature, fused_feature) + + return fused_feature + + +class DPTSegmentationHead(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Optional[Union[str, Callable]] = None, + kernel_size: int = 3, + upsampling: float = 2.0, + ): + super().__init__() + + self.head = nn.Sequential( + nn.Conv2d( + in_channels, in_channels, kernel_size=kernel_size, padding=1, bias=False + ), + nn.BatchNorm2d(in_channels), + nn.ReLU(inplace=True), + nn.Dropout(p=0.1, inplace=False), + nn.Conv2d(in_channels, out_channels, kernel_size=1), + ) + self.activation = Activation(activation) + self.upsampling_factor = upsampling + + def forward(self, x: torch.Tensor) -> torch.Tensor: + head_output = self.head(x) + resized_output = nn.functional.interpolate( + head_output, + scale_factor=self.upsampling_factor, + mode="bilinear", + align_corners=True, + ) + activation_output = self.activation(resized_output) + return activation_output diff --git a/segmentation_models_pytorch/decoders/dpt/model.py b/segmentation_models_pytorch/decoders/dpt/model.py new file mode 100644 index 00000000..1294dd4f --- /dev/null +++ b/segmentation_models_pytorch/decoders/dpt/model.py @@ -0,0 +1,167 @@ +import warnings +from typing import Any, Optional, Union, Callable, Sequence, Literal + +import torch + +from segmentation_models_pytorch.base import ( + ClassificationHead, + SegmentationModel, +) +from segmentation_models_pytorch.encoders.timm_vit import TimmViTEncoder +from segmentation_models_pytorch.base.utils import is_torch_compiling +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading +from .decoder import DPTDecoder, DPTSegmentationHead + + +class DPT(SegmentationModel): + """ + DPT is a dense prediction architecture that leverages vision transformers in place of convolutional networks as + a backbone for dense prediction tasks + + It assembles tokens from various stages of the vision transformer into image-like representations at various resolutions + and progressively combines them into full-resolution predictions using a convolutional decoder. + + The transformer backbone processes representations at a constant and relatively high resolution and has a global receptive + field at every stage. These properties allow the dense vision transformer to provide finer-grained and more globally coherent + predictions when compared to fully-convolutional networks + + Note: + Since this model uses a Vision Transformer backbone, it typically requires a fixed input image size. + To handle variable input sizes, you can set `dynamic_img_size=True` in the model initialization + (if supported by the specific `timm` encoder). You can check if an encoder requires fixed size + using `model.encoder.is_fixed_input_size`, and get the required input dimensions from + `model.encoder.input_size`, however it's no guarantee that information is available. + + Args: + encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) + to extract features of different spatial resolution. + encoder_depth: A number of stages used in encoder in range [1,4]. Each stage generate features + smaller by a factor equal to the ViT model patch_size in spatial dimensions. + Default is 4. + encoder_weights: One of **None** (random initialization), or not **None** (pretrained weights would be loaded + with respect to the encoder_name, e.g. for ``"tu-vit_base_patch16_224.augreg_in21k"`` - ``"augreg_in21k"`` + weights would be loaded). + encoder_output_indices: The indices of the encoder output features to use. If **None** will be sampled uniformly + across the number of blocks in encoder, e.g. if number of blocks is 4 and encoder has 20 blocks, then + encoder_output_indices will be (4, 9, 14, 19). If specified the number of indices should be equal to + encoder_depth. Default is **None**. + decoder_readout: The strategy to utilize the prefix tokens (e.g. cls_token) from the encoder. + Can be one of **"cat"**, **"add"**, or **"ignore"**. Default is **"cat"**. + decoder_intermediate_channels: The number of channels for the intermediate decoder layers. Reduce if you + want to reduce the number of parameters in the decoder. Default is (256, 512, 1024, 1024). + decoder_fusion_channels: The latent dimension to which the encoder features will be projected to before fusion. + Default is 256. + in_channels: Number of input channels for the model, default is 3 (RGB images) + classes: Number of classes for output mask (or you can think as a number of channels of output mask) + activation: An activation function to apply after the final convolution layer. + Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, + **callable** and **None**. Default is **None**. + aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build + on top of encoder if **aux_params** is not **None** (default). Supported params: + + - **classes** (*int*): A number of classes; + - **pooling** (*str*): One of "max", "avg". Default is "avg"; + - **dropout** (*float*): Dropout factor in [0, 1); + - **activation** (*str*): An activation function to apply "sigmoid"/"softmax" (could be **None** to return logits). + kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with + ``None`` values are pruned before passing. Specify ``dynamic_img_size=True`` to allow the model to handle images of different sizes. + + Returns: + ``torch.nn.Module``: DPT + + """ + + # fails for encoders with prefix tokens + _is_torch_scriptable = False + _is_torch_compilable = True + requires_divisible_input_shape = True + + @supports_config_loading + def __init__( + self, + encoder_name: str = "tu-vit_base_patch16_224.augreg_in21k", + encoder_depth: int = 4, + encoder_weights: Optional[str] = "imagenet", + encoder_output_indices: Optional[list[int]] = None, + decoder_readout: Literal["ignore", "add", "cat"] = "cat", + decoder_intermediate_channels: Sequence[int] = (256, 512, 1024, 1024), + decoder_fusion_channels: int = 256, + in_channels: int = 3, + classes: int = 1, + activation: Optional[Union[str, Callable]] = None, + aux_params: Optional[dict] = None, + **kwargs: dict[str, Any], + ): + super().__init__() + if encoder_name.startswith("tu-"): + encoder_name = encoder_name[3:] + else: + raise ValueError( + f"Only Timm encoders are supported for DPT. Encoder name must start with 'tu-', got {encoder_name}" + ) + + if decoder_readout not in ["ignore", "add", "cat"]: + raise ValueError( + f"Invalid decoder readout mode. Must be one of: 'ignore', 'add', 'cat'. Got: {decoder_readout}" + ) + + self.encoder = TimmViTEncoder( + name=encoder_name, + in_channels=in_channels, + depth=encoder_depth, + pretrained=encoder_weights is not None, + output_indices=encoder_output_indices, + **kwargs, + ) + + if not self.encoder.has_prefix_tokens and decoder_readout != "ignore": + warnings.warn( + f"Encoder does not have prefix tokens (e.g. cls_token), but `decoder_readout` is set to '{decoder_readout}'. " + f"It's recommended to set `decoder_readout='ignore'` when using a encoder without prefix tokens.", + UserWarning, + ) + + self.decoder = DPTDecoder( + encoder_out_channels=self.encoder.out_channels, + encoder_output_strides=self.encoder.output_strides, + encoder_has_prefix_tokens=self.encoder.has_prefix_tokens, + readout=decoder_readout, + intermediate_channels=decoder_intermediate_channels, + fusion_channels=decoder_fusion_channels, + ) + + self.segmentation_head = DPTSegmentationHead( + in_channels=decoder_fusion_channels, + out_channels=classes, + activation=activation, + kernel_size=3, + upsampling=2, + ) + + if aux_params is not None: + self.classification_head = ClassificationHead( + in_channels=self.encoder.out_channels[-1], **aux_params + ) + else: + self.classification_head = None + + self.name = "dpt-{}".format(encoder_name) + self.initialize() + + def forward(self, x): + """Sequentially pass `x` trough model`s encoder, decoder and heads""" + + if not ( + torch.jit.is_scripting() or torch.jit.is_tracing() or is_torch_compiling() + ): + self.check_input_shape(x) + + features, prefix_tokens = self.encoder(x) + decoder_output = self.decoder(features, prefix_tokens) + masks = self.segmentation_head(decoder_output) + + if self.classification_head is not None: + labels = self.classification_head(features[-1]) + return masks, labels + + return masks diff --git a/segmentation_models_pytorch/decoders/fpn/decoder.py b/segmentation_models_pytorch/decoders/fpn/decoder.py index 766190f4..b111843a 100644 --- a/segmentation_models_pytorch/decoders/fpn/decoder.py +++ b/segmentation_models_pytorch/decoders/fpn/decoder.py @@ -2,9 +2,11 @@ import torch.nn as nn import torch.nn.functional as F +from typing import List, Literal + class Conv3x3GNReLU(nn.Module): - def __init__(self, in_channels, out_channels, upsample=False): + def __init__(self, in_channels: int, out_channels: int, upsample: bool = False): super().__init__() self.upsample = upsample self.block = nn.Sequential( @@ -15,27 +17,33 @@ def __init__(self, in_channels, out_channels, upsample=False): nn.ReLU(inplace=True), ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.block(x) if self.upsample: - x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + x = F.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=True) return x class FPNBlock(nn.Module): - def __init__(self, pyramid_channels, skip_channels): + def __init__( + self, + pyramid_channels: int, + skip_channels: int, + interpolation_mode: str = "nearest", + ): super().__init__() self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) + self.interpolation_mode = interpolation_mode - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") + def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) skip = self.skip_conv(skip) x = x + skip return x class SegmentationBlock(nn.Module): - def __init__(self, in_channels, out_channels, n_upsamples=0): + def __init__(self, in_channels: int, out_channels: int, n_upsamples: int = 0): super().__init__() blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] @@ -51,7 +59,7 @@ def forward(self, x): class MergeBlock(nn.Module): - def __init__(self, policy): + def __init__(self, policy: Literal["add", "cat"]): super().__init__() if policy not in ["add", "cat"]: raise ValueError( @@ -59,28 +67,30 @@ def __init__(self, policy): ) self.policy = policy - def forward(self, x): + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: if self.policy == "add": - return sum(x) + output = torch.stack(x).sum(dim=0) elif self.policy == "cat": - return torch.cat(x, dim=1) + output = torch.cat(x, dim=1) else: raise ValueError( "`merge_policy` must be one of: ['add', 'cat'], got {}".format( self.policy ) ) + return output class FPNDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - pyramid_channels=256, - segmentation_channels=128, - dropout=0.2, - merge_policy="add", + encoder_channels: List[int], + encoder_depth: int = 5, + pyramid_channels: int = 256, + segmentation_channels: int = 128, + dropout: float = 0.2, + merge_policy: Literal["add", "cat"] = "add", + interpolation_mode: str = "nearest", ): super().__init__() @@ -100,9 +110,9 @@ def __init__( encoder_channels = encoder_channels[: encoder_depth + 1] self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) - self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) - self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) - self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) + self.p4 = FPNBlock(pyramid_channels, encoder_channels[1], interpolation_mode) + self.p3 = FPNBlock(pyramid_channels, encoder_channels[2], interpolation_mode) + self.p2 = FPNBlock(pyramid_channels, encoder_channels[3], interpolation_mode) self.seg_blocks = nn.ModuleList( [ @@ -116,7 +126,7 @@ def __init__( self.merge = MergeBlock(merge_policy) self.dropout = nn.Dropout2d(p=dropout, inplace=True) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: c2, c3, c4, c5 = features[-4:] p5 = self.p5(c5) @@ -124,9 +134,12 @@ def forward(self, *features): p3 = self.p3(p4, c3) p2 = self.p2(p3, c2) - feature_pyramid = [ - seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2]) - ] + s5 = self.seg_blocks[0](p5) + s4 = self.seg_blocks[1](p4) + s3 = self.seg_blocks[2](p3) + s2 = self.seg_blocks[3](p2) + + feature_pyramid = [s5, s4, s3, s2] x = self.merge(feature_pyramid) x = self.dropout(x) diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py index 7420b289..6e37109a 100644 --- a/segmentation_models_pytorch/decoders/fpn/model.py +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -28,12 +28,13 @@ class FPN(SegmentationModel): decoder_merge_policy: Determines how to merge pyramid features inside FPN. Available options are **add** and **cat** decoder_dropout: Spatial dropout rate in range (0, 1) for feature pyramid in FPN_ + decoder_interpolation: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: @@ -62,6 +63,7 @@ def __init__( decoder_segmentation_channels: int = 128, decoder_merge_policy: str = "add", decoder_dropout: float = 0.2, + decoder_interpolation: str = "nearest", in_channels: int = 3, classes: int = 1, activation: Optional[str] = None, @@ -92,6 +94,7 @@ def __init__( segmentation_channels=decoder_segmentation_channels, dropout=decoder_dropout, merge_policy=decoder_merge_policy, + interpolation_mode=decoder_interpolation, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/linknet/decoder.py b/segmentation_models_pytorch/decoders/linknet/decoder.py index e16a32c8..95c7f9f6 100644 --- a/segmentation_models_pytorch/decoders/linknet/decoder.py +++ b/segmentation_models_pytorch/decoders/linknet/decoder.py @@ -1,26 +1,33 @@ +import torch import torch.nn as nn +from typing import Any, Dict, List, Optional, Union from segmentation_models_pytorch.base import modules class TransposeX2(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() - layers = [ - nn.ConvTranspose2d( - in_channels, out_channels, kernel_size=4, stride=2, padding=1 - ), - nn.ReLU(inplace=True), - ] - - if use_batchnorm: - layers.insert(1, nn.BatchNorm2d(out_channels)) - - super().__init__(*layers) + conv = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size=4, stride=2, padding=1 + ) + norm = modules.get_norm_layer(use_norm, out_channels) + activation = nn.ReLU(inplace=True) + super().__init__(conv, norm, activation) class DecoderBlock(nn.Module): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() self.block = nn.Sequential( @@ -28,20 +35,20 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True): in_channels, in_channels // 4, kernel_size=1, - use_batchnorm=use_batchnorm, - ), - TransposeX2( - in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm + use_norm=use_norm, ), + TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm), modules.Conv2dReLU( in_channels // 4, out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) - def forward(self, x, skip=None): + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = self.block(x) if skip is not None: x = x + skip @@ -50,7 +57,11 @@ def forward(self, x, skip=None): class LinknetDecoder(nn.Module): def __init__( - self, encoder_channels, prefinal_channels=32, n_blocks=5, use_batchnorm=True + self, + encoder_channels: List[int], + prefinal_channels: int = 32, + n_blocks: int = 5, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -63,12 +74,16 @@ def __init__( self.blocks = nn.ModuleList( [ - DecoderBlock(channels[i], channels[i + 1], use_batchnorm=use_batchnorm) + DecoderBlock( + channels[i], + channels[i + 1], + use_norm=use_norm, + ) for i in range(n_blocks) ] ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip features = features[::-1] # reverse channels to start from head of encoder diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 356468ed..38eac4c2 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, Union +import warnings +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -29,15 +30,27 @@ class Linknet(SegmentationModel): Default is 5 encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -60,10 +73,10 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -74,6 +87,15 @@ def __init__( "Encoder `{}` is not supported for Linknet".format(encoder_name) ) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -86,7 +108,7 @@ def __init__( encoder_channels=self.encoder.out_channels, n_blocks=encoder_depth, prefinal_channels=32, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/manet/decoder.py b/segmentation_models_pytorch/decoders/manet/decoder.py index 0f6af18d..39e117bf 100644 --- a/segmentation_models_pytorch/decoders/manet/decoder.py +++ b/segmentation_models_pytorch/decoders/manet/decoder.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -5,9 +7,10 @@ from segmentation_models_pytorch.base import modules as md -class PAB(nn.Module): - def __init__(self, in_channels, out_channels, pab_channels=64): - super(PAB, self).__init__() +class PABBlock(nn.Module): + def __init__(self, in_channels: int, pab_channels: int = 64): + super().__init__() + # Series of 1x1 conv to generate attention feature maps self.pab_channels = pab_channels self.in_channels = in_channels @@ -17,10 +20,9 @@ def __init__(self, in_channels, out_channels, pab_channels=64): self.map_softmax = nn.Softmax(dim=1) self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) - def forward(self, x): - bsize = x.size()[0] - h = x.size()[2] - w = x.size()[3] + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = x.shape + x_top = self.top_conv(x) x_center = self.center_conv(x) x_bottom = self.bottom_conv(x) @@ -30,30 +32,42 @@ def forward(self, x): x_bottom = x_bottom.flatten(2).transpose(1, 2) sp_map = torch.matmul(x_center, x_top) - sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h * w, h * w) + sp_map = self.map_softmax(sp_map.view(batch_size, -1)) + sp_map = sp_map.view(batch_size, height * width, height * width) + sp_map = torch.matmul(sp_map, x_bottom) - sp_map = sp_map.reshape(bsize, self.in_channels, h, w) + sp_map = sp_map.reshape(batch_size, self.in_channels, height, width) + x = x + sp_map x = self.out_conv(x) return x -class MFAB(nn.Module): +class MFABBlock(nn.Module): def __init__( - self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16 + self, + in_channels: int, + skip_channels: int, + out_channels: int, + interpolation_mode: str = "nearest", + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + reduction: int = 16, ): - # MFAB is just a modified version of SE-blocks, one for skip, one for input - super(MFAB, self).__init__() + # MFABBlock is just a modified version of SE-blocks, one for skip, one for input + super().__init__() self.hl_conv = nn.Sequential( md.Conv2dReLU( in_channels, in_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), md.Conv2dReLU( - in_channels, skip_channels, kernel_size=1, use_batchnorm=use_batchnorm + in_channels, + skip_channels, + kernel_size=1, + use_norm=use_norm, ), ) reduced_channels = max(1, skip_channels // reduction) @@ -77,19 +91,22 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) + self.interpolation_mode = interpolation_mode - def forward(self, x, skip=None): + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: x = self.hl_conv(x) - x = F.interpolate(x, scale_factor=2, mode="nearest") + x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) attention_hl = self.SE_hl(x) if skip is not None: attention_ll = self.SE_ll(skip) @@ -102,25 +119,35 @@ def forward(self, x, skip=None): class DecoderBlock(nn.Module): - def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True): + def __init__( + self, + in_channels: int, + skip_channels: int, + out_channels: int, + interpolation_mode: str = "nearest", + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) + self.interpolation_mode = interpolation_mode - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) @@ -131,12 +158,13 @@ def forward(self, x, skip=None): class MAnetDecoder(nn.Module): def __init__( self, - encoder_channels, - decoder_channels, - n_blocks=5, - reduction=16, - use_batchnorm=True, - pab_channels=64, + encoder_channels: List[int], + decoder_channels: List[int], + n_blocks: int = 5, + reduction: int = 16, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + pab_channels: int = 64, + interpolation_mode: str = "nearest", ): super().__init__() @@ -159,12 +187,14 @@ def __init__( skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels - self.center = PAB(head_channels, head_channels, pab_channels=pab_channels) + self.center = PABBlock(head_channels, pab_channels=pab_channels) # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here + kwargs = dict( + use_norm=use_norm, interpolation_mode=interpolation_mode + ) # no attention type here blocks = [ - MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) + MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) @@ -172,7 +202,7 @@ def __init__( # for the last we dont have skip connection -> use simple decoder block self.blocks = nn.ModuleList(blocks) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 6ed59207..568a7f58 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Union +import warnings +from typing import Any, Dict, Optional, Union, Sequence, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -29,17 +30,31 @@ class MAnet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` decoder_pab_channels: A number of channels for PAB module in decoder. Default is 64. + decoder_interpolation: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -63,17 +78,27 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_pab_channels: int = 64, + decoder_interpolation: str = "nearest", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -86,8 +111,9 @@ def __init__( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, pab_channels=decoder_pab_channels, + interpolation_mode=decoder_interpolation, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/pan/decoder.py b/segmentation_models_pytorch/decoders/pan/decoder.py index fa0bb261..729c76ed 100644 --- a/segmentation_models_pytorch/decoders/pan/decoder.py +++ b/segmentation_models_pytorch/decoders/pan/decoder.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Literal +from typing import Literal, List import torch import torch.nn as nn @@ -31,18 +31,22 @@ def __init__( bias=bias, groups=groups, ) + self.activation = nn.ReLU(inplace=True) + self.bn = nn.BatchNorm2d(out_channels) + self.add_relu = add_relu self.interpolate = interpolate - self.bn = nn.BatchNorm2d(out_channels) - self.activation = nn.ReLU(inplace=True) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.bn(x) + if self.add_relu: x = self.activation(x) + if self.interpolate: - x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + x = F.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=True) + return x @@ -50,7 +54,7 @@ class FPABlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" ): - super(FPABlock, self).__init__() + super().__init__() self.upscale_mode = upscale_mode if self.upscale_mode == "bilinear": @@ -70,7 +74,7 @@ def __init__( ), ) - # midddle branch + # middle branch self.mid = nn.Sequential( ConvBnRelu( in_channels=in_channels, @@ -112,41 +116,64 @@ def __init__( in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3 ) - def forward(self, x): - h, w = x.size(2), x.size(3) - b1 = self.branch1(x) - upscale_parameters = dict( - mode=self.upscale_mode, align_corners=self.align_corners + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, height, width = x.shape + + branch1_output = self.branch1(x) + branch1_output = F.interpolate( + branch1_output, + size=(height, width), + mode=self.upscale_mode, + align_corners=self.align_corners, ) - b1 = F.interpolate(b1, size=(h, w), **upscale_parameters) - mid = self.mid(x) + middle_output = self.mid(x) + x1 = self.down1(x) x2 = self.down2(x1) x3 = self.down3(x2) - x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters) + x3 = F.interpolate( + x3, + size=(height // 4, width // 4), + mode=self.upscale_mode, + align_corners=self.align_corners, + ) x2 = self.conv2(x2) x = x2 + x3 - x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters) + x = F.interpolate( + x, + size=(height // 2, width // 2), + mode=self.upscale_mode, + align_corners=self.align_corners, + ) x1 = self.conv1(x1) x = x + x1 - x = F.interpolate(x, size=(h, w), **upscale_parameters) + x = F.interpolate( + x, + size=(height, width), + mode=self.upscale_mode, + align_corners=self.align_corners, + ) + + x = torch.mul(x, middle_output) + x = x + branch1_output - x = torch.mul(x, mid) - x = x + b1 return x class GAUBlock(nn.Module): def __init__( - self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear" + self, + in_channels: int, + out_channels: int, + interpolation_mode: str = "bilinear", ): super(GAUBlock, self).__init__() - self.upscale_mode = upscale_mode - self.align_corners = True if upscale_mode == "bilinear" else None + self.interpolation_mode = interpolation_mode + self.align_corners = True if interpolation_mode == "bilinear" else None self.conv1 = nn.Sequential( nn.AdaptiveAvgPool2d(1), @@ -162,15 +189,18 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 ) - def forward(self, x, y): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Args: x: low level feature y: high level feature """ - h, w = x.size(2), x.size(3) + height, width = x.shape[2:] y_up = F.interpolate( - y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners + y, + size=(height, width), + mode=self.interpolation_mode, + align_corners=self.align_corners, ) x = self.conv2(x) y = self.conv1(y) @@ -184,7 +214,7 @@ def __init__( encoder_channels: Sequence[int], encoder_depth: Literal[3, 4, 5], decoder_channels: int, - upscale_mode: str = "bilinear", + interpolation_mode: str = "bilinear", ): super().__init__() @@ -205,22 +235,22 @@ def __init__( self.gau3 = GAUBlock( in_channels=encoder_channels[2], out_channels=decoder_channels, - upscale_mode=upscale_mode, + interpolation_mode=interpolation_mode, ) if encoder_depth >= 4: self.gau2 = GAUBlock( in_channels=encoder_channels[1], out_channels=decoder_channels, - upscale_mode=upscale_mode, + interpolation_mode=interpolation_mode, ) if encoder_depth >= 3: self.gau1 = GAUBlock( in_channels=encoder_channels[0], out_channels=decoder_channels, - upscale_mode=upscale_mode, + interpolation_mode=interpolation_mode, ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[2:] # remove first and second skip out = self.fpa(features[-1]) # 1/16 or 1/32 diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 6d5e78c2..0ea1dfbb 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Literal, Optional, Union +import warnings from segmentation_models_pytorch.base import ( ClassificationHead, @@ -30,12 +31,13 @@ class PAN(SegmentationModel): encoder_output_stride: 16 or 32, if 16 use dilation in encoder last layer. Doesn't work with ***ception***, **vgg***, **densenet*`** backbones.Default is 16. decoder_channels: A number of convolution layer filters in decoder blocks + decoder_interpolation: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"bilinear"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: @@ -62,6 +64,7 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_output_stride: Literal[16, 32] = 16, decoder_channels: int = 32, + decoder_interpolation: str = "bilinear", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, Callable]] = None, @@ -78,6 +81,15 @@ def __init__( ) ) + upscale_mode = kwargs.pop("upscale_mode", None) + if upscale_mode is not None: + warnings.warn( + "The usage of upscale_mode is deprecated. Please modify your code for decoder_interpolation", + DeprecationWarning, + stacklevel=2, + ) + decoder_interpolation = upscale_mode + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -91,6 +103,7 @@ def __init__( encoder_channels=self.encoder.out_channels, encoder_depth=encoder_depth, decoder_channels=decoder_channels, + interpolation_mode=decoder_interpolation, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/pspnet/decoder.py b/segmentation_models_pytorch/decoders/pspnet/decoder.py index 40d2e945..80ad289c 100644 --- a/segmentation_models_pytorch/decoders/pspnet/decoder.py +++ b/segmentation_models_pytorch/decoders/pspnet/decoder.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Tuple, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -6,26 +8,39 @@ class PSPBlock(nn.Module): - def __init__(self, in_channels, out_channels, pool_size, use_bathcnorm=True): + def __init__( + self, + in_channels: int, + out_channels: int, + pool_size: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() + if pool_size == 1: - use_bathcnorm = False # PyTorch does not support BatchNorm for 1x1 shape + use_norm = "identity" # PyTorch does not support BatchNorm for 1x1 shape + self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), modules.Conv2dReLU( - in_channels, out_channels, (1, 1), use_batchnorm=use_bathcnorm + in_channels, out_channels, kernel_size=1, use_norm=use_norm ), ) - def forward(self, x): - h, w = x.size(2), x.size(3) + def forward(self, x: torch.Tensor) -> torch.Tensor: + height, width = x.shape[2:] x = self.pool(x) - x = F.interpolate(x, size=(h, w), mode="bilinear", align_corners=True) + x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=True) return x class PSPModule(nn.Module): - def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): + def __init__( + self, + in_channels: int, + sizes: Tuple[int, ...] = (1, 2, 3, 6), + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() self.blocks = nn.ModuleList( @@ -34,7 +49,7 @@ def __init__(self, in_channels, sizes=(1, 2, 3, 6), use_bathcnorm=True): in_channels, in_channels // len(sizes), size, - use_bathcnorm=use_bathcnorm, + use_norm=use_norm, ) for size in sizes ] @@ -48,26 +63,30 @@ def forward(self, x): class PSPDecoder(nn.Module): def __init__( - self, encoder_channels, use_batchnorm=True, out_channels=512, dropout=0.2 + self, + encoder_channels: List[int], + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + out_channels: int = 512, + dropout: float = 0.2, ): super().__init__() self.psp = PSPModule( in_channels=encoder_channels[-1], sizes=(1, 2, 3, 6), - use_bathcnorm=use_batchnorm, + use_norm=use_norm, ) self.conv = modules.Conv2dReLU( in_channels=encoder_channels[-1] * 2, out_channels=out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.dropout = nn.Dropout2d(p=dropout) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: x = features[-1] x = self.psp(x) x = self.conv(x) diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index 8b99b3da..f7740891 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, Union +import warnings +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -28,16 +29,28 @@ class PSPNet(SegmentationModel): encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and other pretrained weights (see table with available weights for each encoder_name) psp_out_channels: A number of filters in Spatial Pyramid - psp_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` psp_dropout: Spatial dropout rate in [0, 1) used in Spatial Pyramid in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: @@ -62,17 +75,26 @@ def __init__( encoder_weights: Optional[str] = "imagenet", encoder_depth: int = 3, psp_out_channels: int = 512, - psp_use_batchnorm: bool = True, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", psp_dropout: float = 0.2, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, upsampling: int = 8, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): super().__init__() + psp_use_batchnorm = kwargs.pop("psp_use_batchnorm", None) + if psp_use_batchnorm is not None: + warnings.warn( + "The usage of psp_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = psp_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -83,7 +105,7 @@ def __init__( self.decoder = PSPDecoder( encoder_channels=self.encoder.out_channels, - use_batchnorm=psp_use_batchnorm, + use_norm=decoder_use_norm, out_channels=psp_out_channels, dropout=psp_dropout, ) diff --git a/segmentation_models_pytorch/decoders/segformer/decoder.py b/segmentation_models_pytorch/decoders/segformer/decoder.py index daa78b37..2bfadfff 100644 --- a/segmentation_models_pytorch/decoders/segformer/decoder.py +++ b/segmentation_models_pytorch/decoders/segformer/decoder.py @@ -2,11 +2,12 @@ import torch.nn as nn import torch.nn.functional as F +from typing import List from segmentation_models_pytorch.base import modules as md class MLP(nn.Module): - def __init__(self, skip_channels, segmentation_channels): + def __init__(self, skip_channels: int, segmentation_channels: int): super().__init__() self.linear = nn.Linear(skip_channels, segmentation_channels) @@ -22,9 +23,9 @@ def forward(self, x: torch.Tensor): class SegformerDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - segmentation_channels=256, + encoder_channels: List[int], + encoder_depth: int = 5, + segmentation_channels: int = 256, ): super().__init__() @@ -36,9 +37,9 @@ def __init__( ) if encoder_channels[1] == 0: - encoder_channels = tuple( + encoder_channels = [ channel for index, channel in enumerate(encoder_channels) if index != 1 - ) + ] encoder_channels = encoder_channels[::-1] self.mlp_stage = nn.ModuleList( @@ -49,10 +50,10 @@ def __init__( in_channels=(len(encoder_channels) - 1) * segmentation_channels, out_channels=segmentation_channels, kernel_size=1, - use_batchnorm=True, + use_norm="batchnorm", ) - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: # Resize all features to the size of the largest feature target_size = [dim // 4 for dim in features[0].shape[2:]] @@ -60,8 +61,8 @@ def forward(self, *features): features = features[::-1] # reverse channels to start from head of encoder resized_features = [] - for feature, stage in zip(features, self.mlp_stage): - feature = stage(feature) + for i, mlp_layer in enumerate(self.mlp_stage): + feature = mlp_layer(features[i]) resized_feature = F.interpolate( feature, size=target_size, mode="bilinear", align_corners=False ) diff --git a/segmentation_models_pytorch/decoders/segformer/model.py b/segmentation_models_pytorch/decoders/segformer/model.py index 45805de7..03deeeef 100644 --- a/segmentation_models_pytorch/decoders/segformer/model.py +++ b/segmentation_models_pytorch/decoders/segformer/model.py @@ -28,8 +28,8 @@ class Segformer(SegmentationModel): classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. + upsampling: A number to upsample the output of the model, default is 4 (same size as input) aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -57,6 +57,7 @@ def __init__( in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, Callable]] = None, + upsampling: int = 4, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -81,7 +82,7 @@ def __init__( out_channels=classes, activation=activation, kernel_size=1, - upsampling=4, + upsampling=upsampling, ) if aux_params is not None: diff --git a/segmentation_models_pytorch/decoders/unet/decoder.py b/segmentation_models_pytorch/decoders/unet/decoder.py index 33061542..cfeb267e 100644 --- a/segmentation_models_pytorch/decoders/unet/decoder.py +++ b/segmentation_models_pytorch/decoders/unet/decoder.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Optional, Sequence, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -5,22 +7,26 @@ from segmentation_models_pytorch.base import modules as md -class DecoderBlock(nn.Module): +class UnetDecoderBlock(nn.Module): + """A decoder block in the U-Net architecture that performs upsampling and feature fusion.""" + def __init__( self, - in_channels, - skip_channels, - out_channels, - use_batchnorm=True, - attention_type=None, + in_channels: int, + skip_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", ): super().__init__() + self.interpolation_mode = interpolation_mode self.conv1 = md.Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention1 = md.Attention( attention_type, in_channels=in_channels + skip_channels @@ -30,49 +36,73 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") - if skip is not None: - x = torch.cat([x, skip], dim=1) - x = self.attention1(x) - x = self.conv1(x) - x = self.conv2(x) - x = self.attention2(x) - return x + def forward( + self, + feature_map: torch.Tensor, + target_height: int, + target_width: int, + skip_connection: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + feature_map = F.interpolate( + feature_map, + size=(target_height, target_width), + mode=self.interpolation_mode, + ) + if skip_connection is not None: + feature_map = torch.cat([feature_map, skip_connection], dim=1) + feature_map = self.attention1(feature_map) + feature_map = self.conv1(feature_map) + feature_map = self.conv2(feature_map) + feature_map = self.attention2(feature_map) + return feature_map + +class UnetCenterBlock(nn.Sequential): + """Center block of the Unet decoder. Applied to the last feature map of the encoder.""" -class CenterBlock(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) super().__init__(conv1, conv2) class UnetDecoder(nn.Module): + """The decoder part of the U-Net architecture. + + Takes encoded features from different stages of the encoder and progressively upsamples them while + combining with skip connections. This helps preserve fine-grained details in the final segmentation. + """ + def __init__( self, - encoder_channels, - decoder_channels, - n_blocks=5, - use_batchnorm=True, - attention_type=None, - center=False, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + add_center_block: bool = False, + interpolation_mode: str = "nearest", ): super().__init__() @@ -94,31 +124,47 @@ def __init__( skip_channels = list(encoder_channels[1:]) + [0] out_channels = decoder_channels - if center: - self.center = CenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm + if add_center_block: + self.center = UnetCenterBlock( + head_channels, + head_channels, + use_norm=use_norm, ) else: self.center = nn.Identity() # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) - blocks = [ - DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) - for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) - ] - self.blocks = nn.ModuleList(blocks) - - def forward(self, *features): + self.blocks = nn.ModuleList() + for block_in_channels, block_skip_channels, block_out_channels in zip( + in_channels, skip_channels, out_channels + ): + block = UnetDecoderBlock( + block_in_channels, + block_skip_channels, + block_out_channels, + use_norm=use_norm, + attention_type=attention_type, + interpolation_mode=interpolation_mode, + ) + self.blocks.append(block) + + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + # spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...] + spatial_shapes = [feature.shape[2:] for feature in features] + spatial_shapes = spatial_shapes[::-1] + features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder head = features[0] - skips = features[1:] + skip_connections = features[1:] x = self.center(head) + for i, decoder_block in enumerate(self.blocks): - skip = skips[i] if i < len(skips) else None - x = decoder_block(x, skip) + # upsample to the next spatial shape + height, width = spatial_shapes[i + 1] + skip_connection = skip_connections[i] if i < len(skip_connections) else None + x = decoder_block(x, height, width, skip_connection=skip_connection) return x diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 547581eb..3df36e32 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,4 +1,5 @@ -from typing import Any, Optional, Union, Tuple, Callable +import warnings +from typing import Any, Dict, Optional, Union, Callable, Sequence from segmentation_models_pytorch.base import ( ClassificationHead, @@ -12,10 +13,21 @@ class Unet(SegmentationModel): - """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder* - and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial - resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation* - for fusing decoder blocks with skip connections. + """ + U-Net is a fully convolutional neural network architecture designed for semantic image segmentation. + + It consists of two main parts: + + 1. An encoder (downsampling path) that extracts increasingly abstract features + 2. A decoder (upsampling path) that gradually recovers spatial details + + The key is the use of skip connections between corresponding encoder and decoder layers. + These connections allow the decoder to access fine-grained details from earlier encoder layers, + which helps produce more precise segmentation masks. + + The skip connections work by concatenating feature maps from the encoder directly into the decoder + at corresponding resolutions. This helps preserve important spatial information that would + otherwise be lost during the encoding process. Args: encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) @@ -28,17 +40,31 @@ class Unet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). + decoder_interpolation: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -51,20 +77,41 @@ class Unet(SegmentationModel): Returns: ``torch.nn.Module``: Unet + Example: + .. code-block:: python + + import torch + import segmentation_models_pytorch as smp + + model = smp.Unet("resnet18", encoder_weights="imagenet", classes=5) + model.eval() + + # generate random images + images = torch.rand(2, 3, 256, 256) + + with torch.inference_mode(): + mask = model(images) + + print(mask.shape) + # torch.Size([2, 5, 256, 256]) + .. _Unet: https://arxiv.org/abs/1505.04597 """ + requires_divisible_input_shape = False + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, - decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, + decoder_interpolation: str = "nearest", in_channels: int = 3, classes: int = 1, activation: Optional[Union[str, Callable]] = None, @@ -73,6 +120,15 @@ def __init__( ): super().__init__() + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -81,13 +137,16 @@ def __init__( **kwargs, ) + add_center_block = encoder_name.startswith("vgg") + self.decoder = UnetDecoder( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, - center=True if encoder_name.startswith("vgg") else False, + use_norm=decoder_use_norm, + add_center_block=add_center_block, attention_type=decoder_attention_type, + interpolation_mode=decoder_interpolation, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py index 54ec7576..b42a73a9 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/decoder.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/decoder.py @@ -2,17 +2,20 @@ import torch.nn as nn import torch.nn.functional as F +from typing import Any, Dict, List, Optional, Union, Sequence + from segmentation_models_pytorch.base import modules as md class DecoderBlock(nn.Module): def __init__( self, - in_channels, - skip_channels, - out_channels, - use_batchnorm=True, - attention_type=None, + in_channels: int, + skip_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", ): super().__init__() self.conv1 = md.Conv2dReLU( @@ -20,7 +23,7 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention1 = md.Attention( attention_type, in_channels=in_channels + skip_channels @@ -30,12 +33,15 @@ def __init__( out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) self.attention2 = md.Attention(attention_type, in_channels=out_channels) + self.interpolation_mode = interpolation_mode - def forward(self, x, skip=None): - x = F.interpolate(x, scale_factor=2, mode="nearest") + def forward( + self, x: torch.Tensor, skip: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.attention1(x) @@ -46,20 +52,25 @@ def forward(self, x, skip=None): class CenterBlock(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): + def __init__( + self, + in_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): conv1 = md.Conv2dReLU( in_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) conv2 = md.Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ) super().__init__(conv1, conv2) @@ -67,20 +78,19 @@ def __init__(self, in_channels, out_channels, use_batchnorm=True): class UnetPlusPlusDecoder(nn.Module): def __init__( self, - encoder_channels, - decoder_channels, - n_blocks=5, - use_batchnorm=True, - attention_type=None, - center=False, + encoder_channels: Sequence[int], + decoder_channels: Sequence[int], + n_blocks: int = 5, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + attention_type: Optional[str] = None, + interpolation_mode: str = "nearest", + center: bool = False, ): super().__init__() if n_blocks != len(decoder_channels): raise ValueError( - "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( - n_blocks, len(decoder_channels) - ) + f"Model depth is {n_blocks}, but you provide `decoder_channels` for {len(decoder_channels)} blocks." ) # remove first skip with same spatial resolution @@ -95,13 +105,19 @@ def __init__( self.out_channels = decoder_channels if center: self.center = CenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm + head_channels, + head_channels, + use_norm=use_norm, ) else: self.center = nn.Identity() # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) + kwargs = dict( + use_norm=use_norm, + attention_type=attention_type, + interpolation_mode=interpolation_mode, + ) blocks = {} for layer_idx in range(len(self.in_channels) - 1): @@ -119,15 +135,16 @@ def __init__( blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock( in_ch, skip_ch, out_ch, **kwargs ) - blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock( + blocks[f"x_{0}_{len(self.in_channels) - 1}"] = DecoderBlock( self.in_channels[-1], 0, self.out_channels[-1], **kwargs ) self.blocks = nn.ModuleDict(blocks) self.depth = len(self.in_channels) - 1 - def forward(self, *features): + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: features = features[1:] # remove first skip with same spatial resolution features = features[::-1] # reverse channels to start from head of encoder + # start building dense connections dense_x = {} for layer_idx in range(len(self.in_channels) - 1): @@ -148,8 +165,8 @@ def forward(self, *features): ) dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[ f"x_{depth_idx}_{dense_l_i}" - ](dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features) + ](dense_x[f"x_{depth_idx}_{dense_l_i - 1}"], cat_features) dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"]( - dense_x[f"x_{0}_{self.depth-1}"] + dense_x[f"x_{0}_{self.depth - 1}"] ) return dense_x[f"x_{0}_{self.depth}"] diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 9d4a1e35..5448abcb 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Union +import warnings +from typing import Any, Dict, Sequence, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -28,17 +29,31 @@ class UnetPlusPlus(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder. Length of the list should be the same as **encoder_depth** - decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers - is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption. - Available options are **True, False, "inplace"** + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + decoder_use_norm={"type": "layernorm", "eps": 1e-2} + ``` decoder_attention_type: Attention module used in decoder of the model. Available options are **None** and **scse** (https://arxiv.org/abs/1808.08127). + decoder_interpolation: Interpolation mode used in decoder of the model. Available options are + **"nearest"**, **"bilinear"**, **"bicubic"**, **"area"**, **"nearest-exact"**. Default is **"nearest"**. in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -56,18 +71,21 @@ class UnetPlusPlus(SegmentationModel): """ + _is_torch_scriptable = False + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + decoder_channels: Sequence[int] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, + decoder_interpolation: str = "nearest", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -78,6 +96,15 @@ def __init__( "UnetPlusPlus is not support encoder_name={}".format(encoder_name) ) + decoder_use_batchnorm = kwargs.pop("decoder_use_batchnorm", None) + if decoder_use_batchnorm is not None: + warnings.warn( + "The usage of decoder_use_batchnorm is deprecated. Please modify your code for decoder_use_norm", + DeprecationWarning, + stacklevel=2, + ) + decoder_use_norm = decoder_use_batchnorm + self.encoder = get_encoder( encoder_name, in_channels=in_channels, @@ -90,9 +117,10 @@ def __init__( encoder_channels=self.encoder.out_channels, decoder_channels=decoder_channels, n_blocks=encoder_depth, - use_batchnorm=decoder_use_batchnorm, + use_norm=decoder_use_norm, center=True if encoder_name.startswith("vgg") else False, attention_type=decoder_attention_type, + interpolation_mode=decoder_interpolation, ) self.segmentation_head = SegmentationHead( diff --git a/segmentation_models_pytorch/decoders/upernet/decoder.py b/segmentation_models_pytorch/decoders/upernet/decoder.py index 092de36a..435927df 100644 --- a/segmentation_models_pytorch/decoders/upernet/decoder.py +++ b/segmentation_models_pytorch/decoders/upernet/decoder.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Union, Sequence, List + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,10 +10,10 @@ class PSPModule(nn.Module): def __init__( self, - in_channels, - out_channels, - sizes=(1, 2, 3, 6), - use_batchnorm=True, + in_channels: int, + out_channels: int, + sizes: Sequence[int] = (1, 2, 3, 6), + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() self.blocks = nn.ModuleList( @@ -20,63 +22,85 @@ def __init__( nn.AdaptiveAvgPool2d(size), md.Conv2dReLU( in_channels, - in_channels // len(sizes), + out_channels, kernel_size=1, - use_batchnorm=use_batchnorm, + use_norm=use_norm, ), ) for size in sizes ] ) self.out_conv = md.Conv2dReLU( - in_channels=in_channels * 2, + in_channels=in_channels + len(sizes) * out_channels, out_channels=out_channels, - kernel_size=1, - use_batchnorm=True, + kernel_size=3, + padding=1, + use_norm="batchnorm", ) - def forward(self, x): - _, _, height, width = x.shape - out = [x] + [ - F.interpolate( - block(x), size=(height, width), mode="bilinear", align_corners=False + def forward(self, feature: torch.Tensor) -> torch.Tensor: + _, _, height, width = feature.shape + pyramid_features = [feature] + for block in self.blocks: + pooled_feature = block(feature) + resized_feature = F.interpolate( + pooled_feature, + size=(height, width), + mode="bilinear", + align_corners=False, ) - for block in self.blocks - ] - out = self.out_conv(torch.cat(out, dim=1)) - return out + pyramid_features.append(resized_feature) + fused_feature = self.out_conv(torch.cat(pyramid_features, dim=1)) + return fused_feature -class FPNBlock(nn.Module): - def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True): +class LayerNorm2d(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 3, 1) # to channels_last + normed_x = nn.functional.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + normed_x = normed_x.permute(0, 3, 1, 2) # to channels_first + return normed_x + + +class FPNLateralBlock(nn.Module): + def __init__( + self, + lateral_channels: int, + out_channels: int, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", + ): super().__init__() - self.skip_conv = ( - md.Conv2dReLU( - skip_channels, - pyramid_channels, - kernel_size=1, - use_batchnorm=use_bathcnorm, - ) - if skip_channels != 0 - else nn.Identity() + self.conv_norm_relu = md.Conv2dReLU( + lateral_channels, + out_channels, + kernel_size=1, + use_norm=use_norm, ) - def forward(self, x, skip): - _, channels, height, width = skip.shape - x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=False) - if channels != 0: - skip = self.skip_conv(skip) - x = x + skip - return x + def forward( + self, state_feature: torch.Tensor, lateral_feature: torch.Tensor + ) -> torch.Tensor: + # 1. Apply block to encoder feature + lateral_feature = self.conv_norm_relu(lateral_feature) + # 2. Upsample encoder feature to the "state" feature resolution + _, _, height, width = lateral_feature.shape + state_feature = F.interpolate( + state_feature, size=(height, width), mode="bilinear", align_corners=False + ) + # 3. Sum state and encoder features + fused_feature = state_feature + lateral_feature + return fused_feature class UPerNetDecoder(nn.Module): def __init__( self, - encoder_channels, - encoder_depth=5, - pyramid_channels=256, - segmentation_channels=64, + encoder_channels: Sequence[int], + encoder_depth: int = 5, + decoder_channels: int = 256, + use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", ): super().__init__() @@ -87,51 +111,101 @@ def __init__( ) ) - encoder_channels = encoder_channels[::-1] + # Encoder channels for input features starting from the highest resolution + # [1, 1/2, 1/4, 1/8, 1/16, ...] for num_features = encoder_depth + 1, + # but we use only [1/4, 1/8, 1/16, ...] for UPerNet + encoder_channels = encoder_channels[2:] + + self.feature_norms = nn.ModuleList( + [LayerNorm2d(channels, eps=1e-6) for channels in encoder_channels] + ) # PSP Module + lowest_resolution_feature_channels = encoder_channels[-1] self.psp = PSPModule( - in_channels=encoder_channels[0], - out_channels=pyramid_channels, + in_channels=lowest_resolution_feature_channels, + out_channels=decoder_channels, sizes=(1, 2, 3, 6), - use_batchnorm=True, + use_norm=use_norm, ) # FPN Module - self.fpn_stages = nn.ModuleList( - [FPNBlock(ch, pyramid_channels) for ch in encoder_channels[1:]] - ) + # we skip lower resolution feature maps + reverse the order + # [1/4, 1/8, 1/16, 1/32] -> [1/16, 1/8, 1/4] + lateral_channels = encoder_channels[:-1][::-1] + self.fpn_lateral_blocks = nn.ModuleList([]) + self.fpn_conv_blocks = nn.ModuleList([]) + for channels in lateral_channels: + block = FPNLateralBlock( + lateral_channels=channels, + out_channels=decoder_channels, + use_norm=use_norm, + ) + self.fpn_lateral_blocks.append(block) + conv_block = md.Conv2dReLU( + in_channels=decoder_channels, + out_channels=decoder_channels, + kernel_size=3, + padding=1, + use_norm=use_norm, + ) + self.fpn_conv_blocks.append(conv_block) - self.fpn_bottleneck = md.Conv2dReLU( - in_channels=(len(encoder_channels) - 1) * pyramid_channels, - out_channels=segmentation_channels, + num_blocks_to_fuse = len(self.fpn_conv_blocks) + 1 # +1 for the PSP module + self.fusion_block = md.Conv2dReLU( + in_channels=num_blocks_to_fuse * decoder_channels, + out_channels=decoder_channels, kernel_size=3, padding=1, - use_batchnorm=True, + use_norm=use_norm, ) - def forward(self, *features): - output_size = features[0].shape[2:] - target_size = [size // 4 for size in output_size] + def forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Args: + features (List[torch.Tensor]): + features with: [1, 1/2, 1/4, 1/8, 1/16, ...] spatial resolutions, + where the first feature is the highest resolution and the number + of features is equal to encoder_depth + 1. + """ + + # skip 1/1 and 1/2 resolution features + features = features[2:] - features = features[1:] # remove first skip with same spatial resolution - features = features[::-1] # reverse channels to start from head of encoder + # normalize feature maps + for i, norm in enumerate(self.feature_norms): + features[i] = norm(features[i]) - psp_out = self.psp(features[0]) + # pass lowest resolution feature to PSP module + psp_out = self.psp(features[-1]) + # skip lowest features for FPN + reverse the order + # [1/4, 1/8, 1/16, 1/32] -> [1/16, 1/8, 1/4] + fpn_lateral_features = features[:-1][::-1] fpn_features = [psp_out] - for feature, stage in zip(features[1:], self.fpn_stages): - fpn_feature = stage(fpn_features[-1], feature) + for i, block in enumerate(self.fpn_lateral_blocks): + # 1. for each encoder (skip) feature we apply 1x1 ConvNormRelu, + # 2. upsample latest fpn feature to it's resolution + # 3. sum them together + lateral_feature = fpn_lateral_features[i] + state_feature = fpn_features[-1] + fpn_feature = block(state_feature, lateral_feature) fpn_features.append(fpn_feature) + # Apply FPN conv blocks, but skip PSP module + for i, conv_block in enumerate(self.fpn_conv_blocks, start=1): + fpn_features[i] = conv_block(fpn_features[i]) + # Resize all FPN features to 1/4 of the original resolution. resized_fpn_features = [] + target_size = fpn_features[-1].shape[2:] # 1/4 of the original resolution for feature in fpn_features: resized_feature = F.interpolate( feature, size=target_size, mode="bilinear", align_corners=False ) resized_fpn_features.append(resized_feature) - output = self.fpn_bottleneck(torch.cat(resized_fpn_features, dim=1)) - + # reverse and concatenate + stacked_features = torch.cat(resized_fpn_features[::-1], dim=1) + output = self.fusion_block(stacked_features) return output diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index 076ed2de..54f578b3 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -25,12 +25,27 @@ class UPerNet(SegmentationModel): other pretrained weights (see table with available weights for each encoder_name) decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 + decoder_use_norm: Specifies normalization between Conv2D and activation. + Accepts the following types: + - **True**: Defaults to `"batchnorm"`. + - **False**: No normalization (`nn.Identity`). + - **str**: Specifies normalization type using default parameters. Available values: + `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. + - **dict**: Fully customizable normalization settings. Structure: + ```python + {"type": , **kwargs} + ``` + where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. + + **Example**: + ```python + use_norm={"type": "layernorm", "eps": 1e-2} + ``` in_channels: A number of input channels for the model, default is 3 (RGB images) classes: A number of classes for output mask (or you can think as a number of channels of output mask) activation: An activation function to apply after the final convolution layer. Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, - **callable** and **None**. - Default is **None** + **callable** and **None**. Default is **None**. aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build on top of encoder if **aux_params** is not **None** (default). Supported params: - classes (int): A number of classes @@ -54,11 +69,12 @@ def __init__( encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", - decoder_pyramid_channels: int = 256, - decoder_segmentation_channels: int = 64, + decoder_channels: int = 256, + decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, + upsampling: int = 4, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): @@ -75,16 +91,16 @@ def __init__( self.decoder = UPerNetDecoder( encoder_channels=self.encoder.out_channels, encoder_depth=encoder_depth, - pyramid_channels=decoder_pyramid_channels, - segmentation_channels=decoder_segmentation_channels, + decoder_channels=decoder_channels, + use_norm=decoder_use_norm, ) self.segmentation_head = SegmentationHead( - in_channels=decoder_segmentation_channels, + in_channels=decoder_channels, out_channels=classes, activation=activation, kernel_size=1, - upsampling=4, + upsampling=upsampling, ) if aux_params is not None: diff --git a/segmentation_models_pytorch/encoders/__init__.py b/segmentation_models_pytorch/encoders/__init__.py index c4a4c037..287a921a 100644 --- a/segmentation_models_pytorch/encoders/__init__.py +++ b/segmentation_models_pytorch/encoders/__init__.py @@ -1,6 +1,12 @@ +import json import timm +import copy +import warnings import functools -import torch.utils.model_zoo as model_zoo +from torch.utils.model_zoo import load_url +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file + from .resnet import resnet_encoders from .dpn import dpn_encoders @@ -13,18 +19,23 @@ from .mobilenet import mobilenet_encoders from .xception import xception_encoders from .timm_efficientnet import timm_efficientnet_encoders -from .timm_resnest import timm_resnest_encoders -from .timm_res2net import timm_res2net_encoders -from .timm_regnet import timm_regnet_encoders from .timm_sknet import timm_sknet_encoders -from .timm_mobilenetv3 import timm_mobilenetv3_encoders -from .timm_gernet import timm_gernet_encoders from .mix_transformer import mix_transformer_encoders from .mobileone import mobileone_encoders from .timm_universal import TimmUniversalEncoder +from .timm_vit import TimmViTEncoder # noqa F401 from ._preprocessing import preprocess_input +from ._legacy_pretrained_settings import pretrained_settings + +__all__ = [ + "encoders", + "get_encoder", + "get_encoder_names", + "get_preprocessing_params", + "get_preprocessing_fn", +] encoders = {} encoders.update(resnet_encoders) @@ -38,17 +49,39 @@ encoders.update(mobilenet_encoders) encoders.update(xception_encoders) encoders.update(timm_efficientnet_encoders) -encoders.update(timm_resnest_encoders) -encoders.update(timm_res2net_encoders) -encoders.update(timm_regnet_encoders) encoders.update(timm_sknet_encoders) -encoders.update(timm_mobilenetv3_encoders) -encoders.update(timm_gernet_encoders) encoders.update(mix_transformer_encoders) encoders.update(mobileone_encoders) +def is_equivalent_to_timm_universal(name): + patterns = [ + "timm-regnet", + "timm-res2", + "timm-resnest", + "timm-mobilenetv3", + "timm-gernet", + ] + for pattern in patterns: + if name.startswith(pattern): + return True + return False + + def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): + if name.startswith("timm-"): + warnings.warn( + "`timm-` encoders are deprecated and will be removed in the future. " + "Please use `tu-` equivalent encoders instead (see 'Timm encoders' section in the documentation).", + DeprecationWarning, + ) + + # convert timm- models to tu- models + if is_equivalent_to_timm_universal(name): + name = name.replace("timm-", "tu-") + if "mobilenetv3" in name: + name = name.replace("tu-", "tu-tf_") + if name.startswith("tu-"): name = name[3:] encoder = TimmUniversalEncoder( @@ -61,29 +94,56 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, ** ) return encoder - try: - Encoder = encoders[name]["encoder"] - except KeyError: + if name not in encoders: raise KeyError( - "Wrong encoder name `{}`, supported encoders: {}".format( - name, list(encoders.keys()) - ) + f"Wrong encoder name `{name}`, supported encoders: {list(encoders.keys())}" ) - params = encoders[name]["params"] - params.update(depth=depth) - encoder = Encoder(**params) + params = copy.deepcopy(encoders[name]["params"]) + params["depth"] = depth + params["output_stride"] = output_stride + + EncoderClass = encoders[name]["encoder"] + encoder = EncoderClass(**params) if weights is not None: - try: - settings = encoders[name]["pretrained_settings"][weights] - except KeyError: + if weights not in encoders[name]["pretrained_settings"]: + available_weights = list(encoders[name]["pretrained_settings"].keys()) raise KeyError( - "Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( - weights, name, list(encoders[name]["pretrained_settings"].keys()) - ) + f"Wrong pretrained weights `{weights}` for encoder `{name}`. " + f"Available options are: {available_weights}" ) - encoder.load_state_dict(model_zoo.load_url(settings["url"])) + + settings = encoders[name]["pretrained_settings"][weights] + repo_id = settings["repo_id"] + revision = settings["revision"] + + # First, try to load from HF-Hub, but as far as I know not all countries have + # access to the Hub (e.g. China), so we try to load from the original url if + # the first attempt fails. + weights_path = None + try: + hf_hub_download(repo_id, filename="config.json", revision=revision) + weights_path = hf_hub_download( + repo_id, filename="model.safetensors", revision=revision + ) + except Exception as e: + if name in pretrained_settings and weights in pretrained_settings[name]: + message = ( + f"Error loading {name} `{weights}` weights from Hugging Face Hub, " + "trying loading from original url..." + ) + warnings.warn(message, UserWarning) + url = pretrained_settings[name][weights]["url"] + state_dict = load_url(url, map_location="cpu") + else: + raise e + + if weights_path is not None: + state_dict = load_file(weights_path, device="cpu") + + # Load model weights + encoder.load_state_dict(state_dict) encoder.set_in_channels(in_channels, pretrained=weights is not None) if output_stride != 32: @@ -110,7 +170,25 @@ def get_preprocessing_params(encoder_name, pretrained="imagenet"): raise ValueError( "Available pretrained options {}".format(all_settings.keys()) ) - settings = all_settings[pretrained] + + repo_id = all_settings[pretrained]["repo_id"] + revision = all_settings[pretrained]["revision"] + + # Load config and model + try: + config_path = hf_hub_download( + repo_id, filename="config.json", revision=revision + ) + with open(config_path, "r") as f: + settings = json.load(f) + except Exception as e: + if ( + encoder_name in pretrained_settings + and pretrained in pretrained_settings[encoder_name] + ): + settings = pretrained_settings[encoder_name][pretrained] + else: + raise e formatted_settings = {} formatted_settings["input_space"] = settings.get("input_space", "RGB") diff --git a/segmentation_models_pytorch/encoders/_base.py b/segmentation_models_pytorch/encoders/_base.py index 3b877075..98c431fb 100644 --- a/segmentation_models_pytorch/encoders/_base.py +++ b/segmentation_models_pytorch/encoders/_base.py @@ -1,3 +1,6 @@ +import torch +from typing import Sequence, Dict + from . import _utils as utils @@ -7,7 +10,14 @@ class EncoderMixin: - patching first convolution for arbitrary input channels """ - _output_stride = 32 + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + + def __init__(self): + self._depth = 5 + self._in_channels = 3 + self._output_stride = 32 @property def out_channels(self): @@ -25,34 +35,27 @@ def set_in_channels(self, in_channels, pretrained=True): self._in_channels = in_channels if self._out_channels[0] == 3: - self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) + self._out_channels = [in_channels] + self._out_channels[1:] utils.patch_first_conv( model=self, new_in_channels=in_channels, pretrained=pretrained ) - def get_stages(self): - """Override it in your implementation""" + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + """Override it in your implementation, should return a dictionary with keys as + the output stride and values as the list of modules + """ raise NotImplementedError def make_dilated(self, output_stride): - if output_stride == 16: - stage_list = [5] - dilation_list = [2] - - elif output_stride == 8: - stage_list = [4, 5] - dilation_list = [2, 4] - - else: - raise ValueError( - "Output stride should be 16 or 8, got {}.".format(output_stride) - ) - - self._output_stride = output_stride + if output_stride not in [8, 16]: + raise ValueError(f"Output stride should be 16 or 8, got {output_stride}.") stages = self.get_stages() - for stage_indx, dilation_rate in zip(stage_list, dilation_list): - utils.replace_strides_with_dilation( - module=stages[stage_indx], dilation_rate=dilation_rate - ) + for stage_stride, stage_modules in stages.items(): + if stage_stride <= output_stride: + continue + + dilation_rate = stage_stride // output_stride + for module in stage_modules: + utils.replace_strides_with_dilation(module, dilation_rate) diff --git a/segmentation_models_pytorch/encoders/_dpn.py b/segmentation_models_pytorch/encoders/_dpn.py new file mode 100644 index 00000000..e7292615 --- /dev/null +++ b/segmentation_models_pytorch/encoders/_dpn.py @@ -0,0 +1,364 @@ +"""PyTorch implementation of DualPathNetworks +Ported to PyTorch by [Ross Wightman](https://github.com/rwightman/pytorch-dpn-pretrained) + +Based on original MXNet implementation https://github.com/cypw/DPNs with +many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs. + +This implementation is compatible with the pretrained weights +from cypw's MXNet implementation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict + + +class CatBnAct(nn.Module): + def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)): + super(CatBnAct, self).__init__() + self.bn = nn.BatchNorm2d(in_chs, eps=0.001) + self.act = activation_fn + + def forward(self, x): + x = torch.cat(x, dim=1) if isinstance(x, tuple) else x + return self.act(self.bn(x)) + + +class BnActConv2d(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size, + stride, + padding=0, + groups=1, + activation_fn=nn.ReLU(inplace=True), + ): + super(BnActConv2d, self).__init__() + self.bn = nn.BatchNorm2d(in_chs, eps=0.001) + self.act = activation_fn + self.conv = nn.Conv2d( + in_chs, out_chs, kernel_size, stride, padding, groups=groups, bias=False + ) + + def forward(self, x): + return self.conv(self.act(self.bn(x))) + + +class InputBlock(nn.Module): + def __init__( + self, + num_init_features, + kernel_size=7, + padding=3, + activation_fn=nn.ReLU(inplace=True), + ): + super(InputBlock, self).__init__() + self.conv = nn.Conv2d( + 3, + num_init_features, + kernel_size=kernel_size, + stride=2, + padding=padding, + bias=False, + ) + self.bn = nn.BatchNorm2d(num_init_features, eps=0.001) + self.act = activation_fn + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + x = self.pool(x) + return x + + +class DualPathBlock(nn.Module): + def __init__( + self, + in_chs, + num_1x1_a, + num_3x3_b, + num_1x1_c, + inc, + groups, + block_type="normal", + b=False, + ): + super(DualPathBlock, self).__init__() + self.num_1x1_c = num_1x1_c + self.inc = inc + self.b = b + if block_type == "proj": + self.key_stride = 1 + self.has_proj = True + elif block_type == "down": + self.key_stride = 2 + self.has_proj = True + else: + assert block_type == "normal" + self.key_stride = 1 + self.has_proj = False + + if self.has_proj: + # Using different member names here to allow easier parameter key matching for conversion + if self.key_stride == 2: + self.c1x1_w_s2 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2 + ) + else: + self.c1x1_w_s1 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1 + ) + self.c1x1_a = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1 + ) + self.c3x3_b = BnActConv2d( + in_chs=num_1x1_a, + out_chs=num_3x3_b, + kernel_size=3, + stride=self.key_stride, + padding=1, + groups=groups, + ) + if b: + self.c1x1_c = CatBnAct(in_chs=num_3x3_b) + self.c1x1_c1 = nn.Conv2d(num_3x3_b, num_1x1_c, kernel_size=1, bias=False) + self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False) + else: + self.c1x1_c = BnActConv2d( + in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1 + ) + + def forward(self, x): + x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x + if self.has_proj: + if self.key_stride == 2: + x_s = self.c1x1_w_s2(x_in) + else: + x_s = self.c1x1_w_s1(x_in) + x_s1 = x_s[:, : self.num_1x1_c, :, :] + x_s2 = x_s[:, self.num_1x1_c :, :, :] + else: + x_s1 = x[0] + x_s2 = x[1] + x_in = self.c1x1_a(x_in) + x_in = self.c3x3_b(x_in) + if self.b: + x_in = self.c1x1_c(x_in) + out1 = self.c1x1_c1(x_in) + out2 = self.c1x1_c2(x_in) + else: + x_in = self.c1x1_c(x_in) + out1 = x_in[:, : self.num_1x1_c, :, :] + out2 = x_in[:, self.num_1x1_c :, :, :] + resid = x_s1 + out1 + dense = torch.cat([x_s2, out2], dim=1) + return resid, dense + + +class DPN(nn.Module): + def __init__( + self, + small=False, + num_init_features=64, + k_r=96, + groups=32, + b=False, + k_sec=(3, 4, 20, 3), + inc_sec=(16, 32, 24, 128), + num_classes=1000, + test_time_pool=False, + ): + super(DPN, self).__init__() + self.test_time_pool = test_time_pool + self.b = b + bw_factor = 1 if small else 4 + + blocks = OrderedDict() + + # conv1 + if small: + blocks["conv1_1"] = InputBlock(num_init_features, kernel_size=3, padding=1) + else: + blocks["conv1_1"] = InputBlock(num_init_features, kernel_size=7, padding=3) + + # conv2 + bw = 64 * bw_factor + inc = inc_sec[0] + r = (k_r * bw) // (64 * bw_factor) + blocks["conv2_1"] = DualPathBlock( + num_init_features, r, r, bw, inc, groups, "proj", b + ) + in_chs = bw + 3 * inc + for i in range(2, k_sec[0] + 1): + blocks["conv2_" + str(i)] = DualPathBlock( + in_chs, r, r, bw, inc, groups, "normal", b + ) + in_chs += inc + + # conv3 + bw = 128 * bw_factor + inc = inc_sec[1] + r = (k_r * bw) // (64 * bw_factor) + blocks["conv3_1"] = DualPathBlock(in_chs, r, r, bw, inc, groups, "down", b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[1] + 1): + blocks["conv3_" + str(i)] = DualPathBlock( + in_chs, r, r, bw, inc, groups, "normal", b + ) + in_chs += inc + + # conv4 + bw = 256 * bw_factor + inc = inc_sec[2] + r = (k_r * bw) // (64 * bw_factor) + blocks["conv4_1"] = DualPathBlock(in_chs, r, r, bw, inc, groups, "down", b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[2] + 1): + blocks["conv4_" + str(i)] = DualPathBlock( + in_chs, r, r, bw, inc, groups, "normal", b + ) + in_chs += inc + + # conv5 + bw = 512 * bw_factor + inc = inc_sec[3] + r = (k_r * bw) // (64 * bw_factor) + blocks["conv5_1"] = DualPathBlock(in_chs, r, r, bw, inc, groups, "down", b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[3] + 1): + blocks["conv5_" + str(i)] = DualPathBlock( + in_chs, r, r, bw, inc, groups, "normal", b + ) + in_chs += inc + blocks["conv5_bn_ac"] = CatBnAct(in_chs) + + self.features = nn.Sequential(blocks) + + # Using 1x1 conv for the FC layer to allow the extra pooling scheme + self.last_linear = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True) + + def logits(self, features): + if not self.training and self.test_time_pool: + x = F.avg_pool2d(features, kernel_size=7, stride=1) + out = self.last_linear(x) + # The extra test time pool should be pooling an img_size//32 - 6 size patch + out = adaptive_avgmax_pool2d(out, pool_type="avgmax") + else: + x = adaptive_avgmax_pool2d(features, pool_type="avg") + out = self.last_linear(x) + return out.view(out.size(0), -1) + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Author: Ross Wightman (rwightman) +""" + + +def pooling_factor(pool_type="avg"): + return 2 if pool_type == "avgmaxc" else 1 + + +def adaptive_avgmax_pool2d(x, pool_type="avg", padding=0, count_include_pad=False): + """Selectable global pooling function with dynamic input kernel size""" + if pool_type == "avgmaxc": + x = torch.cat( + [ + F.avg_pool2d( + x, + kernel_size=(x.size(2), x.size(3)), + padding=padding, + count_include_pad=count_include_pad, + ), + F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding), + ], + dim=1, + ) + elif pool_type == "avgmax": + x_avg = F.avg_pool2d( + x, + kernel_size=(x.size(2), x.size(3)), + padding=padding, + count_include_pad=count_include_pad, + ) + x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) + x = 0.5 * (x_avg + x_max) + elif pool_type == "max": + x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) + else: + if pool_type != "avg": + print( + "Invalid pool type %s specified. Defaulting to average pooling." + % pool_type + ) + x = F.avg_pool2d( + x, + kernel_size=(x.size(2), x.size(3)), + padding=padding, + count_include_pad=count_include_pad, + ) + return x + + +class AdaptiveAvgMaxPool2d(torch.nn.Module): + """Selectable global pooling layer with dynamic input kernel size""" + + def __init__(self, output_size=1, pool_type="avg"): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + self.pool_type = pool_type + if pool_type == "avgmaxc" or pool_type == "avgmax": + self.pool = nn.ModuleList( + [nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)] + ) + elif pool_type == "max": + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + if pool_type != "avg": + print( + "Invalid pool type %s specified. Defaulting to average pooling." + % pool_type + ) + self.pool = nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + if self.pool_type == "avgmaxc": + x = torch.cat([p(x) for p in self.pool], dim=1) + elif self.pool_type == "avgmax": + x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze( + dim=0 + ) + else: + x = self.pool(x) + return x + + def factor(self): + return pooling_factor(self.pool_type) + + def __repr__(self): + return ( + self.__class__.__name__ + + " (" + + "output_size=" + + str(self.output_size) + + ", pool_type=" + + self.pool_type + + ")" + ) diff --git a/segmentation_models_pytorch/encoders/_efficientnet.py b/segmentation_models_pytorch/encoders/_efficientnet.py new file mode 100644 index 00000000..b2847a56 --- /dev/null +++ b/segmentation_models_pytorch/encoders/_efficientnet.py @@ -0,0 +1,883 @@ +"""model.py - Model and module class for EfficientNet. +They are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import torch +from torch import nn +from torch.nn import functional as F +import re +import math +import collections +from functools import partial +from typing import List, Optional + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple( + "GlobalParams", + [ + "width_coefficient", + "depth_coefficient", + "image_size", + "dropout_rate", + "num_classes", + "batch_norm_momentum", + "batch_norm_epsilon", + "drop_connect_rate", + "depth_divisor", + "min_depth", + "include_top", + ], +) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple( + "BlockArgs", + [ + "num_repeat", + "kernel_size", + "stride", + "expand_ratio", + "input_filters", + "output_filters", + "se_ratio", + "id_skip", + ], +) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block. + + Args: + block_args (namedtuple): BlockArgs, defined in utils.py. + global_params (namedtuple): GlobalParam, defined in utils.py. + image_size (tuple or list): [image_height, image_width]. + + References: + [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) + [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) + [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) + """ + + def __init__( + self, block_args: BlockArgs, global_params: GlobalParams, image_size=None + ): + super().__init__() + + self._has_expansion = block_args.expand_ratio != 1 + self._has_se = block_args.se_ratio is not None and 0 < block_args.se_ratio <= 1 + self._has_drop_connect = ( + block_args.id_skip + and block_args.stride == 1 + and block_args.input_filters == block_args.output_filters + ) + + # Pytorch's difference from tensorflow + bn_momentum = 1 - global_params.batch_norm_momentum + bn_eps = global_params.batch_norm_epsilon + + # Expansion phase (Inverted Bottleneck) + input_channels = block_args.input_filters + expanded_channels = input_channels * block_args.expand_ratio + + if self._has_expansion: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d( + input_channels, expanded_channels, kernel_size=1, bias=False + ) + self._bn0 = nn.BatchNorm2d( + expanded_channels, + momentum=bn_momentum, + eps=bn_eps, + ) + else: + # for torchscript compatibility + self._expand_conv = nn.Identity() + self._bn0 = nn.Identity() + + # Depthwise convolution phase + kernel_size = block_args.kernel_size + stride = block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=expanded_channels, + out_channels=expanded_channels, + groups=expanded_channels, # groups makes it depthwise + kernel_size=kernel_size, + stride=stride, + bias=False, + ) + self._bn1 = nn.BatchNorm2d( + expanded_channels, + momentum=bn_momentum, + eps=bn_eps, + ) + image_size = calculate_output_image_size(image_size, stride) + + # Squeeze and Excitation layer, if desired + if self._has_se: + squeezed_channels = int(input_channels * block_args.se_ratio) + squeezed_channels = max(1, squeezed_channels) + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + self._se_reduce = Conv2d( + in_channels=expanded_channels, + out_channels=squeezed_channels, + kernel_size=1, + ) + self._se_expand = Conv2d( + in_channels=squeezed_channels, + out_channels=expanded_channels, + kernel_size=1, + ) + + # Pointwise convolution phase + output_channels = block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d( + in_channels=expanded_channels, + out_channels=output_channels, + kernel_size=1, + bias=False, + ) + self._bn2 = nn.BatchNorm2d( + num_features=output_channels, + momentum=bn_momentum, + eps=bn_eps, + ) + self._swish = nn.SiLU() + + def forward(self, inputs: torch.Tensor, drop_connect_rate: Optional[float] = None): + """MBConvBlock's forward function. + + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + + Returns: + Output of this block after processing. + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._has_expansion: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self._has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + # Pointwise Convolution + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + if self._has_drop_connect: + # The combination of skip connection and drop connect brings about stochastic depth. + if drop_connect_rate is not None and drop_connect_rate > 0: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + +class EfficientNet(nn.Module): + """EfficientNet model. + + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args: List[BlockArgs], global_params: GlobalParams): + super().__init__() + + if not isinstance(blocks_args, list): + raise ValueError("blocks_args should be a list") + if len(blocks_args) == 0: + raise ValueError("block args must be greater than 0") + + self._global_params = global_params + self._blocks_args = blocks_args + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Get stem static or dynamic convolution depending on image size + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias=False + ) + self._bn0 = nn.BatchNorm2d(out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in blocks_args: + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters( + block_args.input_filters, self._global_params + ), + output_filters=round_filters( + block_args.output_filters, self._global_params + ), + num_repeat=round_repeats(block_args.num_repeat, self._global_params), + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append( + MBConvBlock(block_args, self._global_params, image_size=image_size) + ) + image_size = calculate_output_image_size(image_size, block_args.stride) + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace( + input_filters=block_args.output_filters, stride=1 + ) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock(block_args, self._global_params, image_size=image_size) + ) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d( + num_features=out_channels, momentum=bn_mom, eps=bn_eps + ) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + self._swish = nn.SiLU() + + def extract_features(self, inputs): + """Use convolution layer to extract feature. + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + # scale drop connect_rate + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of this model after processing. + """ + # Convolution layers + x = self.extract_features(inputs) + + # Pooling and final linear layer + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + +################################################################################ +# Help functions for model architecture +################################################################################ + +# GlobalParams and BlockArgs: Two namedtuples +# round_filters and round_repeats: +# Functions to calculate params for scaling model width and depth ! ! ! +# get_width_and_height_from_size and calculate_output_image_size +# drop_connect: A structural design +# get_same_padding_conv2d: +# Conv2dDynamicSamePadding +# Conv2dStaticSamePadding +# get_same_padding_maxPool2d: +# MaxPool2dDynamicSamePadding +# MaxPool2dStaticSamePadding +# It's an additional function, not used in EfficientNet, +# but can be used in other model (such as EfficientDet). + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + # TODO: modify the params names. + # maybe the names (width_divisor,min_width) + # are more suitable than (depth_divisor,min_depth). + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor # pay attention to this line when using min_depth + # follow the formula transferred from official TensorFlow implementation + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + # follow the formula transferred from official TensorFlow implementation + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs: torch.Tensor, p: float, training: bool) -> torch.Tensor: + """Drop connect. + + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, "p must be in range of [0,1]" + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand( + [batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device + ) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + + Args: + x (int, tuple or list): Data size. + + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size(input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +# Note: +# The following 'SamePadding' functions make output size equal ceil(input size/stride). +# Only when stride equals 1, can the output size be the same as input size. +# Don't be confused by their function names ! ! ! + + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # Output after Conv2d: + # o = floor((i+p-((k-1)*d+1))/s+1) + # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), + # => p = (i-1)*s+((k-1)*d+1)-i + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + ): + super().__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = ( + math.ceil(ih / sh), + math.ceil(iw / sw), + ) # change the output size according to stride ! ! ! + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + # With the same calculation as Conv2dDynamicSamePadding + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + image_size=None, + **kwargs, + ): + super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + ) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__( + self, + kernel_size, + stride, + padding=0, + dilation=1, + return_indices=False, + ceil_mode=False, + ): + super().__init__( + kernel_size, stride, padding, dilation, return_indices, ceil_mode + ) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = ( + [self.kernel_size] * 2 + if isinstance(self.kernel_size, int) + else self.kernel_size + ) + self.dilation = ( + [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + ) + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + ) + return F.max_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.ceil_mode, + self.return_indices, + ) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = ( + [self.kernel_size] * 2 + if isinstance(self.kernel_size, int) + else self.kernel_size + ) + self.dilation = ( + [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + ) + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d( + (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + ) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d( + x, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.ceil_mode, + self.return_indices, + ) + return x + + +################################################################################ +# Helper functions for loading model params +################################################################################ + +# BlockDecoder: A Class for encoding and decoding BlockArgs +# efficientnet_params: A function to query compound coefficient +# get_model_params and efficientnet: +# Functions to get BlockArgs and GlobalParams for efficientnet + + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split("_") + options = {} + for op in ops: + splits = re.split(r"(\d.*)", op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert ("s" in options and len(options["s"]) == 1) or ( + len(options["s"]) == 2 and options["s"][0] == options["s"][1] + ) + + return BlockArgs( + num_repeat=int(options["r"]), + kernel_size=int(options["k"]), + stride=[int(options["s"][0])], + expand_ratio=int(options["e"]), + input_filters=int(options["i"]), + output_filters=int(options["o"]), + se_ratio=float(options["se"]) if "se" in options else None, + id_skip=("noskip" not in block_string), + ) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + + Args: + model_name (str): Model name to be queried. + + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + # Coefficients: width,depth,res,dropout + "efficientnet-b0": (1.0, 1.0, 224, 0.2), + "efficientnet-b1": (1.0, 1.1, 240, 0.2), + "efficientnet-b2": (1.1, 1.2, 260, 0.3), + "efficientnet-b3": (1.2, 1.4, 300, 0.3), + "efficientnet-b4": (1.4, 1.8, 380, 0.4), + "efficientnet-b5": (1.6, 2.2, 456, 0.4), + "efficientnet-b6": (1.8, 2.6, 528, 0.5), + "efficientnet-b7": (2.0, 3.1, 600, 0.5), + "efficientnet-b8": (2.2, 3.6, 672, 0.5), + "efficientnet-l2": (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet( + width_coefficient=None, + depth_coefficient=None, + image_size=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + num_classes=1000, + include_top=True, +): + """Create BlockArgs and GlobalParams for efficientnet model. + + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + + Meaning as the name suggests. + + Returns: + blocks_args, global_params. + """ + + # Blocks args for the whole model(efficientnet-b0 by default) + # It will be modified in the construction of EfficientNet Class according to model + blocks_args = [ + "r1_k3_s11_e1_i32_o16_se0.25", + "r2_k3_s22_e6_i16_o24_se0.25", + "r2_k5_s22_e6_i24_o40_se0.25", + "r3_k3_s22_e6_i40_o80_se0.25", + "r3_k5_s11_e6_i80_o112_se0.25", + "r4_k5_s22_e6_i112_o192_se0.25", + "r1_k3_s11_e6_i192_o320_se0.25", + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + + Returns: + blocks_args, global_params + """ + if model_name.startswith("efficientnet"): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s + ) + else: + raise NotImplementedError( + "model name is not pre-defined: {}".format(model_name) + ) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params diff --git a/segmentation_models_pytorch/encoders/_inceptionresnetv2.py b/segmentation_models_pytorch/encoders/_inceptionresnetv2.py new file mode 100644 index 00000000..50b9b616 --- /dev/null +++ b/segmentation_models_pytorch/encoders/_inceptionresnetv2.py @@ -0,0 +1,301 @@ +import torch +import torch.nn as nn + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) # verify bias false + self.bn = nn.BatchNorm2d( + out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True, + ) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2), + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1), + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1), + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1), + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2), + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2), + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2), + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2), + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + def __init__(self, scale=1.0, noReLU=False): + super(Block8, self).__init__() + + self.scale = scale + self.noReLU = noReLU + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)), + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + if not self.noReLU: + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if not self.noReLU: + out = self.relu(out) + return out + + +class InceptionResNetV2(nn.Module): + def __init__(self, num_classes=1001): + super(InceptionResNetV2, self).__init__() + # Special attributs + self.input_space = None + self.input_size = (299, 299, 3) + self.mean = None + self.std = None + # Modules + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + ) + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + ) + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + ) + self.block8 = Block8(noReLU=True) + self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) + self.avgpool_1a = nn.AvgPool2d(8, count_include_pad=False) + self.last_linear = nn.Linear(1536, num_classes) + + def features(self, input): + x = self.conv2d_1a(input) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def logits(self, features): + x = self.avgpool_1a(features) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x diff --git a/segmentation_models_pytorch/encoders/_inceptionv4.py b/segmentation_models_pytorch/encoders/_inceptionv4.py new file mode 100644 index 00000000..934f74cd --- /dev/null +++ b/segmentation_models_pytorch/encoders/_inceptionv4.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) # verify bias false + self.bn = nn.BatchNorm2d( + out_planes, + eps=0.001, # value found in tensorflow + momentum=0.1, # default pytorch value + affine=True, + ) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_3a(nn.Module): + def __init__(self): + super(Mixed_3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_4a(nn.Module): + def __init__(self): + super(Mixed_4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1), + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_5a(nn.Module): + def __init__(self): + super(Mixed_5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class Inception_A(nn.Module): + def __init__(self): + super(Inception_A, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1), + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_A(nn.Module): + def __init__(self): + super(Reduction_A, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2), + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_B(nn.Module): + def __init__(self): + super(Inception_B, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)), + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_B(nn.Module): + def __init__(self): + super(Reduction_B, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2), + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2), + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_C(nn.Module): + def __init__(self): + super(Inception_C, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d( + 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) + ) + self.branch1_1b = BasicConv2d( + 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d( + 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + self.branch2_2 = BasicConv2d( + 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1) + ) + self.branch2_3a = BasicConv2d( + 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) + ) + self.branch2_3b = BasicConv2d( + 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + def __init__(self, num_classes=1001): + super(InceptionV4, self).__init__() + # Special attributs + self.input_space = None + self.input_size = (299, 299, 3) + self.mean = None + self.std = None + # Modules + self.features = nn.Sequential( + BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed_3a(), + Mixed_4a(), + Mixed_5a(), + Inception_A(), + Inception_A(), + Inception_A(), + Inception_A(), + Reduction_A(), # Mixed_6a + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Reduction_B(), # Mixed_7a + Inception_C(), + Inception_C(), + Inception_C(), + ) + self.last_linear = nn.Linear(1536, num_classes) + + def logits(self, features): + # Allows image of any size to be processed + adaptiveAvgPoolWidth = features.shape[2] + x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x diff --git a/segmentation_models_pytorch/encoders/_legacy_pretrained_settings.py b/segmentation_models_pytorch/encoders/_legacy_pretrained_settings.py new file mode 100644 index 00000000..21f5691e --- /dev/null +++ b/segmentation_models_pytorch/encoders/_legacy_pretrained_settings.py @@ -0,0 +1,1062 @@ +pretrained_settings = { + "resnet18": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet18-5c106cde.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnet34": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "resnet50": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet50-19c8e357.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnet101": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "resnet152": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "resnext50_32x4d": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnext101_32x4d": { + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnext101_32x8d": { + "imagenet": { + "url": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnext101_32x16d": { + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "ssl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + "swsl": { + "url": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + }, + }, + "resnext101_32x32d": { + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "resnext101_32x48d": { + "instagram": { + "url": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "dpn68": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn68-4af7d88d2.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.48627450980392156, 0.4588235294117647, 0.40784313725490196], + "std": [0.23482446870963955, 0.23482446870963955, 0.23482446870963955], + "num_classes": 1000, + } + }, + "dpn68b": { + "imagenet+5k": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-363ab9c19.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.48627450980392156, 0.4588235294117647, 0.40784313725490196], + "std": [0.23482446870963955, 0.23482446870963955, 0.23482446870963955], + "num_classes": 1000, + } + }, + "dpn92": { + "imagenet+5k": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-fda993c95.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.48627450980392156, 0.4588235294117647, 0.40784313725490196], + "std": [0.23482446870963955, 0.23482446870963955, 0.23482446870963955], + "num_classes": 1000, + } + }, + "dpn98": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn98-722954780.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.48627450980392156, 0.4588235294117647, 0.40784313725490196], + "std": [0.23482446870963955, 0.23482446870963955, 0.23482446870963955], + "num_classes": 1000, + } + }, + "dpn107": { + "imagenet+5k": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-b7f9f4cc9.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.48627450980392156, 0.4588235294117647, 0.40784313725490196], + "std": [0.23482446870963955, 0.23482446870963955, 0.23482446870963955], + "num_classes": 1000, + } + }, + "dpn131": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/dpn131-7af84be88.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.48627450980392156, 0.4588235294117647, 0.40784313725490196], + "std": [0.23482446870963955, 0.23482446870963955, 0.23482446870963955], + "num_classes": 1000, + } + }, + "vgg11": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg11_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg13": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg13-c768596a.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg13_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg16": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg16-397923af.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg16_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg19": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg19_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "senet154": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnet50": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnet101": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnet152": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnext50_32x4d": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "se_resnext101_32x4d": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet121": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet121-fbdb23505.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet169": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet169-f470b90a4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet201": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet201-5750cbb1e.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "densenet161": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/densenet161-347e6b360.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "inceptionresnetv2": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1000, + }, + "imagenet+background": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1001, + }, + }, + "inceptionv4": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1000, + }, + "imagenet+background": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1001, + }, + }, + "efficientnet-b0": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b1": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b2": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b3": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b4": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b5": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b6": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "efficientnet-b7": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + "advprop": { + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "url": "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth", + "input_space": "RGB", + "input_range": [0, 1], + }, + }, + "mobilenet_v2": { + "imagenet": { + "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "input_space": "RGB", + "input_range": [0, 1], + } + }, + "xception": { + "imagenet": { + "url": "http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth", + "input_space": "RGB", + "input_size": [3, 299, 299], + "input_range": [0, 1], + "mean": [0.5, 0.5, 0.5], + "std": [0.5, 0.5, 0.5], + "num_classes": 1000, + "scale": 0.8975, + } + }, + "timm-efficientnet-b0": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b1": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b2": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b3": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b4": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b5": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b6": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b7": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-b8": { + "imagenet": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "advprop": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-efficientnet-l2": { + "noisy-student": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + "noisy-student-475": { + "mean": (0.485, 0.456, 0.406), + "std": (0.229, 0.224, 0.225), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth", + "input_range": (0, 1), + "input_space": "RGB", + }, + }, + "timm-tf_efficientnet_lite0": { + "imagenet": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth", + "input_range": (0, 1), + "input_space": "RGB", + } + }, + "timm-tf_efficientnet_lite1": { + "imagenet": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth", + "input_range": (0, 1), + "input_space": "RGB", + } + }, + "timm-tf_efficientnet_lite2": { + "imagenet": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth", + "input_range": (0, 1), + "input_space": "RGB", + } + }, + "timm-tf_efficientnet_lite3": { + "imagenet": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth", + "input_range": (0, 1), + "input_space": "RGB", + } + }, + "timm-tf_efficientnet_lite4": { + "imagenet": { + "mean": (0.5, 0.5, 0.5), + "std": (0.5, 0.5, 0.5), + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth", + "input_range": (0, 1), + "input_space": "RGB", + } + }, + "timm-skresnet18": { + "imagenet": { + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "timm-skresnet34": { + "imagenet": { + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "timm-skresnext50_32x4d": { + "imagenet": { + "url": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "mit_b0": { + "imagenet": { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b0.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + }, + "mit_b1": { + "imagenet": { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b1.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + }, + "mit_b2": { + "imagenet": { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b2.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + }, + "mit_b3": { + "imagenet": { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b3.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + }, + "mit_b4": { + "imagenet": { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b4.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + }, + "mit_b5": { + "imagenet": { + "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b5.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + } + }, + "mobileone_s0": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar", + "input_space": "RGB", + "input_range": [0, 1], + } + }, + "mobileone_s1": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar", + "input_space": "RGB", + "input_range": [0, 1], + } + }, + "mobileone_s2": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar", + "input_space": "RGB", + "input_range": [0, 1], + } + }, + "mobileone_s3": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar", + "input_space": "RGB", + "input_range": [0, 1], + } + }, + "mobileone_s4": { + "imagenet": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar", + "input_space": "RGB", + "input_range": [0, 1], + } + }, +} diff --git a/segmentation_models_pytorch/encoders/_senet.py b/segmentation_models_pytorch/encoders/_senet.py new file mode 100644 index 00000000..f56c776a --- /dev/null +++ b/segmentation_models_pytorch/encoders/_senet.py @@ -0,0 +1,337 @@ +""" +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +""" + +from collections import OrderedDict +import math + +import torch.nn as nn + + +class SEModule(nn.Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, + planes * 4, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, bias=False, stride=stride + ) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + + expansion = 4 + + def __init__( + self, + inplanes, + planes, + groups, + reduction, + stride=1, + downsample=None, + base_width=4, + ): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, + width, + kernel_size=3, + stride=stride, + padding=1, + groups=groups, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SENet(nn.Module): + def __init__( + self, + block, + layers, + groups, + reduction, + dropout_p=0.2, + inplanes=128, + input_3x3=True, + downsample_kernel_size=3, + downsample_padding=1, + num_classes=1000, + ): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + if input_3x3: + layer0_modules = [ + ("conv1", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)), + ("bn1", nn.BatchNorm2d(64)), + ("relu1", nn.ReLU(inplace=True)), + ("conv2", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), + ("bn2", nn.BatchNorm2d(64)), + ("relu2", nn.ReLU(inplace=True)), + ("conv3", nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), + ("bn3", nn.BatchNorm2d(inplanes)), + ("relu3", nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ( + "conv1", + nn.Conv2d( + 3, inplanes, kernel_size=7, stride=2, padding=3, bias=False + ), + ), + ("bn1", nn.BatchNorm2d(inplanes)), + ("relu1", nn.ReLU(inplace=True)), + ] + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + layer0_modules.append(("pool", nn.MaxPool2d(3, stride=2, ceil_mode=True))) + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0, + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding, + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding, + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding, + ) + self.avg_pool = nn.AvgPool2d(7, stride=1) + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer( + self, + block, + planes, + blocks, + groups, + reduction, + stride=1, + downsample_kernel_size=1, + downsample_padding=0, + ): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=downsample_kernel_size, + stride=stride, + padding=downsample_padding, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, groups, reduction, stride, downsample) + ) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def features(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.avg_pool(x) + if self.dropout is not None: + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.features(x) + x = self.logits(x) + return x diff --git a/segmentation_models_pytorch/encoders/_xception.py b/segmentation_models_pytorch/encoders/_xception.py new file mode 100644 index 00000000..4b6f308b --- /dev/null +++ b/segmentation_models_pytorch/encoders/_xception.py @@ -0,0 +1,231 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" + +import torch.nn as nn +import torch.nn.functional as F + + +class SeparableConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False, + ): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=bias, + ) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__( + self, + in_filters, + out_filters, + reps, + strides=1, + start_with_relu=True, + grow_first=True, + ): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d( + in_filters, out_filters, 1, stride=strides, bias=False + ) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + rep = [] + + filters = in_filters + if grow_first: + rep.append(nn.ReLU(inplace=True)) + rep.append( + SeparableConv2d( + in_filters, out_filters, 3, stride=1, padding=1, bias=False + ) + ) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps - 1): + rep.append(nn.ReLU(inplace=True)) + rep.append( + SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False) + ) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(nn.ReLU(inplace=True)) + rep.append( + SeparableConv2d( + in_filters, out_filters, 3, stride=1, padding=1, bias=False + ) + ) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=1000): + """Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.num_classes = num_classes + + self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + # do relu here + + self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) + self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) + self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) + + self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + self.relu3 = nn.ReLU(inplace=True) + + # do relu here + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + + self.fc = nn.Linear(2048, num_classes) + + # #------- init weights -------- + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + # elif isinstance(m, nn.BatchNorm2d): + # m.weight.data.fill_(1) + # m.bias.data.zero_() + # #----------------------------- + + def features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.relu3(x) + + x = self.conv4(x) + x = self.bn4(x) + return x + + def logits(self, features): + x = nn.ReLU(inplace=True)(features) + + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x diff --git a/segmentation_models_pytorch/encoders/densenet.py b/segmentation_models_pytorch/encoders/densenet.py index c4bd0ce2..ad0e0c25 100644 --- a/segmentation_models_pytorch/encoders/densenet.py +++ b/segmentation_models_pytorch/encoders/densenet.py @@ -24,33 +24,25 @@ """ import re -import torch.nn as nn -from pretrainedmodels.models.torchvision_models import pretrained_settings from torchvision.models.densenet import DenseNet from ._base import EncoderMixin -class TransitionWithSkip(nn.Module): - def __init__(self, module): - super().__init__() - self.module = module - - def forward(self, x): - for module in self.module: - x = module(x) - if isinstance(module, nn.ReLU): - skip = x - return x, skip - - class DenseNetEncoder(DenseNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__(self, out_channels, depth=5, output_stride=32, **kwargs): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__(**kwargs) - self._out_channels = out_channels + self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.classifier def make_dilated(self, *args, **kwargs): @@ -59,37 +51,44 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential( - self.features.conv0, self.features.norm0, self.features.relu0 - ), - nn.Sequential( - self.features.pool0, - self.features.denseblock1, - TransitionWithSkip(self.features.transition1), - ), - nn.Sequential( - self.features.denseblock2, TransitionWithSkip(self.features.transition2) - ), - nn.Sequential( - self.features.denseblock3, TransitionWithSkip(self.features.transition3) - ), - nn.Sequential(self.features.denseblock4, self.features.norm5), - ] - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - if isinstance(x, (list, tuple)): - x, skip = x - features.append(skip) - else: - features.append(x) + features = [x] + + if self._depth >= 1: + x = self.features.conv0(x) + x = self.features.norm0(x) + x = self.features.relu0(x) + features.append(x) + + if self._depth >= 2: + x = self.features.pool0(x) + x = self.features.denseblock1(x) + x = self.features.transition1.norm(x) + x = self.features.transition1.relu(x) + features.append(x) + + if self._depth >= 3: + x = self.features.transition1.conv(x) + x = self.features.transition1.pool(x) + x = self.features.denseblock2(x) + x = self.features.transition2.norm(x) + x = self.features.transition2.relu(x) + features.append(x) + + if self._depth >= 4: + x = self.features.transition2.conv(x) + x = self.features.transition2.pool(x) + x = self.features.denseblock3(x) + x = self.features.transition3.norm(x) + x = self.features.transition3.relu(x) + features.append(x) + + if self._depth >= 5: + x = self.features.transition3.conv(x) + x = self.features.transition3.pool(x) + x = self.features.denseblock4(x) + x = self.features.norm5(x) + features.append(x) return features @@ -114,42 +113,62 @@ def load_state_dict(self, state_dict): densenet_encoders = { "densenet121": { "encoder": DenseNetEncoder, - "pretrained_settings": pretrained_settings["densenet121"], "params": { - "out_channels": (3, 64, 256, 512, 1024, 1024), + "out_channels": [3, 64, 256, 512, 1024, 1024], "num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16), }, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/densenet121.imagenet", + "revision": "a17c96896a265b61338f66f61d3887b24f61995a", + } + }, }, "densenet169": { "encoder": DenseNetEncoder, - "pretrained_settings": pretrained_settings["densenet169"], "params": { - "out_channels": (3, 64, 256, 512, 1280, 1664), + "out_channels": [3, 64, 256, 512, 1280, 1664], "num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32), }, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/densenet169.imagenet", + "revision": "8facfba9fc72f7750879dac9ac6ceb3ab990de8d", + } + }, }, "densenet201": { "encoder": DenseNetEncoder, - "pretrained_settings": pretrained_settings["densenet201"], "params": { - "out_channels": (3, 64, 256, 512, 1792, 1920), + "out_channels": [3, 64, 256, 512, 1792, 1920], "num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32), }, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/densenet201.imagenet", + "revision": "ed5deb355d71659391d46fae5e7587460fbb5f84", + } + }, }, "densenet161": { "encoder": DenseNetEncoder, - "pretrained_settings": pretrained_settings["densenet161"], "params": { - "out_channels": (3, 96, 384, 768, 2112, 2208), + "out_channels": [3, 96, 384, 768, 2112, 2208], "num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24), }, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/densenet161.imagenet", + "revision": "9afe0fec51ab2a627141769d97d6f83756d78446", + } + }, }, } diff --git a/segmentation_models_pytorch/encoders/dpn.py b/segmentation_models_pytorch/encoders/dpn.py index 220c66de..527bbc02 100644 --- a/segmentation_models_pytorch/encoders/dpn.py +++ b/segmentation_models_pytorch/encoders/dpn.py @@ -24,49 +24,73 @@ """ import torch -import torch.nn as nn import torch.nn.functional as F - -from pretrainedmodels.models.dpn import DPN -from pretrainedmodels.models.dpn import pretrained_settings +from typing import List, Dict, Sequence from ._base import EncoderMixin +from ._dpn import DPN class DPNEncoder(DPN, EncoderMixin): - def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + _is_torch_scriptable = False + _is_torch_exportable = True # since torch 2.6.0 + + def __init__( + self, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__(**kwargs) self._stage_idxs = stage_idxs self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.last_linear - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential( - self.features[0].conv, self.features[0].bn, self.features[0].act - ), - nn.Sequential( - self.features[0].pool, self.features[1 : self._stage_idxs[0]] - ), - self.features[self._stage_idxs[0] : self._stage_idxs[1]], - self.features[self._stage_idxs[1] : self._stage_idxs[2]], - self.features[self._stage_idxs[2] : self._stage_idxs[3]], - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - if isinstance(x, (list, tuple)): - features.append(F.relu(torch.cat(x, dim=1), inplace=True)) - else: - features.append(x) + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.features[self._stage_idxs[1] : self._stage_idxs[2]]], + 32: [self.features[self._stage_idxs[2] : self._stage_idxs[3]]], + } + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] + + if self._depth >= 1: + x = self.features[0].conv(x) + x = self.features[0].bn(x) + x = self.features[0].act(x) + features.append(x) + + if self._depth >= 2: + x = self.features[0].pool(x) + x = self.features[1 : self._stage_idxs[0]](x) + skip = F.relu(torch.cat(x, dim=1), inplace=True) + features.append(skip) + + if self._depth >= 3: + x = self.features[self._stage_idxs[0] : self._stage_idxs[1]](x) + skip = F.relu(torch.cat(x, dim=1), inplace=True) + features.append(skip) + + if self._depth >= 4: + x = self.features[self._stage_idxs[1] : self._stage_idxs[2]](x) + skip = F.relu(torch.cat(x, dim=1), inplace=True) + features.append(skip) + + if self._depth >= 5: + x = self.features[self._stage_idxs[2] : self._stage_idxs[3]](x) + features.append(x) return features @@ -79,10 +103,15 @@ def load_state_dict(self, state_dict, **kwargs): dpn_encoders = { "dpn68": { "encoder": DPNEncoder, - "pretrained_settings": pretrained_settings["dpn68"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/dpn68.imagenet", + "revision": "c209aefdeae6bc93937556629e974b44d4e58535", + } + }, "params": { - "stage_idxs": (4, 8, 20, 24), - "out_channels": (3, 10, 144, 320, 704, 832), + "stage_idxs": [4, 8, 20, 24], + "out_channels": [3, 10, 144, 320, 704, 832], "groups": 32, "inc_sec": (16, 32, 32, 64), "k_r": 128, @@ -95,10 +124,15 @@ def load_state_dict(self, state_dict, **kwargs): }, "dpn68b": { "encoder": DPNEncoder, - "pretrained_settings": pretrained_settings["dpn68b"], + "pretrained_settings": { + "imagenet+5k": { + "repo_id": "smp-hub/dpn68b.imagenet-5k", + "revision": "6c6615e77688e390ae0eaa81e26821fbd83cee4b", + } + }, "params": { - "stage_idxs": (4, 8, 20, 24), - "out_channels": (3, 10, 144, 320, 704, 832), + "stage_idxs": [4, 8, 20, 24], + "out_channels": [3, 10, 144, 320, 704, 832], "b": True, "groups": 32, "inc_sec": (16, 32, 32, 64), @@ -112,10 +146,15 @@ def load_state_dict(self, state_dict, **kwargs): }, "dpn92": { "encoder": DPNEncoder, - "pretrained_settings": pretrained_settings["dpn92"], + "pretrained_settings": { + "imagenet+5k": { + "repo_id": "smp-hub/dpn92.imagenet-5k", + "revision": "d231f51ce4ad2c84ed5fcaf4ef0cfece6814a526", + } + }, "params": { - "stage_idxs": (4, 8, 28, 32), - "out_channels": (3, 64, 336, 704, 1552, 2688), + "stage_idxs": [4, 8, 28, 32], + "out_channels": [3, 64, 336, 704, 1552, 2688], "groups": 32, "inc_sec": (16, 32, 24, 128), "k_r": 96, @@ -127,10 +166,15 @@ def load_state_dict(self, state_dict, **kwargs): }, "dpn98": { "encoder": DPNEncoder, - "pretrained_settings": pretrained_settings["dpn98"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/dpn98.imagenet", + "revision": "b2836c86216c1ddce980d832f7deaa4ca22babd3", + } + }, "params": { - "stage_idxs": (4, 10, 30, 34), - "out_channels": (3, 96, 336, 768, 1728, 2688), + "stage_idxs": [4, 10, 30, 34], + "out_channels": [3, 96, 336, 768, 1728, 2688], "groups": 40, "inc_sec": (16, 32, 32, 128), "k_r": 160, @@ -142,10 +186,15 @@ def load_state_dict(self, state_dict, **kwargs): }, "dpn107": { "encoder": DPNEncoder, - "pretrained_settings": pretrained_settings["dpn107"], + "pretrained_settings": { + "imagenet+5k": { + "repo_id": "smp-hub/dpn107.imagenet-5k", + "revision": "dab4cd6b8b79de3db970f2dbff85359a8847db05", + } + }, "params": { - "stage_idxs": (5, 13, 33, 37), - "out_channels": (3, 128, 376, 1152, 2432, 2688), + "stage_idxs": [5, 13, 33, 37], + "out_channels": [3, 128, 376, 1152, 2432, 2688], "groups": 50, "inc_sec": (20, 64, 64, 128), "k_r": 200, @@ -157,10 +206,15 @@ def load_state_dict(self, state_dict, **kwargs): }, "dpn131": { "encoder": DPNEncoder, - "pretrained_settings": pretrained_settings["dpn131"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/dpn131.imagenet", + "revision": "04bbb9f415ca2bb59f3d8227857967b74698515e", + } + }, "params": { - "stage_idxs": (5, 13, 41, 45), - "out_channels": (3, 128, 352, 832, 1984, 2688), + "stage_idxs": [5, 13, 41, 45], + "out_channels": [3, 128, 352, 832, 1984, 2688], "groups": 40, "inc_sec": (16, 32, 32, 128), "k_r": 160, diff --git a/segmentation_models_pytorch/encoders/efficientnet.py b/segmentation_models_pytorch/encoders/efficientnet.py index 4a7af6b4..70046e44 100644 --- a/segmentation_models_pytorch/encoders/efficientnet.py +++ b/segmentation_models_pytorch/encoders/efficientnet.py @@ -23,56 +23,68 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ -import torch.nn as nn -from efficientnet_pytorch import EfficientNet -from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params +import torch +from typing import List, Dict, Sequence from ._base import EncoderMixin +from ._efficientnet import EfficientNet, get_model_params class EfficientNetEncoder(EfficientNet, EncoderMixin): - def __init__(self, stage_idxs, out_channels, model_name, depth=5): + def __init__( + self, + out_indexes: List[int], + out_channels: List[int], + model_name: str, + depth: int = 5, + output_stride: int = 32, + ): + if depth > 5 or depth < 2: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + blocks_args, global_params = get_model_params(model_name, override_params=None) super().__init__(blocks_args, global_params) - self._stage_idxs = stage_idxs - self._out_channels = out_channels + self._out_indexes = out_indexes self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + self._drop_connect_rate = self._global_params.drop_connect_rate del self._fc - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self._conv_stem, self._bn0, self._swish), - self._blocks[: self._stage_idxs[0]], - self._blocks[self._stage_idxs[0] : self._stage_idxs[1]], - self._blocks[self._stage_idxs[1] : self._stage_idxs[2]], - self._blocks[self._stage_idxs[2] :], - ] - - def forward(self, x): - stages = self.get_stages() - - block_number = 0.0 - drop_connect_rate = self._global_params.drop_connect_rate - - features = [] - for i in range(self._depth + 1): - # Identity and Sequential stages - if i < 2: - x = stages[i](x) - - # Block stages need drop_connect rate - else: - for module in stages[i]: - drop_connect = drop_connect_rate * block_number / len(self._blocks) - block_number += 1.0 - x = module(x, drop_connect) + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self._blocks[self._out_indexes[1] + 1 : self._out_indexes[2] + 1]], + 32: [self._blocks[self._out_indexes[2] + 1 :]], + } + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] + + if self._depth >= 1: + x = self._conv_stem(x) + x = self._bn0(x) + x = self._swish(x) features.append(x) + depth = 1 + for i, block in enumerate(self._blocks): + drop_connect_prob = self._drop_connect_rate * i / len(self._blocks) + x = block(x, drop_connect_prob) + + if i in self._out_indexes: + features.append(x) + depth += 1 + + if not torch.jit.is_scripting() and depth > self._depth: + break + + features = features[: self._depth + 1] + return features def load_state_dict(self, state_dict, **kwargs): @@ -81,96 +93,148 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) -def _get_pretrained_settings(encoder): - pretrained_settings = { - "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": url_map[encoder], - "input_space": "RGB", - "input_range": [0, 1], - }, - "advprop": { - "mean": [0.5, 0.5, 0.5], - "std": [0.5, 0.5, 0.5], - "url": url_map_advprop[encoder], - "input_space": "RGB", - "input_range": [0, 1], - }, - } - return pretrained_settings - - efficient_net_encoders = { "efficientnet-b0": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b0.imagenet", + "revision": "1bbe7ecc1d5ea1d2058de1a2db063b8701aff314", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b0.advprop", + "revision": "29043c08140d9c6ee7de1468d55923f2b06bcec2", + }, + }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (3, 5, 9, 16), + "out_channels": [3, 32, 24, 40, 112, 320], + "out_indexes": [2, 4, 8, 15], "model_name": "efficientnet-b0", }, }, "efficientnet-b1": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b1.imagenet", + "revision": "5d637466a5215de300a8ccb13a39357df2df2bf4", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b1.advprop", + "revision": "2e518b8b0955bbab467f50525578dab6b6086afc", + }, + }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (5, 8, 16, 23), + "out_channels": [3, 32, 24, 40, 112, 320], + "out_indexes": [4, 7, 15, 22], "model_name": "efficientnet-b1", }, }, "efficientnet-b2": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b2.imagenet", + "revision": "a96d4f0295ffbae18ebba173bf7f3c0c8f21990e", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b2.advprop", + "revision": "be788c20dfb0bbe83b4c439f9cfe0dd937c0783e", + }, + }, "params": { - "out_channels": (3, 32, 24, 48, 120, 352), - "stage_idxs": (5, 8, 16, 23), + "out_channels": [3, 32, 24, 48, 120, 352], + "out_indexes": [4, 7, 15, 22], "model_name": "efficientnet-b2", }, }, "efficientnet-b3": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b3.imagenet", + "revision": "074c54a6c473e0d294690d49cedb6cf463e7127d", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b3.advprop", + "revision": "9ccc166d87bd9c08d6bed4477638c7f4bb3eec78", + }, + }, "params": { - "out_channels": (3, 40, 32, 48, 136, 384), - "stage_idxs": (5, 8, 18, 26), + "out_channels": [3, 40, 32, 48, 136, 384], + "out_indexes": [4, 7, 17, 25], "model_name": "efficientnet-b3", }, }, "efficientnet-b4": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b4.imagenet", + "revision": "05cd5dde5dab658f00c463f9b9aa0ced76784f40", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b4.advprop", + "revision": "f04caa809ea4eb08ee9e7fd555f5514ebe2a9ef5", + }, + }, "params": { - "out_channels": (3, 48, 32, 56, 160, 448), - "stage_idxs": (6, 10, 22, 32), + "out_channels": [3, 48, 32, 56, 160, 448], + "out_indexes": [5, 9, 21, 31], "model_name": "efficientnet-b4", }, }, "efficientnet-b5": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b5.imagenet", + "revision": "69f4d28460a4e421b7860bc26ee7d832e03e01ca", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b5.advprop", + "revision": "dabe78fc8ab7ce93ddc2bb156b01db227caede88", + }, + }, "params": { - "out_channels": (3, 48, 40, 64, 176, 512), - "stage_idxs": (8, 13, 27, 39), + "out_channels": [3, 48, 40, 64, 176, 512], + "out_indexes": [7, 12, 26, 38], "model_name": "efficientnet-b5", }, }, "efficientnet-b6": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b6.imagenet", + "revision": "8570752016f7c62ae149cffa058550fe44e21c8b", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b6.advprop", + "revision": "c2dbb4d1359151165ec7b96cfe54a9cac2142a31", + }, + }, "params": { - "out_channels": (3, 56, 40, 72, 200, 576), - "stage_idxs": (9, 15, 31, 45), + "out_channels": [3, 56, 40, 72, 200, 576], + "out_indexes": [8, 14, 30, 44], "model_name": "efficientnet-b6", }, }, "efficientnet-b7": { "encoder": EfficientNetEncoder, - "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/efficientnet-b7.imagenet", + "revision": "5a5dbe687d612ebc3dca248274fd1191111deda6", + }, + "advprop": { + "repo_id": "smp-hub/efficientnet-b7.advprop", + "revision": "ce33edb4e80c0cde268f098ae2299e23f615577d", + }, + }, "params": { - "out_channels": (3, 64, 48, 80, 224, 640), - "stage_idxs": (11, 18, 38, 55), + "out_channels": [3, 64, 48, 80, 224, 640], + "out_indexes": [10, 17, 37, 54], "model_name": "efficientnet-b7", }, }, diff --git a/segmentation_models_pytorch/encoders/inceptionresnetv2.py b/segmentation_models_pytorch/encoders/inceptionresnetv2.py index 5d90c7f4..d7f83f9d 100644 --- a/segmentation_models_pytorch/encoders/inceptionresnetv2.py +++ b/segmentation_models_pytorch/encoders/inceptionresnetv2.py @@ -23,20 +23,33 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torch.nn as nn -from pretrainedmodels.models.inceptionresnetv2 import InceptionResNetV2 -from pretrainedmodels.models.inceptionresnetv2 import pretrained_settings +from typing import List from ._base import EncoderMixin +from ._inceptionresnetv2 import InceptionResNetV2 class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__(**kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride # correct paddings for m in self.modules(): @@ -46,6 +59,9 @@ def __init__(self, out_channels, depth=5, **kwargs): if isinstance(m, nn.MaxPool2d): m.padding = (1, 1) + # for torchscript, block8 does not have relu defined + self.block8.relu = nn.Identity() + # remove linear layers del self.avgpool_1a del self.last_linear @@ -56,22 +72,37 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv2d_1a, self.conv2d_2a, self.conv2d_2b), - nn.Sequential(self.maxpool_3a, self.conv2d_3b, self.conv2d_4a), - nn.Sequential(self.maxpool_5a, self.mixed_5b, self.repeat), - nn.Sequential(self.mixed_6a, self.repeat_1), - nn.Sequential(self.mixed_7a, self.repeat_2, self.block8, self.conv2d_7b), - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] + + if self._depth >= 1: + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + features.append(x) + + if self._depth >= 2: + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + features.append(x) + + if self._depth >= 3: + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + features.append(x) + + if self._depth >= 4: + x = self.mixed_6a(x) + x = self.repeat_1(x) + features.append(x) + + if self._depth >= 5: + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) features.append(x) return features @@ -85,7 +116,16 @@ def load_state_dict(self, state_dict, **kwargs): inceptionresnetv2_encoders = { "inceptionresnetv2": { "encoder": InceptionResNetV2Encoder, - "pretrained_settings": pretrained_settings["inceptionresnetv2"], - "params": {"out_channels": (3, 64, 192, 320, 1088, 1536), "num_classes": 1000}, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/inceptionresnetv2.imagenet", + "revision": "120c5afdbb80a1c989db0a7423ebb7a9db9b1e6c", + }, + "imagenet+background": { + "repo_id": "smp-hub/inceptionresnetv2.imagenet-background", + "revision": "3ecf3491658dc0f6a76d69c9d1cb36511b1ee56c", + }, + }, + "params": {"out_channels": [3, 64, 192, 320, 1088, 1536], "num_classes": 1000}, } } diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 96540f9a..3c335042 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -23,20 +23,34 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torch.nn as nn -from pretrainedmodels.models.inceptionv4 import InceptionV4 -from pretrainedmodels.models.inceptionv4 import pretrained_settings + +from typing import List from ._base import EncoderMixin +from ._inceptionv4 import InceptionV4 class InceptionV4Encoder(InceptionV4, EncoderMixin): - def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) - self._stage_idxs = stage_idxs - self._out_channels = out_channels + self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + self._out_indexes = [2, 4, 8, 14, len(self.features) - 1] # correct paddings for m in self.modules(): @@ -55,24 +69,23 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - self.features[: self._stage_idxs[0]], - self.features[self._stage_idxs[0] : self._stage_idxs[1]], - self.features[self._stage_idxs[1] : self._stage_idxs[2]], - self.features[self._stage_idxs[2] : self._stage_idxs[3]], - self.features[self._stage_idxs[3] :], - ] + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + depth = 0 + features = [x] + + for i, module in enumerate(self.features): + x = module(x) - def forward(self, x): - stages = self.get_stages() + if i in self._out_indexes: + features.append(x) + depth += 1 - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) + # torchscript does not support break in cycle, so we just + # go over all modules and then slice number of features + if not torch.jit.is_scripting() and depth > self._depth: + break + features = features[: self._depth + 1] return features def load_state_dict(self, state_dict, **kwargs): @@ -84,10 +97,18 @@ def load_state_dict(self, state_dict, **kwargs): inceptionv4_encoders = { "inceptionv4": { "encoder": InceptionV4Encoder, - "pretrained_settings": pretrained_settings["inceptionv4"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/inceptionv4.imagenet", + "revision": "918fb54f07811d82a4ecde3a51156041d0facba9", + }, + "imagenet+background": { + "repo_id": "smp-hub/inceptionv4.imagenet-background", + "revision": "8c2a48e20d2709ee64f8421c61be309f05bfa536", + }, + }, "params": { - "stage_idxs": (3, 5, 9, 15), - "out_channels": (3, 64, 192, 384, 1024, 1536), + "out_channels": [3, 64, 192, 384, 1024, 1536], "num_classes": 1001, }, } diff --git a/segmentation_models_pytorch/encoders/mix_transformer.py b/segmentation_models_pytorch/encoders/mix_transformer.py index 0cc3fb21..d5dca7fd 100644 --- a/segmentation_models_pytorch/encoders/mix_transformer.py +++ b/segmentation_models_pytorch/encoders/mix_transformer.py @@ -11,20 +11,22 @@ import math import torch import torch.nn as nn +import torch.nn.functional as F from functools import partial +from typing import Dict, Sequence, List from timm.layers import DropPath, to_2tuple, trunc_normal_ class LayerNorm(nn.LayerNorm): - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim == 4: - B, C, H, W = x.shape - x = x.view(B, C, -1).transpose(1, 2) - x = super().forward(x) - x = x.transpose(1, 2).view(B, C, H, W) + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, -1).transpose(1, 2) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(1, 2).view(batch_size, channels, height, width) else: - x = super().forward(x) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x @@ -60,9 +62,9 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x, H, W): + def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: x = self.fc1(x) - x = self.dwconv(x, H, W) + x = self.dwconv(x, height, width) x = self.act(x) x = self.drop(x) x = self.fc2(x) @@ -82,9 +84,9 @@ def __init__( sr_ratio=1, ): super().__init__() - assert ( - dim % num_heads == 0 - ), f"dim {dim} should be divided by num_heads {num_heads}." + assert dim % num_heads == 0, ( + f"dim {dim} should be divided by num_heads {num_heads}." + ) self.dim = dim self.num_heads = num_heads @@ -101,6 +103,10 @@ def __init__( if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = LayerNorm(dim) + else: + # for torchscript compatibility + self.sr = nn.Identity() + self.norm = nn.Identity() self.apply(self._init_weights) @@ -119,27 +125,27 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x, H, W): - B, N, C = x.shape + def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: + batch_size, N, C = x.shape q = ( self.q(x) - .reshape(B, N, self.num_heads, C // self.num_heads) + .reshape(batch_size, N, self.num_heads, C // self.num_heads) .permute(0, 2, 1, 3) ) if self.sr_ratio > 1: - x_ = x.permute(0, 2, 1).reshape(B, C, H, W) - x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) + x_ = x.permute(0, 2, 1).reshape(batch_size, C, height, width) + x_ = self.sr(x_).reshape(batch_size, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = ( self.kv(x_) - .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) else: kv = ( self.kv(x) - .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .reshape(batch_size, -1, 2, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) k, v = kv[0], kv[1] @@ -148,7 +154,7 @@ def forward(self, x, H, W): attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = (attn @ v).transpose(1, 2).reshape(batch_size, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -209,12 +215,12 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x): - B, _, H, W = x.shape + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, _, height, width = x.shape x = x.flatten(2).transpose(1, 2) - x = x + self.drop_path(self.attn(self.norm1(x), H, W)) - x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) - x = x.transpose(1, 2).view(B, -1, H, W) + x = x + self.drop_path(self.attn(self.norm1(x), height, width)) + x = x + self.drop_path(self.mlp(self.norm2(x), height, width)) + x = x.transpose(1, 2).view(batch_size, -1, height, width) return x @@ -256,7 +262,7 @@ def _init_weights(self, m): if m.bias is not None: m.bias.data.zero_() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) x = self.norm(x) return x @@ -462,7 +468,7 @@ def reset_classifier(self, num_classes, global_pool=""): nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> List[torch.Tensor]: outs = [] # stage 1 @@ -491,11 +497,11 @@ def forward_features(self, x): return outs - def forward(self, x): - x = self.forward_features(x) + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = self.forward_features(x) # x = self.head(x) - return x + return features class DWConv(nn.Module): @@ -503,9 +509,9 @@ def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) - def forward(self, x, H, W): - B, _, C = x.shape - x = x.transpose(1, 2).view(B, C, H, W) + def forward(self, x: torch.Tensor, height: int, width: int) -> torch.Tensor: + batch_size, _, channels = x.shape + x = x.transpose(1, 2).view(batch_size, channels, height, width) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) @@ -520,36 +526,63 @@ def forward(self, x, H, W): class MixVisionTransformerEncoder(MixVisionTransformer, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) - self._out_channels = out_channels + self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride - def get_stages(self): - return [ - nn.Identity(), - nn.Identity(), - nn.Sequential(self.patch_embed1, self.block1, self.norm1), - nn.Sequential(self.patch_embed2, self.block2, self.norm2), - nn.Sequential(self.patch_embed3, self.block3, self.norm3), - nn.Sequential(self.patch_embed4, self.block4, self.norm4), - ] - - def forward(self, x): - stages = self.get_stages() + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.patch_embed3, self.block3, self.norm3], + 32: [self.patch_embed4, self.block4, self.norm4], + } + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: # create dummy output for the first block - B, _, H, W = x.shape - dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device) - - features = [] - for i in range(self._depth + 1): - if i == 1: - features.append(dummy) - else: - x = stages[i](x).contiguous() - features.append(x) + batch_size, _, height, width = x.shape + dummy = torch.empty( + [batch_size, 0, height // 2, width // 2], dtype=x.dtype, device=x.device + ) + + features = [x, dummy] + + if self._depth >= 2: + x = self.patch_embed1(x) + x = self.block1(x) + x = self.norm1(x) + x = x.contiguous() + features.append(x) + + if self._depth >= 3: + x = self.patch_embed2(x) + x = self.block2(x) + x = self.norm2(x) + x = x.contiguous() + features.append(x) + + if self._depth >= 4: + x = self.patch_embed3(x) + x = self.block3(x) + x = self.norm3(x) + x = x.contiguous() + features.append(x) + + if self._depth >= 5: + x = self.patch_embed4(x) + x = self.block4(x) + x = self.norm4(x) + x = x.contiguous() + features.append(x) + return features def load_state_dict(self, state_dict): @@ -558,120 +591,137 @@ def load_state_dict(self, state_dict): return super().load_state_dict(state_dict) -def get_pretrained_cfg(name): - return { - "url": "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/{}.pth".format( - name - ), - "input_space": "RGB", - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - } - - mix_transformer_encoders = { "mit_b0": { "encoder": MixVisionTransformerEncoder, - "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b0")}, - "params": dict( - out_channels=(3, 0, 32, 64, 160, 256), - patch_size=4, - embed_dims=[32, 64, 160, 256], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[2, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/mit_b0.imagenet", + "revision": "9ce53d104d92d75aabb00aae70677aaab67e7c84", + } + }, + "params": { + "out_channels": [3, 0, 32, 64, 160, 256], + "patch_size": 4, + "embed_dims": [32, 64, 160, 256], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [2, 2, 2, 2], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b1": { "encoder": MixVisionTransformerEncoder, - "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b1")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[2, 2, 2, 2], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/mit_b1.imagenet", + "revision": "a04bf4f13a549bce677cf79b04852e7510782817", + } + }, + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [2, 2, 2, 2], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b2": { "encoder": MixVisionTransformerEncoder, - "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b2")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 4, 6, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/mit_b2.imagenet", + "revision": "868ab6f13871dcf8c3d9f90ee4519403475b65ef", + } + }, + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 4, 6, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b3": { "encoder": MixVisionTransformerEncoder, - "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b3")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 4, 18, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/mit_b3.imagenet", + "revision": "32558d12a65f1daa0ebcf4f4053c4285e2c1cbda", + } + }, + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 4, 18, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b4": { "encoder": MixVisionTransformerEncoder, - "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b4")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 8, 27, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/mit_b4.imagenet", + "revision": "3a3454e900a4b4f11dd60eeb59101a9a1a36b017", + } + }, + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 8, 27, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, "mit_b5": { "encoder": MixVisionTransformerEncoder, - "pretrained_settings": {"imagenet": get_pretrained_cfg("mit_b5")}, - "params": dict( - out_channels=(3, 0, 64, 128, 320, 512), - patch_size=4, - embed_dims=[64, 128, 320, 512], - num_heads=[1, 2, 5, 8], - mlp_ratios=[4, 4, 4, 4], - qkv_bias=True, - norm_layer=partial(LayerNorm, eps=1e-6), - depths=[3, 6, 40, 3], - sr_ratios=[8, 4, 2, 1], - drop_rate=0.0, - drop_path_rate=0.1, - ), + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/mit_b5.imagenet", + "revision": "ced04d96c586b6297fd59a7a1e244fc78fdb6531", + } + }, + "params": { + "out_channels": [3, 0, 64, 128, 320, 512], + "patch_size": 4, + "embed_dims": [64, 128, 320, 512], + "num_heads": [1, 2, 5, 8], + "mlp_ratios": [4, 4, 4, 4], + "qkv_bias": True, + "norm_layer": partial(LayerNorm, eps=1e-6), + "depths": [3, 6, 40, 3], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + }, }, } diff --git a/segmentation_models_pytorch/encoders/mobilenet.py b/segmentation_models_pytorch/encoders/mobilenet.py index dd30f142..793a9be2 100644 --- a/segmentation_models_pytorch/encoders/mobilenet.py +++ b/segmentation_models_pytorch/encoders/mobilenet.py @@ -23,37 +23,54 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torchvision -import torch.nn as nn +from typing import Dict, Sequence, List from ._base import EncoderMixin class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) + self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + self._out_indexes = [1, 3, 6, 13, len(self.features) - 1] + del self.classifier - def get_stages(self): - return [ - nn.Identity(), - self.features[:2], - self.features[2:4], - self.features[4:7], - self.features[7:14], - self.features[14:], - ] + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.features[7:14]], + 32: [self.features[14:]], + } + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] + + depth = 0 + for i, module in enumerate(self.features): + x = module(x) + + if i in self._out_indexes: + features.append(x) + depth += 1 - def forward(self, x): - stages = self.get_stages() + # torchscript does not support break in cycle, so we just + # go over all modules and then slice number of features + if not torch.jit.is_scripting() and depth > self._depth: + break - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) + features = features[: self._depth + 1] return features @@ -68,13 +85,10 @@ def load_state_dict(self, state_dict, **kwargs): "encoder": MobileNetV2Encoder, "pretrained_settings": { "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", - "input_space": "RGB", - "input_range": [0, 1], + "repo_id": "smp-hub/mobilenet_v2.imagenet", + "revision": "e67aa804e17f7b404b629127eabbd224c4e0690b", } }, - "params": {"out_channels": (3, 16, 24, 32, 96, 1280)}, + "params": {"out_channels": [3, 16, 24, 32, 96, 1280]}, } } diff --git a/segmentation_models_pytorch/encoders/mobileone.py b/segmentation_models_pytorch/encoders/mobileone.py index 76f50053..ba2947d0 100644 --- a/segmentation_models_pytorch/encoders/mobileone.py +++ b/segmentation_models_pytorch/encoders/mobileone.py @@ -3,7 +3,7 @@ # Copyright (C) 2022 Apple Inc. All Rights Reserved. # import copy -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Sequence import torch import torch.nn as nn @@ -120,6 +120,8 @@ def __init__( bias=True, ) else: + self.reparam_conv = nn.Identity() + # Re-parameterizable skip connection self.rbr_skip = ( nn.BatchNorm2d(num_features=in_channels) @@ -157,8 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Other branches out = scale_out + identity_out - for ix in range(self.num_conv_branches): - out += self.rbr_conv[ix](x) + for module in self.rbr_conv: + out += module(x) return self.activation(self.se(out)) @@ -298,13 +300,14 @@ class MobileOne(nn.Module, EncoderMixin): def __init__( self, - out_channels, + out_channels: List[int], num_blocks_per_stage: List[int] = [2, 8, 10, 1], width_multipliers: Optional[List[float]] = None, inference_mode: bool = False, use_se: bool = False, depth=5, in_channels=3, + output_stride=32, num_conv_branches: int = 1, ) -> None: """Construct MobileOne model. @@ -316,17 +319,23 @@ def __init__( :param use_se: Whether to use SE-ReLU activations. :param num_conv_branches: Number of linear conv branches. """ + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__() assert len(width_multipliers) == 4 self.inference_mode = inference_mode - self._out_channels = out_channels self.in_planes = min(64, int(64 * width_multipliers[0])) self.use_se = use_se self.num_conv_branches = num_conv_branches + self._depth = depth self._in_channels = in_channels - self.set_in_channels(self._in_channels) + self._out_channels = out_channels + self._output_stride = output_stride # Build stages self.stage0 = MobileOneBlock( @@ -355,15 +364,11 @@ def __init__( num_se_blocks=num_blocks_per_stage[3] if use_se else 0, ) - def get_stages(self): - return [ - nn.Identity(), - self.stage0, - self.stage1, - self.stage2, - self.stage3, - self.stage4, - ] + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.stage3], + 32: [self.stage4], + } def _make_stage( self, planes: int, num_blocks: int, num_se_blocks: int @@ -381,9 +386,7 @@ def _make_stage( for ix, stride in enumerate(strides): use_se = False if num_se_blocks > num_blocks: - raise ValueError( - "Number of SE blocks cannot " "exceed number of layers." - ) + raise ValueError("Number of SE blocks cannot exceed number of layers.") if ix >= (num_blocks - num_se_blocks): use_se = True @@ -419,13 +422,30 @@ def _make_stage( self.cur_layer_idx += 1 return nn.Sequential(*blocks) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: """Apply forward pass.""" - stages = self.get_stages() - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + features = [x] + + if self._depth >= 1: + x = self.stage0(x) features.append(x) + + if self._depth >= 2: + x = self.stage1(x) + features.append(x) + + if self._depth >= 3: + x = self.stage2(x) + features.append(x) + + if self._depth >= 4: + x = self.stage3(x) + features.append(x) + + if self._depth >= 5: + x = self.stage4(x) + features.append(x) + return features def load_state_dict(self, state_dict, **kwargs): @@ -473,15 +493,12 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: "encoder": MobileOne, "pretrained_settings": { "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar", # noqa - "input_space": "RGB", - "input_range": [0, 1], + "repo_id": "smp-hub/mobileone_s0.imagenet", + "revision": "f52815cf0ad29278a9860c9cd5fabf19f904bedf", } }, "params": { - "out_channels": (3, 48, 48, 128, 256, 1024), + "out_channels": [3, 48, 48, 128, 256, 1024], "width_multipliers": (0.75, 1.0, 1.0, 2.0), "num_conv_branches": 4, "inference_mode": False, @@ -491,15 +508,12 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: "encoder": MobileOne, "pretrained_settings": { "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar", # noqa - "input_space": "RGB", - "input_range": [0, 1], + "repo_id": "smp-hub/mobileone_s1.imagenet", + "revision": "5707a98852b762cd8e0c43b5c8c729cd28496677", } }, "params": { - "out_channels": (3, 64, 96, 192, 512, 1280), + "out_channels": [3, 64, 96, 192, 512, 1280], "width_multipliers": (1.5, 1.5, 2.0, 2.5), "inference_mode": False, }, @@ -508,15 +522,12 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: "encoder": MobileOne, "pretrained_settings": { "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar", # noqa - "input_space": "RGB", - "input_range": [0, 1], + "repo_id": "smp-hub/mobileone_s2.imagenet", + "revision": "ddc3db8fa40d271902c7a8c95cee6691f617d551", } }, "params": { - "out_channels": (3, 64, 96, 256, 640, 2048), + "out_channels": [3, 64, 96, 256, 640, 2048], "width_multipliers": (1.5, 2.0, 2.5, 4.0), "inference_mode": False, }, @@ -525,15 +536,12 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: "encoder": MobileOne, "pretrained_settings": { "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar", # noqa - "input_space": "RGB", - "input_range": [0, 1], + "repo_id": "smp-hub/mobileone_s3.imagenet", + "revision": "da89b84a91b7400c366c358bfbf8dd0b2fa4dde2", } }, "params": { - "out_channels": (3, 64, 128, 320, 768, 2048), + "out_channels": [3, 64, 128, 320, 768, 2048], "width_multipliers": (2.0, 2.5, 3.0, 4.0), "inference_mode": False, }, @@ -542,15 +550,12 @@ def reparameterize_model(model: torch.nn.Module) -> nn.Module: "encoder": MobileOne, "pretrained_settings": { "imagenet": { - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "url": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar", # noqa - "input_space": "RGB", - "input_range": [0, 1], + "repo_id": "smp-hub/mobileone_s4.imagenet", + "revision": "16197c55d599076b6aae67a83d3b3f70c31b097c", } }, "params": { - "out_channels": (3, 64, 192, 448, 896, 2048), + "out_channels": [3, 64, 192, 448, 896, 2048], "width_multipliers": (3.0, 3.5, 3.5, 4.0), "use_se": True, "inference_mode": False, diff --git a/segmentation_models_pytorch/encoders/resnet.py b/segmentation_models_pytorch/encoders/resnet.py index 2040a42c..d4f2db4e 100644 --- a/segmentation_models_pytorch/encoders/resnet.py +++ b/segmentation_models_pytorch/encoders/resnet.py @@ -23,44 +23,65 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ -from copy import deepcopy - -import torch.nn as nn - +import torch +from typing import Dict, Sequence, List from torchvision.models.resnet import ResNet from torchvision.models.resnet import BasicBlock from torchvision.models.resnet import Bottleneck -from pretrainedmodels.models.torchvision_models import pretrained_settings from ._base import EncoderMixin class ResNetEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + """ResNet encoder implementation.""" + + def __init__( + self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) + self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.fc del self.avgpool - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.relu), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.layer3], + 32: [self.layer4], + } + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + features = [x] + + if self._depth >= 1: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + features.append(x) + + if self._depth >= 2: + x = self.maxpool(x) + x = self.layer1(x) + features.append(x) + + if self._depth >= 3: + x = self.layer2(x) + features.append(x) + + if self._depth >= 4: + x = self.layer3(x) + features.append(x) + + if self._depth >= 5: + x = self.layer4(x) features.append(x) return features @@ -71,110 +92,111 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) -new_settings = { - "resnet18": { - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth", # noqa - }, - "resnet50": { - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth", # noqa - }, - "resnext50_32x4d": { - "imagenet": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth", # noqa - }, - "resnext101_32x4d": { - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth", # noqa - }, - "resnext101_32x8d": { - "imagenet": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth", - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth", # noqa - }, - "resnext101_32x16d": { - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth", - "ssl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth", # noqa - "swsl": "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth", # noqa - }, - "resnext101_32x32d": { - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth" - }, - "resnext101_32x48d": { - "instagram": "https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth" - }, -} - -pretrained_settings = deepcopy(pretrained_settings) -for model_name, sources in new_settings.items(): - if model_name not in pretrained_settings: - pretrained_settings[model_name] = {} - - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - - resnet_encoders = { "resnet18": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnet18"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnet18.imagenet", + "revision": "3f2325ff978283d47aa6a1d6878ca20565622683", + }, + "ssl": { + "repo_id": "smp-hub/resnet18.ssl", + "revision": "d600d5116aac2e6e595f99f40612074c723c00b2", + }, + "swsl": { + "repo_id": "smp-hub/resnet18.swsl", + "revision": "0e3a35d4d8e344088c14a96eee502a88ac70eae1", + }, + }, "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": BasicBlock, "layers": [2, 2, 2, 2], }, }, "resnet34": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnet34"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnet34.imagenet", + "revision": "7a57b34f723329ff020b3f8bc41771163c519d0c", + }, + }, "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": BasicBlock, "layers": [3, 4, 6, 3], }, }, "resnet50": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnet50"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnet50.imagenet", + "revision": "00cb74e366966d59cd9a35af57e618af9f88efe9", + }, + "ssl": { + "repo_id": "smp-hub/resnet50.ssl", + "revision": "d07daf5b4377f3700c6ac61906b0aafbc4eca46b", + }, + "swsl": { + "repo_id": "smp-hub/resnet50.swsl", + "revision": "b9520cce124f91c6fe7eee45721a2c7954f0d8c0", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 6, 3], }, }, "resnet101": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnet101"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnet101.imagenet", + "revision": "cd7c15e8c51da86ae6a084515fdb962d0c94e7d1", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], }, }, "resnet152": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnet152"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnet152.imagenet", + "revision": "951dd835e9d086628e447b484584c8983f9e1dd0", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 8, 36, 3], }, }, "resnext50_32x4d": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnext50_32x4d"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnext50_32x4d.imagenet", + "revision": "329793c85d62fd340ae42ae39fb905a63df872e7", + }, + "ssl": { + "repo_id": "smp-hub/resnext50_32x4d.ssl", + "revision": "9b67cff77d060c7044493a58c24d1007c1eb06c3", + }, + "swsl": { + "repo_id": "smp-hub/resnext50_32x4d.swsl", + "revision": "52e6e49da61b8e26ca691e1aef2cbb952884057d", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 6, 3], "groups": 32, @@ -183,9 +205,18 @@ def load_state_dict(self, state_dict, **kwargs): }, "resnext101_32x4d": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnext101_32x4d"], + "pretrained_settings": { + "ssl": { + "repo_id": "smp-hub/resnext101_32x4d.ssl", + "revision": "b39796c8459084d13523b7016c3ef13a2e9e472b", + }, + "swsl": { + "repo_id": "smp-hub/resnext101_32x4d.swsl", + "revision": "3f8355b4892a31f001a832b49b2b01484d48516a", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -194,9 +225,26 @@ def load_state_dict(self, state_dict, **kwargs): }, "resnext101_32x8d": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnext101_32x8d"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/resnext101_32x8d.imagenet", + "revision": "221af6198d03a4ee88992f78a1ee81b46a52d339", + }, + "instagram": { + "repo_id": "smp-hub/resnext101_32x8d.instagram", + "revision": "44cd927aa6e64673ffe9d31230bad44abc18b823", + }, + "ssl": { + "repo_id": "smp-hub/resnext101_32x8d.ssl", + "revision": "723a95ddeed335c9488c37c6cbef13d779ac8f97", + }, + "swsl": { + "repo_id": "smp-hub/resnext101_32x8d.swsl", + "revision": "58cf0bb65f91365470398080d9588b187d1777c4", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -205,9 +253,22 @@ def load_state_dict(self, state_dict, **kwargs): }, "resnext101_32x16d": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnext101_32x16d"], + "pretrained_settings": { + "instagram": { + "repo_id": "smp-hub/resnext101_32x16d.instagram", + "revision": "64e8e320eeae6501185b0627b2429a68e52d050c", + }, + "ssl": { + "repo_id": "smp-hub/resnext101_32x16d.ssl", + "revision": "1283fe03fbb6aa2599b2df24095255acb93c3d5c", + }, + "swsl": { + "repo_id": "smp-hub/resnext101_32x16d.swsl", + "revision": "30ba61bbd4d6af0d955c513dbb4f557b84eb094f", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -216,9 +277,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "resnext101_32x32d": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnext101_32x32d"], + "pretrained_settings": { + "instagram": { + "repo_id": "smp-hub/resnext101_32x32d.instagram", + "revision": "c9405de121fdaa275a89de470fb19409e3eeaa86", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, @@ -227,9 +293,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "resnext101_32x48d": { "encoder": ResNetEncoder, - "pretrained_settings": pretrained_settings["resnext101_32x48d"], + "pretrained_settings": { + "instagram": { + "repo_id": "smp-hub/resnext101_32x48d.instagram", + "revision": "53e61a962b824ad7027409821f9ac3e3336dd024", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": Bottleneck, "layers": [3, 4, 23, 3], "groups": 32, diff --git a/segmentation_models_pytorch/encoders/senet.py b/segmentation_models_pytorch/encoders/senet.py index 8e0f6fd8..da509f5a 100644 --- a/segmentation_models_pytorch/encoders/senet.py +++ b/segmentation_models_pytorch/encoders/senet.py @@ -23,45 +23,72 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ -import torch.nn as nn +import torch +from typing import List, Dict, Sequence -from pretrainedmodels.models.senet import ( +from ._base import EncoderMixin +from ._senet import ( SENet, SEBottleneck, SEResNetBottleneck, SEResNeXtBottleneck, - pretrained_settings, ) -from ._base import EncoderMixin class SENetEncoder(SENet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + + # for compatibility with torchscript + self.layer0_pool = self.layer0.pool + self.layer0.pool = torch.nn.Identity() del self.last_linear del self.avg_pool - def get_stages(self): - return [ - nn.Identity(), - self.layer0[:-1], - nn.Sequential(self.layer0[-1], self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.layer3], + 32: [self.layer4], + } + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] + + if self._depth >= 1: + x = self.layer0(x) + features.append(x) + + if self._depth >= 2: + x = self.layer0_pool(x) + x = self.layer1(x) + features.append(x) + + if self._depth >= 3: + x = self.layer2(x) + features.append(x) + + if self._depth >= 4: + x = self.layer3(x) + features.append(x) + + if self._depth >= 5: + x = self.layer4(x) features.append(x) return features @@ -75,9 +102,14 @@ def load_state_dict(self, state_dict, **kwargs): senet_encoders = { "senet154": { "encoder": SENetEncoder, - "pretrained_settings": pretrained_settings["senet154"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/senet154.imagenet", + "revision": "249f45efc9881ba560a0c480128edbc34ab87e40", + } + }, "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), + "out_channels": [3, 128, 256, 512, 1024, 2048], "block": SEBottleneck, "dropout_p": 0.2, "groups": 64, @@ -88,9 +120,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "se_resnet50": { "encoder": SENetEncoder, - "pretrained_settings": pretrained_settings["se_resnet50"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/se_resnet50.imagenet", + "revision": "e6b4bc2dc85226c3d3474544410724a485455459", + } + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNetBottleneck, "layers": [3, 4, 6, 3], "downsample_kernel_size": 1, @@ -105,9 +142,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "se_resnet101": { "encoder": SENetEncoder, - "pretrained_settings": pretrained_settings["se_resnet101"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/se_resnet101.imagenet", + "revision": "71fe95cc0a27f444cf83671f354de02dc741b18b", + } + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNetBottleneck, "layers": [3, 4, 23, 3], "downsample_kernel_size": 1, @@ -122,9 +164,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "se_resnet152": { "encoder": SENetEncoder, - "pretrained_settings": pretrained_settings["se_resnet152"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/se_resnet152.imagenet", + "revision": "e79fc3d9d76f197bd76a2593c2054edf1083fe32", + } + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNetBottleneck, "layers": [3, 8, 36, 3], "downsample_kernel_size": 1, @@ -139,9 +186,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "se_resnext50_32x4d": { "encoder": SENetEncoder, - "pretrained_settings": pretrained_settings["se_resnext50_32x4d"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/se_resnext50_32x4d.imagenet", + "revision": "73246406d879a2b0e3fdfe6fddd56347d38f38ae", + } + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNeXtBottleneck, "layers": [3, 4, 6, 3], "downsample_kernel_size": 1, @@ -156,9 +208,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "se_resnext101_32x4d": { "encoder": SENetEncoder, - "pretrained_settings": pretrained_settings["se_resnext101_32x4d"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/se_resnext101_32x4d.imagenet", + "revision": "18808a4276f46421d358a9de554e0b93c2795df4", + } + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SEResNeXtBottleneck, "layers": [3, 4, 23, 3], "downsample_kernel_size": 1, diff --git a/segmentation_models_pytorch/encoders/timm_efficientnet.py b/segmentation_models_pytorch/encoders/timm_efficientnet.py index fc248575..a1c36491 100644 --- a/segmentation_models_pytorch/encoders/timm_efficientnet.py +++ b/segmentation_models_pytorch/encoders/timm_efficientnet.py @@ -1,9 +1,11 @@ -from functools import partial - +import torch import torch.nn as nn +from typing import List, Dict, Sequence +from functools import partial + from timm.models.efficientnet import EfficientNet -from timm.models.efficientnet import decode_arch_def, round_channels, default_cfgs +from timm.models.efficientnet import decode_arch_def, round_channels from timm.layers.activations import Swish from ._base import EncoderMixin @@ -95,32 +97,59 @@ def gen_efficientnet_lite_kwargs( class EfficientNetBaseEncoder(EfficientNet, EncoderMixin): - def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): + def __init__( + self, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) self._stage_idxs = stage_idxs - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.classifier - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv_stem, self.bn1), - self.blocks[: self._stage_idxs[0]], - self.blocks[self._stage_idxs[0] : self._stage_idxs[1]], - self.blocks[self._stage_idxs[1] : self._stage_idxs[2]], - self.blocks[self._stage_idxs[2] :], - ] + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.blocks[self._stage_idxs[1] : self._stage_idxs[2]]], + 32: [self.blocks[self._stage_idxs[2] :]], + } - def forward(self, x): - stages = self.get_stages() + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + if self._depth >= 1: + x = self.conv_stem(x) + x = self.bn1(x) + features.append(x) + + if self._depth >= 2: + x = self.blocks[0](x) + x = self.blocks[1](x) + features.append(x) + + if self._depth >= 3: + x = self.blocks[2](x) + features.append(x) + + if self._depth >= 4: + x = self.blocks[3](x) + x = self.blocks[4](x) + features.append(x) + + if self._depth >= 5: + x = self.blocks[5](x) + x = self.blocks[6](x) features.append(x) return features @@ -134,33 +163,47 @@ def load_state_dict(self, state_dict, **kwargs): class EfficientNetEncoder(EfficientNetBaseEncoder): def __init__( self, - stage_idxs, - out_channels, - depth=5, - channel_multiplier=1.0, - depth_multiplier=1.0, - drop_rate=0.2, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, ): kwargs = get_efficientnet_kwargs( channel_multiplier, depth_multiplier, drop_rate ) - super().__init__(stage_idxs, out_channels, depth, **kwargs) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) class EfficientNetLiteEncoder(EfficientNetBaseEncoder): def __init__( self, - stage_idxs, - out_channels, - depth=5, - channel_multiplier=1.0, - depth_multiplier=1.0, - drop_rate=0.2, + stage_idxs: List[int], + out_channels: List[int], + depth: int = 5, + channel_multiplier: float = 1.0, + depth_multiplier: float = 1.0, + drop_rate: float = 0.2, + output_stride: int = 32, ): kwargs = gen_efficientnet_lite_kwargs( channel_multiplier, depth_multiplier, drop_rate ) - super().__init__(stage_idxs, out_channels, depth, **kwargs) + super().__init__( + stage_idxs=stage_idxs, + depth=depth, + out_channels=out_channels, + output_stride=output_stride, + **kwargs, + ) def prepare_settings(settings): @@ -177,19 +220,22 @@ def prepare_settings(settings): "timm-efficientnet-b0": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b0"].cfgs["in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b0"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b0"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b0.imagenet", + "revision": "8419e9cc19da0b68dcd7bb12f19b7c92407ad7c4", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b0.advprop", + "revision": "a5870af2d24ce79e0cc7fae2bbd8e0a21fcfa6d8", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b0.noisy-student", + "revision": "bea8b0ff726a50e48774d2d360c5fb1ac4815836", + }, }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.0, "drop_rate": 0.2, @@ -198,19 +244,22 @@ def prepare_settings(settings): "timm-efficientnet-b1": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b1"].cfgs["in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b1"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b1"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b1.imagenet", + "revision": "63bdd65ef6596ef24f1cadc7dd4f46b624442349", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b1.advprop", + "revision": "79b3d102080ef679b16c2748e608a871112233d0", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b1.noisy-student", + "revision": "36856124a699f6032574ceeefab02040daa90a9a", + }, }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.1, "drop_rate": 0.2, @@ -219,19 +268,22 @@ def prepare_settings(settings): "timm-efficientnet-b2": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b2"].cfgs["in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b2"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b2"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b2.imagenet", + "revision": "e693adb39d3cb3847e71e3700a0c2aa58072cff1", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b2.advprop", + "revision": "b58479bf78007cfbb365091d64eeee369bddfa21", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b2.noisy-student", + "revision": "67c558827c6d3e0975ff9b4bce8557bc2ca80931", + }, }, "params": { - "out_channels": (3, 32, 24, 48, 120, 352), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 48, 120, 352], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.1, "depth_multiplier": 1.2, "drop_rate": 0.3, @@ -240,19 +292,22 @@ def prepare_settings(settings): "timm-efficientnet-b3": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b3"].cfgs["in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b3"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b3"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b3.imagenet", + "revision": "1666b835b5151d6bb2067c7cd67e67ada6c39edf", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b3.advprop", + "revision": "70474cdb9f1ff4fcbd7434e66560ead1ab8e506b", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b3.noisy-student", + "revision": "2367bc9f61e79ee97684169a71a87db280bcf4db", + }, }, "params": { - "out_channels": (3, 40, 32, 48, 136, 384), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 40, 32, 48, 136, 384], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.2, "depth_multiplier": 1.4, "drop_rate": 0.3, @@ -261,19 +316,22 @@ def prepare_settings(settings): "timm-efficientnet-b4": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b4"].cfgs["in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b4"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b4"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b4.imagenet", + "revision": "07868c28ab308f4de4cf1e7ec54b33b8b002ccdb", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b4.advprop", + "revision": "8ea1772ee9a2a0d18c1b56dce0dfac8dd33d537d", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b4.noisy-student", + "revision": "faeb77b6e8292a700380c840d39442d7ce4d6443", + }, }, "params": { - "out_channels": (3, 48, 32, 56, 160, 448), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 48, 32, 56, 160, 448], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.4, "depth_multiplier": 1.8, "drop_rate": 0.4, @@ -282,19 +340,22 @@ def prepare_settings(settings): "timm-efficientnet-b5": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b5"].cfgs["in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b5"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b5"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b5.imagenet", + "revision": "004153b4ddd93d30afd9bbf34329d7f57396d413", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b5.advprop", + "revision": "1d1c5f05aab5ed9a1d5052847ddd4024c06a464d", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b5.noisy-student", + "revision": "9bc3a1e5490de92b1af061d5c2c474ab3129e38c", + }, }, "params": { - "out_channels": (3, 48, 40, 64, 176, 512), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 48, 40, 64, 176, 512], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.6, "depth_multiplier": 2.2, "drop_rate": 0.4, @@ -303,19 +364,22 @@ def prepare_settings(settings): "timm-efficientnet-b6": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b6"].cfgs["aa_in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b6"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b6"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b6.imagenet", + "revision": "dbbf28a5c33f021486db4070de693caad6b56c3d", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b6.advprop", + "revision": "3b5d3412047f7711c56ffde997911cfefe79f835", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b6.noisy-student", + "revision": "9b899ea9e8e0ce2ccada0f34a8cb8b5028e9bb36", + }, }, "params": { - "out_channels": (3, 56, 40, 72, 200, 576), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 56, 40, 72, 200, 576], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.8, "depth_multiplier": 2.6, "drop_rate": 0.5, @@ -324,19 +388,22 @@ def prepare_settings(settings): "timm-efficientnet-b7": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b7"].cfgs["aa_in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b7"].cfgs["ap_in1k"] - ), - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_b7"].cfgs["ns_jft_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b7.imagenet", + "revision": "8ef7ffccf54dad9baceb21d05b7ef86b6b70f4cc", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b7.advprop", + "revision": "fcbc576ffb939c12d5cd8dad523fdae6eb0177ca", + }, + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-b7.noisy-student", + "revision": "6b1dd73e61bf934d485d7bd4381dc3e2ab374664", + }, }, "params": { - "out_channels": (3, 64, 48, 80, 224, 640), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 64, 48, 80, 224, 640], + "stage_idxs": [2, 3, 5], "channel_multiplier": 2.0, "depth_multiplier": 3.1, "drop_rate": 0.5, @@ -345,16 +412,18 @@ def prepare_settings(settings): "timm-efficientnet-b8": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_b8"].cfgs["ra_in1k"] - ), - "advprop": prepare_settings( - default_cfgs["tf_efficientnet_b8"].cfgs["ap_in1k"] - ), + "imagenet": { + "repo_id": "smp-hub/timm-efficientnet-b8.imagenet", + "revision": "b5e9dde35605a3a6d17ea2a727382625f9066a37", + }, + "advprop": { + "repo_id": "smp-hub/timm-efficientnet-b8.advprop", + "revision": "e43f381de72e7467383c2c80bacbb7fcb9572866", + }, }, "params": { - "out_channels": (3, 72, 56, 88, 248, 704), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 72, 56, 88, 248, 704], + "stage_idxs": [2, 3, 5], "channel_multiplier": 2.2, "depth_multiplier": 3.6, "drop_rate": 0.5, @@ -363,16 +432,18 @@ def prepare_settings(settings): "timm-efficientnet-l2": { "encoder": EfficientNetEncoder, "pretrained_settings": { - "noisy-student": prepare_settings( - default_cfgs["tf_efficientnet_l2"].cfgs["ns_jft_in1k"] - ), - "noisy-student-475": prepare_settings( - default_cfgs["tf_efficientnet_l2"].cfgs["ns_jft_in1k_475"] - ), + "noisy-student": { + "repo_id": "smp-hub/timm-efficientnet-l2.noisy-student", + "revision": "cdc711e76d1becdd9197169f1a8bb1b2094e980c", + }, + "noisy-student-475": { + "repo_id": "smp-hub/timm-efficientnet-l2.noisy-student-475", + "revision": "35f5ba667a64bf4f3f0689daf84fc6d0f8e1311b", + }, }, "params": { - "out_channels": (3, 136, 104, 176, 480, 1376), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 136, 104, 176, 480, 1376], + "stage_idxs": [2, 3, 5], "channel_multiplier": 4.3, "depth_multiplier": 5.3, "drop_rate": 0.5, @@ -381,13 +452,14 @@ def prepare_settings(settings): "timm-tf_efficientnet_lite0": { "encoder": EfficientNetLiteEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_lite0"].cfgs["in1k"] - ) + "imagenet": { + "repo_id": "smp-hub/timm-tf_efficientnet_lite0.imagenet", + "revision": "f5729249af07e5d923fb8b16922256ce2865d108", + }, }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.0, "drop_rate": 0.2, @@ -396,13 +468,14 @@ def prepare_settings(settings): "timm-tf_efficientnet_lite1": { "encoder": EfficientNetLiteEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_lite1"].cfgs["in1k"] - ) + "imagenet": { + "repo_id": "smp-hub/timm-tf_efficientnet_lite1.imagenet", + "revision": "7b5e3f8dbb0c13b74101773584bba7523721be72", + }, }, "params": { - "out_channels": (3, 32, 24, 40, 112, 320), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 40, 112, 320], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.0, "depth_multiplier": 1.1, "drop_rate": 0.2, @@ -411,13 +484,14 @@ def prepare_settings(settings): "timm-tf_efficientnet_lite2": { "encoder": EfficientNetLiteEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_lite2"].cfgs["in1k"] - ) + "imagenet": { + "repo_id": "smp-hub/timm-tf_efficientnet_lite2.imagenet", + "revision": "cc5f6cd4c7409ebacc13292f09d369ae88547f6a", + }, }, "params": { - "out_channels": (3, 32, 24, 48, 120, 352), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 24, 48, 120, 352], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.1, "depth_multiplier": 1.2, "drop_rate": 0.3, @@ -426,13 +500,14 @@ def prepare_settings(settings): "timm-tf_efficientnet_lite3": { "encoder": EfficientNetLiteEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_lite3"].cfgs["in1k"] - ) + "imagenet": { + "repo_id": "smp-hub/timm-tf_efficientnet_lite3.imagenet", + "revision": "ab29c8402991591d66f813bbb1f061565d9b0cd0", + }, }, "params": { - "out_channels": (3, 32, 32, 48, 136, 384), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 32, 48, 136, 384], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.2, "depth_multiplier": 1.4, "drop_rate": 0.3, @@ -441,13 +516,14 @@ def prepare_settings(settings): "timm-tf_efficientnet_lite4": { "encoder": EfficientNetLiteEncoder, "pretrained_settings": { - "imagenet": prepare_settings( - default_cfgs["tf_efficientnet_lite4"].cfgs["in1k"] - ) + "imagenet": { + "repo_id": "smp-hub/timm-tf_efficientnet_lite4.imagenet", + "revision": "91a822e0f03c255b34dfb7846d3858397e50ba39", + }, }, "params": { - "out_channels": (3, 32, 32, 56, 160, 448), - "stage_idxs": (2, 3, 5), + "out_channels": [3, 32, 32, 56, 160, 448], + "stage_idxs": [2, 3, 5], "channel_multiplier": 1.4, "depth_multiplier": 1.8, "drop_rate": 0.4, diff --git a/segmentation_models_pytorch/encoders/timm_gernet.py b/segmentation_models_pytorch/encoders/timm_gernet.py deleted file mode 100644 index e0c3354d..00000000 --- a/segmentation_models_pytorch/encoders/timm_gernet.py +++ /dev/null @@ -1,124 +0,0 @@ -from timm.models import ByoModelCfg, ByoBlockCfg, ByobNet - -from ._base import EncoderMixin -import torch.nn as nn - - -class GERNetEncoder(ByobNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.head - - def get_stages(self): - return [ - nn.Identity(), - self.stem, - self.stages[0], - self.stages[1], - self.stages[2], - nn.Sequential(self.stages[3], self.stages[4], self.final_conv), - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("head.fc.weight", None) - state_dict.pop("head.fc.bias", None) - super().load_state_dict(state_dict, **kwargs) - - -regnet_weights = { - "timm-gernet_s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth" # noqa - }, - "timm-gernet_m": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth" # noqa - }, - "timm-gernet_l": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in regnet_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - -timm_gernet_encoders = { - "timm-gernet_s": { - "encoder": GERNetEncoder, - "pretrained_settings": pretrained_settings["timm-gernet_s"], - "params": { - "out_channels": (3, 13, 48, 48, 384, 1920), - "cfg": ByoModelCfg( - blocks=( - ByoBlockCfg(type="basic", d=1, c=48, s=2, gs=0, br=1.0), - ByoBlockCfg(type="basic", d=3, c=48, s=2, gs=0, br=1.0), - ByoBlockCfg(type="bottle", d=7, c=384, s=2, gs=0, br=1 / 4), - ByoBlockCfg(type="bottle", d=2, c=560, s=2, gs=1, br=3.0), - ByoBlockCfg(type="bottle", d=1, c=256, s=1, gs=1, br=3.0), - ), - stem_chs=13, - stem_pool=None, - num_features=1920, - ), - }, - }, - "timm-gernet_m": { - "encoder": GERNetEncoder, - "pretrained_settings": pretrained_settings["timm-gernet_m"], - "params": { - "out_channels": (3, 32, 128, 192, 640, 2560), - "cfg": ByoModelCfg( - blocks=( - ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), - ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), - ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), - ByoBlockCfg(type="bottle", d=4, c=640, s=2, gs=1, br=3.0), - ByoBlockCfg(type="bottle", d=1, c=640, s=1, gs=1, br=3.0), - ), - stem_chs=32, - stem_pool=None, - num_features=2560, - ), - }, - }, - "timm-gernet_l": { - "encoder": GERNetEncoder, - "pretrained_settings": pretrained_settings["timm-gernet_l"], - "params": { - "out_channels": (3, 32, 128, 192, 640, 2560), - "cfg": ByoModelCfg( - blocks=( - ByoBlockCfg(type="basic", d=1, c=128, s=2, gs=0, br=1.0), - ByoBlockCfg(type="basic", d=2, c=192, s=2, gs=0, br=1.0), - ByoBlockCfg(type="bottle", d=6, c=640, s=2, gs=0, br=1 / 4), - ByoBlockCfg(type="bottle", d=5, c=640, s=2, gs=1, br=3.0), - ByoBlockCfg(type="bottle", d=4, c=640, s=1, gs=1, br=3.0), - ), - stem_chs=32, - stem_pool=None, - num_features=2560, - ), - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py b/segmentation_models_pytorch/encoders/timm_mobilenetv3.py deleted file mode 100644 index ff733ab9..00000000 --- a/segmentation_models_pytorch/encoders/timm_mobilenetv3.py +++ /dev/null @@ -1,151 +0,0 @@ -import timm -import numpy as np -import torch.nn as nn - -from ._base import EncoderMixin - - -def _make_divisible(x, divisible_by=8): - return int(np.ceil(x * 1.0 / divisible_by) * divisible_by) - - -class MobileNetV3Encoder(nn.Module, EncoderMixin): - def __init__(self, model_name, width_mult, depth=5, **kwargs): - super().__init__() - if "large" not in model_name and "small" not in model_name: - raise ValueError("MobileNetV3 wrong model name {}".format(model_name)) - - self._mode = "small" if "small" in model_name else "large" - self._depth = depth - self._out_channels = self._get_channels(self._mode, width_mult) - self._in_channels = 3 - - # minimal models replace hardswish with relu - self.model = timm.create_model( - model_name=model_name, - scriptable=True, # torch.jit scriptable - exportable=True, # onnx export - features_only=True, - ) - - def _get_channels(self, mode, width_mult): - if mode == "small": - channels = [16, 16, 24, 48, 576] - else: - channels = [16, 24, 40, 112, 960] - channels = [3] + [_make_divisible(x * width_mult) for x in channels] - return tuple(channels) - - def get_stages(self): - if self._mode == "small": - return [ - nn.Identity(), - nn.Sequential(self.model.conv_stem, self.model.bn1, self.model.act1), - self.model.blocks[0], - self.model.blocks[1], - self.model.blocks[2:4], - self.model.blocks[4:], - ] - elif self._mode == "large": - return [ - nn.Identity(), - nn.Sequential( - self.model.conv_stem, - self.model.bn1, - self.model.act1, - self.model.blocks[0], - ), - self.model.blocks[1], - self.model.blocks[2], - self.model.blocks[3:5], - self.model.blocks[5:], - ] - else: - ValueError( - "MobileNetV3 mode should be small or large, got {}".format(self._mode) - ) - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("conv_head.weight", None) - state_dict.pop("conv_head.bias", None) - state_dict.pop("classifier.weight", None) - state_dict.pop("classifier.bias", None) - self.model.load_state_dict(state_dict, **kwargs) - - -mobilenetv3_weights = { - "tf_mobilenetv3_large_075": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth" # noqa - }, - "tf_mobilenetv3_large_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth" # noqa - }, - "tf_mobilenetv3_large_minimal_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth" # noqa - }, - "tf_mobilenetv3_small_075": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth" # noqa - }, - "tf_mobilenetv3_small_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth" # noqa - }, - "tf_mobilenetv3_small_minimal_100": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in mobilenetv3_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "input_space": "RGB", - } - - -timm_mobilenetv3_encoders = { - "timm-mobilenetv3_large_075": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_075"], - "params": {"model_name": "tf_mobilenetv3_large_075", "width_mult": 0.75}, - }, - "timm-mobilenetv3_large_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_100"], - "params": {"model_name": "tf_mobilenetv3_large_100", "width_mult": 1.0}, - }, - "timm-mobilenetv3_large_minimal_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_large_minimal_100"], - "params": {"model_name": "tf_mobilenetv3_large_minimal_100", "width_mult": 1.0}, - }, - "timm-mobilenetv3_small_075": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_075"], - "params": {"model_name": "tf_mobilenetv3_small_075", "width_mult": 0.75}, - }, - "timm-mobilenetv3_small_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_100"], - "params": {"model_name": "tf_mobilenetv3_small_100", "width_mult": 1.0}, - }, - "timm-mobilenetv3_small_minimal_100": { - "encoder": MobileNetV3Encoder, - "pretrained_settings": pretrained_settings["tf_mobilenetv3_small_minimal_100"], - "params": {"model_name": "tf_mobilenetv3_small_minimal_100", "width_mult": 1.0}, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_regnet.py b/segmentation_models_pytorch/encoders/timm_regnet.py deleted file mode 100644 index cc60b8ba..00000000 --- a/segmentation_models_pytorch/encoders/timm_regnet.py +++ /dev/null @@ -1,350 +0,0 @@ -from ._base import EncoderMixin -from timm.models.regnet import RegNet, RegNetCfg -import torch.nn as nn - - -class RegNetEncoder(RegNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - kwargs["cfg"] = RegNetCfg(**kwargs["cfg"]) - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.head - - def get_stages(self): - return [nn.Identity(), self.stem, self.s1, self.s2, self.s3, self.s4] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("head.fc.weight", None) - state_dict.pop("head.fc.bias", None) - super().load_state_dict(state_dict, **kwargs) - - -regnet_weights = { - "timm-regnetx_002": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth" # noqa - }, - "timm-regnetx_004": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth" # noqa - }, - "timm-regnetx_006": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth" # noqa - }, - "timm-regnetx_008": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth" # noqa - }, - "timm-regnetx_016": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth" # noqa - }, - "timm-regnetx_032": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth" # noqa - }, - "timm-regnetx_040": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth" # noqa - }, - "timm-regnetx_064": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth" # noqa - }, - "timm-regnetx_080": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth" # noqa - }, - "timm-regnetx_120": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth" # noqa - }, - "timm-regnetx_160": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth" # noqa - }, - "timm-regnetx_320": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth" # noqa - }, - "timm-regnety_002": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth" # noqa - }, - "timm-regnety_004": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth" # noqa - }, - "timm-regnety_006": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth" # noqa - }, - "timm-regnety_008": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth" # noqa - }, - "timm-regnety_016": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth" # noqa - }, - "timm-regnety_032": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth" # noqa - }, - "timm-regnety_040": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth" # noqa - }, - "timm-regnety_064": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth" # noqa - }, - "timm-regnety_080": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth" # noqa - }, - "timm-regnety_120": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth" # noqa - }, - "timm-regnety_160": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth" # noqa - }, - "timm-regnety_320": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in regnet_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - -# at this point I am too lazy to copy configs, so I just used the same configs from timm's repo - - -def _mcfg(**kwargs): - cfg = dict(se_ratio=0.0, bottle_ratio=1.0, stem_width=32) - cfg.update(**kwargs) - return cfg - - -timm_regnet_encoders = { - "timm-regnetx_002": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_002"], - "params": { - "out_channels": (3, 32, 24, 56, 152, 368), - "cfg": _mcfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13), - }, - }, - "timm-regnetx_004": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_004"], - "params": { - "out_channels": (3, 32, 32, 64, 160, 384), - "cfg": _mcfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22), - }, - }, - "timm-regnetx_006": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_006"], - "params": { - "out_channels": (3, 32, 48, 96, 240, 528), - "cfg": _mcfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16), - }, - }, - "timm-regnetx_008": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_008"], - "params": { - "out_channels": (3, 32, 64, 128, 288, 672), - "cfg": _mcfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16), - }, - }, - "timm-regnetx_016": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_016"], - "params": { - "out_channels": (3, 32, 72, 168, 408, 912), - "cfg": _mcfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18), - }, - }, - "timm-regnetx_032": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_032"], - "params": { - "out_channels": (3, 32, 96, 192, 432, 1008), - "cfg": _mcfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25), - }, - }, - "timm-regnetx_040": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_040"], - "params": { - "out_channels": (3, 32, 80, 240, 560, 1360), - "cfg": _mcfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23), - }, - }, - "timm-regnetx_064": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_064"], - "params": { - "out_channels": (3, 32, 168, 392, 784, 1624), - "cfg": _mcfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17), - }, - }, - "timm-regnetx_080": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_080"], - "params": { - "out_channels": (3, 32, 80, 240, 720, 1920), - "cfg": _mcfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23), - }, - }, - "timm-regnetx_120": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_120"], - "params": { - "out_channels": (3, 32, 224, 448, 896, 2240), - "cfg": _mcfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19), - }, - }, - "timm-regnetx_160": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_160"], - "params": { - "out_channels": (3, 32, 256, 512, 896, 2048), - "cfg": _mcfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22), - }, - }, - "timm-regnetx_320": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnetx_320"], - "params": { - "out_channels": (3, 32, 336, 672, 1344, 2520), - "cfg": _mcfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23), - }, - }, - # regnety - "timm-regnety_002": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_002"], - "params": { - "out_channels": (3, 32, 24, 56, 152, 368), - "cfg": _mcfg( - w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25 - ), - }, - }, - "timm-regnety_004": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_004"], - "params": { - "out_channels": (3, 32, 48, 104, 208, 440), - "cfg": _mcfg( - w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25 - ), - }, - }, - "timm-regnety_006": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_006"], - "params": { - "out_channels": (3, 32, 48, 112, 256, 608), - "cfg": _mcfg( - w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25 - ), - }, - }, - "timm-regnety_008": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_008"], - "params": { - "out_channels": (3, 32, 64, 128, 320, 768), - "cfg": _mcfg( - w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25 - ), - }, - }, - "timm-regnety_016": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_016"], - "params": { - "out_channels": (3, 32, 48, 120, 336, 888), - "cfg": _mcfg( - w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25 - ), - }, - }, - "timm-regnety_032": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_032"], - "params": { - "out_channels": (3, 32, 72, 216, 576, 1512), - "cfg": _mcfg( - w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25 - ), - }, - }, - "timm-regnety_040": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_040"], - "params": { - "out_channels": (3, 32, 128, 192, 512, 1088), - "cfg": _mcfg( - w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25 - ), - }, - }, - "timm-regnety_064": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_064"], - "params": { - "out_channels": (3, 32, 144, 288, 576, 1296), - "cfg": _mcfg( - w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25 - ), - }, - }, - "timm-regnety_080": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_080"], - "params": { - "out_channels": (3, 32, 168, 448, 896, 2016), - "cfg": _mcfg( - w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25 - ), - }, - }, - "timm-regnety_120": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_120"], - "params": { - "out_channels": (3, 32, 224, 448, 896, 2240), - "cfg": _mcfg( - w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25 - ), - }, - }, - "timm-regnety_160": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_160"], - "params": { - "out_channels": (3, 32, 224, 448, 1232, 3024), - "cfg": _mcfg( - w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25 - ), - }, - }, - "timm-regnety_320": { - "encoder": RegNetEncoder, - "pretrained_settings": pretrained_settings["timm-regnety_320"], - "params": { - "out_channels": (3, 32, 232, 696, 1392, 3712), - "cfg": _mcfg( - w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25 - ), - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_res2net.py b/segmentation_models_pytorch/encoders/timm_res2net.py deleted file mode 100644 index e97043e3..00000000 --- a/segmentation_models_pytorch/encoders/timm_res2net.py +++ /dev/null @@ -1,163 +0,0 @@ -from ._base import EncoderMixin -from timm.models.resnet import ResNet -from timm.models.res2net import Bottle2neck -import torch.nn as nn - - -class Res2NetEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.fc - del self.global_pool - - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.act1), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def make_dilated(self, *args, **kwargs): - raise ValueError("Res2Net encoders do not support dilated mode") - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias", None) - state_dict.pop("fc.weight", None) - super().load_state_dict(state_dict, **kwargs) - - -res2net_weights = { - "timm-res2net50_26w_4s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth" # noqa - }, - "timm-res2net50_48w_2s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth" # noqa - }, - "timm-res2net50_14w_8s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth" # noqa - }, - "timm-res2net50_26w_6s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth" # noqa - }, - "timm-res2net50_26w_8s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth" # noqa - }, - "timm-res2net101_26w_4s": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth" # noqa - }, - "timm-res2next50": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in res2net_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - - -timm_res2net_encoders = { - "timm-res2net50_26w_4s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 26, - "block_args": {"scale": 4}, - }, - }, - "timm-res2net101_26w_4s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 23, 3], - "base_width": 26, - "block_args": {"scale": 4}, - }, - }, - "timm-res2net50_26w_6s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 26, - "block_args": {"scale": 6}, - }, - }, - "timm-res2net50_26w_8s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 26, - "block_args": {"scale": 8}, - }, - }, - "timm-res2net50_48w_2s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 48, - "block_args": {"scale": 2}, - }, - }, - "timm-res2net50_14w_8s": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 14, - "block_args": {"scale": 8}, - }, - }, - "timm-res2next50": { - "encoder": Res2NetEncoder, - "pretrained_settings": pretrained_settings["timm-res2next50"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": Bottle2neck, - "layers": [3, 4, 6, 3], - "base_width": 4, - "cardinality": 8, - "block_args": {"scale": 4}, - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_resnest.py b/segmentation_models_pytorch/encoders/timm_resnest.py deleted file mode 100644 index 1599b6c8..00000000 --- a/segmentation_models_pytorch/encoders/timm_resnest.py +++ /dev/null @@ -1,208 +0,0 @@ -from ._base import EncoderMixin -from timm.models.resnet import ResNet -from timm.models.resnest import ResNestBottleneck -import torch.nn as nn - - -class ResNestEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): - super().__init__(**kwargs) - self._depth = depth - self._out_channels = out_channels - self._in_channels = 3 - - del self.fc - del self.global_pool - - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.act1), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def make_dilated(self, *args, **kwargs): - raise ValueError("ResNest encoders do not support dilated mode") - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) - - return features - - def load_state_dict(self, state_dict, **kwargs): - state_dict.pop("fc.bias", None) - state_dict.pop("fc.weight", None) - super().load_state_dict(state_dict, **kwargs) - - -resnest_weights = { - "timm-resnest14d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth" # noqa - }, - "timm-resnest26d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth" # noqa - }, - "timm-resnest50d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth" # noqa - }, - "timm-resnest101e": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth" # noqa - }, - "timm-resnest200e": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth" # noqa - }, - "timm-resnest269e": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth" # noqa - }, - "timm-resnest50d_4s2x40d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth" # noqa - }, - "timm-resnest50d_1s4x24d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in resnest_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - - -timm_resnest_encoders = { - "timm-resnest14d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest14d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [1, 1, 1, 1], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest26d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest26d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [2, 2, 2, 2], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest50d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest50d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 6, 3], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest101e": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest101e"], - "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 23, 3], - "stem_type": "deep", - "stem_width": 64, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest200e": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest200e"], - "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 24, 36, 3], - "stem_type": "deep", - "stem_width": 64, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest269e": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest269e"], - "params": { - "out_channels": (3, 128, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 30, 48, 8], - "stem_type": "deep", - "stem_width": 64, - "avg_down": True, - "base_width": 64, - "cardinality": 1, - "block_args": {"radix": 2, "avd": True, "avd_first": False}, - }, - }, - "timm-resnest50d_4s2x40d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 6, 3], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 40, - "cardinality": 2, - "block_args": {"radix": 4, "avd": True, "avd_first": True}, - }, - }, - "timm-resnest50d_1s4x24d": { - "encoder": ResNestEncoder, - "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"], - "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), - "block": ResNestBottleneck, - "layers": [3, 4, 6, 3], - "stem_type": "deep", - "stem_width": 32, - "avg_down": True, - "base_width": 24, - "cardinality": 4, - "block_args": {"radix": 1, "avd": True, "avd_first": True}, - }, - }, -} diff --git a/segmentation_models_pytorch/encoders/timm_sknet.py b/segmentation_models_pytorch/encoders/timm_sknet.py index 14d6d2b0..49fda0e8 100644 --- a/segmentation_models_pytorch/encoders/timm_sknet.py +++ b/segmentation_models_pytorch/encoders/timm_sknet.py @@ -1,35 +1,63 @@ -from ._base import EncoderMixin +import torch +from typing import Dict, List, Sequence from timm.models.resnet import ResNet from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic -import torch.nn as nn + +from ._base import EncoderMixin class SkNetEncoder(ResNet, EncoderMixin): - def __init__(self, out_channels, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(**kwargs) + self._depth = depth - self._out_channels = out_channels self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride del self.fc del self.global_pool - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential(self.conv1, self.bn1, self.act1), - nn.Sequential(self.maxpool, self.layer1), - self.layer2, - self.layer3, - self.layer4, - ] - - def forward(self, x): - stages = self.get_stages() - - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: + return { + 16: [self.layer3], + 32: [self.layer4], + } + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [x] + + if self._depth >= 1: + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + features.append(x) + + if self._depth >= 2: + x = self.maxpool(x) + x = self.layer1(x) + features.append(x) + + if self._depth >= 3: + x = self.layer2(x) + features.append(x) + + if self._depth >= 4: + x = self.layer3(x) + features.append(x) + + if self._depth >= 5: + x = self.layer4(x) features.append(x) return features @@ -40,37 +68,17 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) -sknet_weights = { - "timm-skresnet18": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth" # noqa - }, - "timm-skresnet34": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth" # noqa - }, - "timm-skresnext50_32x4d": { - "imagenet": "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth" # noqa - }, -} - -pretrained_settings = {} -for model_name, sources in sknet_weights.items(): - pretrained_settings[model_name] = {} - for source_name, source_url in sources.items(): - pretrained_settings[model_name][source_name] = { - "url": source_url, - "input_size": [3, 224, 224], - "input_range": [0, 1], - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225], - "num_classes": 1000, - } - timm_sknet_encoders = { "timm-skresnet18": { "encoder": SkNetEncoder, - "pretrained_settings": pretrained_settings["timm-skresnet18"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/timm-skresnet18.imagenet", + "revision": "6c97652bb744d89177b68274d2fda3923a7d1f95", + }, + }, "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": SelectiveKernelBasic, "layers": [2, 2, 2, 2], "zero_init_last": False, @@ -79,9 +87,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "timm-skresnet34": { "encoder": SkNetEncoder, - "pretrained_settings": pretrained_settings["timm-skresnet34"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/timm-skresnet34.imagenet", + "revision": "2367796924a8182cc835ef6b5dc303917f923f99", + }, + }, "params": { - "out_channels": (3, 64, 64, 128, 256, 512), + "out_channels": [3, 64, 64, 128, 256, 512], "block": SelectiveKernelBasic, "layers": [3, 4, 6, 3], "zero_init_last": False, @@ -90,9 +103,14 @@ def load_state_dict(self, state_dict, **kwargs): }, "timm-skresnext50_32x4d": { "encoder": SkNetEncoder, - "pretrained_settings": pretrained_settings["timm-skresnext50_32x4d"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/timm-skresnext50_32x4d.imagenet", + "revision": "50207e407cc4c6ea9e6872963db6844ca7b7b9de", + }, + }, "params": { - "out_channels": (3, 64, 256, 512, 1024, 2048), + "out_channels": [3, 64, 256, 512, 1024, 2048], "block": SelectiveKernelBottleneck, "layers": [3, 4, 6, 3], "zero_init_last": False, diff --git a/segmentation_models_pytorch/encoders/timm_universal.py b/segmentation_models_pytorch/encoders/timm_universal.py index 9bdcb188..138b2ef8 100644 --- a/segmentation_models_pytorch/encoders/timm_universal.py +++ b/segmentation_models_pytorch/encoders/timm_universal.py @@ -44,6 +44,10 @@ class TimmUniversalEncoder(nn.Module): - Compatible with convolutional and transformer-like backbones. """ + _is_torch_scriptable = True + _is_torch_exportable = True + _is_torch_compilable = True + def __init__( self, name: str, @@ -64,7 +68,15 @@ def __init__( output_stride (int): Desired output stride (default: 32). **kwargs: Additional arguments passed to `timm.create_model`. """ + # At the moment we do not support models with more than 5 stages, + # but can be reconfigured in the future. + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) + super().__init__() + self.name = name # Default model configuration for feature extraction common_kwargs = dict( @@ -118,9 +130,9 @@ def __init__( # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5. common_kwargs["out_indices"] = tuple(range(depth - 1)) - self.model = timm.create_model( - name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs) - ) + timm_model_kwargs = _merge_kwargs_no_duplicates(common_kwargs, kwargs) + self.model = timm.create_model(name, **timm_model_kwargs) + # Add a dummy output channel (0) to align with traditional encoder structures. self._out_channels = ( [in_channels] + [0] + self.model.feature_info.channels() @@ -193,7 +205,28 @@ def output_stride(self) -> int: Returns: int: The effective output stride. """ - return min(self._output_stride, 2**self._depth) + return int(min(self._output_stride, 2**self._depth)) + + def load_state_dict(self, state_dict, **kwargs): + # for compatibility of weights for + # timm- ported encoders with TimmUniversalEncoder + patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"] + + is_deprecated_encoder = any( + self.name.startswith(pattern) for pattern in patterns + ) + + if is_deprecated_encoder: + keys = list(state_dict.keys()) + for key in keys: + new_key = key + if not key.startswith("model."): + new_key = "model." + key + if "gernet" in self.name: + new_key = new_key.replace(".stages.", ".stages_") + state_dict[new_key] = state_dict.pop(key) + + return super().load_state_dict(state_dict, **kwargs) def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]: diff --git a/segmentation_models_pytorch/encoders/timm_vit.py b/segmentation_models_pytorch/encoders/timm_vit.py new file mode 100644 index 00000000..5519d897 --- /dev/null +++ b/segmentation_models_pytorch/encoders/timm_vit.py @@ -0,0 +1,191 @@ +from typing import Any, Optional + +import timm +import torch +import torch.nn as nn + +from .timm_universal import _merge_kwargs_no_duplicates + + +def sample_block_indices_uniformly(n: int, total_num_blocks: int) -> list[int]: + """ + Sample N block indices uniformly from the total number of blocks. + """ + return [ + int(total_num_blocks / n * block_depth) - 1 for block_depth in range(1, n + 1) + ] + + +def validate_output_indices( + output_indices: list[int], model_num_blocks: int, depth: int +): + """ + Validate the output indices are within the valid range of the model and the + length of the output indices is equal to the depth of the encoder. + """ + for output_index in output_indices: + if output_index < -model_num_blocks or output_index >= model_num_blocks: + raise ValueError( + f"Output indices for feature extraction should be in range " + f"[-{model_num_blocks}, {model_num_blocks}), because the model has {model_num_blocks} blocks, " + f"got index = {output_index}." + ) + + +def preprocess_output_indices( + output_indices: Optional[list[int]], model_num_blocks: int, depth: int +) -> list[int]: + """ + Preprocess the output indices for the encoder. + """ + + # Refine encoder output indices + if output_indices is None: + output_indices = sample_block_indices_uniformly(depth, model_num_blocks) + elif not isinstance(output_indices, (list, tuple)): + raise ValueError( + f"`output_indices` for encoder should be a list/tuple/None, got {type(output_indices)}" + ) + validate_output_indices(output_indices, model_num_blocks, depth) + + return output_indices + + +class TimmViTEncoder(nn.Module): + """ + A universal encoder leveraging the `timm` library for feature extraction from + ViT style models + + Features: + - Supports configurable depth. + - Ensures consistent multi-level feature extraction across all ViT models. + """ + + # prefix tokens are not supported for scripting + _is_torch_scriptable = False + _is_torch_exportable = True + _is_torch_compilable = True + + def __init__( + self, + name: str, + pretrained: bool = True, + in_channels: int = 3, + depth: int = 4, + output_indices: Optional[list[int]] = None, + **kwargs: dict[str, Any], + ): + """ + Initialize the encoder. + + Args: + name (str): ViT model name to load from `timm`. + pretrained (bool): Load pretrained weights (default: True). + in_channels (int): Number of input channels (default: 3 for RGB). + depth (int): Number of feature stages to extract (default: 4). + output_indices (Optional[list[int] | int]): Indices of blocks in the model to be used for feature extraction. + **kwargs: Additional arguments passed to `timm.create_model`. + """ + super().__init__() + + if depth < 1: + raise ValueError(f"`encoder_depth` should be greater than 1, got {depth}.") + + # Output stride validation needed for smp encoder test consistency + output_stride = kwargs.pop("output_stride", None) + if output_stride is not None: + raise ValueError("Dilated mode not supported, set output stride to None") + + if isinstance(output_indices, (list, tuple)) and len(output_indices) != depth: + raise ValueError( + f"Length of output indices for feature extraction should be equal to the depth of the encoder " + f"architecture, got output indices length - {len(output_indices)}, encoder depth - {depth}" + ) + + self.name = name + + # Load a timm model + encoder_kwargs = dict(in_chans=in_channels, pretrained=pretrained) + encoder_kwargs = _merge_kwargs_no_duplicates(encoder_kwargs, kwargs) + self.model = timm.create_model(name, **encoder_kwargs) + + if not hasattr(self.model, "forward_intermediates"): + raise ValueError( + f"Encoder `{name}` does not support `forward_intermediates` for feature extraction. " + f"Please update `timm` or use another encoder." + ) + + # Get all the necessary information about the model + feature_info = self.model.feature_info + + # Additional checks + model_num_blocks = len(feature_info) + if depth > model_num_blocks: + raise ValueError( + f"Depth of the encoder cannot exceed the number of blocks in the model " + f"got {depth} depth, model has {model_num_blocks} blocks" + ) + + # Preprocess the output indices, uniformly sample from model_num_blocks if None + output_indices = preprocess_output_indices( + output_indices, model_num_blocks, depth + ) + + # Private attributes for model forward + self._num_prefix_tokens = getattr(self.model, "num_prefix_tokens", 0) + self._has_cls_token = getattr(self.model, "has_cls_token", False) + self._output_indices = output_indices + + # Public attributes + self.output_strides = [feature_info[i]["reduction"] for i in output_indices] + self.output_stride = self.output_strides[-1] + self.out_channels = [feature_info[i]["num_chs"] for i in output_indices] + self.has_prefix_tokens = self._num_prefix_tokens > 0 + self.input_size = self.model.pretrained_cfg.get("input_size", None) + self.is_fixed_input_size = self.model.pretrained_cfg.get( + "fixed_input_size", False + ) + + def _forward_with_prefix_tokens( + self, x: torch.Tensor + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + intermediate_outputs = self.model.forward_intermediates( + x, + indices=self._output_indices, + intermediates_only=True, + return_prefix_tokens=True, + ) + + features = [output[0] for output in intermediate_outputs] + prefix_tokens = [output[1] for output in intermediate_outputs] + + return features, prefix_tokens + + def _forward_without_prefix_tokens(self, x: torch.Tensor) -> list[torch.Tensor]: + features = self.model.forward_intermediates( + x, + indices=self._output_indices, + intermediates_only=True, + ) + return features + + def forward( + self, x: torch.Tensor + ) -> tuple[list[torch.Tensor], list[Optional[torch.Tensor]]]: + """ + Forward pass to extract multi-stage features. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W). + + Returns: + tuple[list[torch.Tensor], list[torch.Tensor]]: Tuple of feature maps and cls tokens (if supported) at different scales. + """ + + if self.has_prefix_tokens: + features, prefix_tokens = self._forward_with_prefix_tokens(x) + else: + features = self._forward_without_prefix_tokens(x) + prefix_tokens = [None] * len(features) + + return features, prefix_tokens diff --git a/segmentation_models_pytorch/encoders/vgg.py b/segmentation_models_pytorch/encoders/vgg.py index cbc602c8..1bb577fe 100644 --- a/segmentation_models_pytorch/encoders/vgg.py +++ b/segmentation_models_pytorch/encoders/vgg.py @@ -23,29 +23,53 @@ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). """ +import torch import torch.nn as nn + from torchvision.models.vgg import VGG from torchvision.models.vgg import make_layers -from pretrainedmodels.models.torchvision_models import pretrained_settings + +from typing import List, Union from ._base import EncoderMixin # fmt: off cfg = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } # fmt: on class VGGEncoder(VGG, EncoderMixin): - def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + config: List[Union[int, str]], + batch_norm: bool = False, + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs) - self._out_channels = out_channels + self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride + self._out_indexes = [ + i - 1 + for i, module in enumerate(self.features) + if isinstance(module, nn.MaxPool2d) + ] + self._out_indexes.append(len(self.features) - 1) + del self.classifier def make_dilated(self, *args, **kwargs): @@ -54,24 +78,23 @@ def make_dilated(self, *args, **kwargs): " operations for downsampling!" ) - def get_stages(self): - stages = [] - stage_modules = [] - for module in self.features: - if isinstance(module, nn.MaxPool2d): - stages.append(nn.Sequential(*stage_modules)) - stage_modules = [] - stage_modules.append(module) - stages.append(nn.Sequential(*stage_modules)) - return stages + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + features = [] + depth = 0 + + for i, module in enumerate(self.features): + x = module(x) - def forward(self, x): - stages = self.get_stages() + if i in self._out_indexes: + features.append(x) + depth += 1 - features = [] - for i in range(self._depth + 1): - x = stages[i](x) - features.append(x) + # torchscript does not support break in cycle, so we just + # go over all modules and then slice number of features + if not torch.jit.is_scripting() and depth > self._depth: + break + + features = features[: self._depth + 1] return features @@ -83,75 +106,206 @@ def load_state_dict(self, state_dict, **kwargs): super().load_state_dict(state_dict, **kwargs) +pretrained_settings = { + "vgg11": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg11_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg13": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg13-c768596a.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg13_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg16": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg16-397923af.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg16_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg19": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, + "vgg19_bn": { + "imagenet": { + "url": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + "input_space": "RGB", + "input_size": [3, 224, 224], + "input_range": [0, 1], + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + "num_classes": 1000, + } + }, +} + vgg_encoders = { "vgg11": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg11"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg11.imagenet", + "revision": "ad8b90e1051c38fdbf399cf5016886a1be357390", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["A"], "batch_norm": False, }, }, "vgg11_bn": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg11_bn"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg11_bn.imagenet", + "revision": "59757f9215032c9f092977092d57d26a9df7fd9c", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["A"], "batch_norm": True, }, }, "vgg13": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg13"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg13.imagenet", + "revision": "1b70ff2580f101a8007a48b51e2b5d1e5925dc42", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["B"], "batch_norm": False, }, }, "vgg13_bn": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg13_bn"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg13_bn.imagenet", + "revision": "9be454515193af6612261b7614fe90607e27b143", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["B"], "batch_norm": True, }, }, "vgg16": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg16"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg16.imagenet", + "revision": "49d74b799006ee252b86e25acd6f1fd8ac9a99c1", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["D"], "batch_norm": False, }, }, "vgg16_bn": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg16_bn"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg16_bn.imagenet", + "revision": "2c186d02fb519e93219a99a1c2af6295aef0bf0d", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["D"], "batch_norm": True, }, }, "vgg19": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg19"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg19.imagenet", + "revision": "2853d00d7bca364dbb98be4d6afa347e5aeec1f6", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["E"], "batch_norm": False, }, }, "vgg19_bn": { "encoder": VGGEncoder, - "pretrained_settings": pretrained_settings["vgg19_bn"], + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/vgg19_bn.imagenet", + "revision": "f09a924cb0d201ea6f61601df9559141382271d7", + }, + }, "params": { - "out_channels": (64, 128, 256, 512, 512, 512), + "out_channels": [64, 128, 256, 512, 512, 512], "config": cfg["E"], "batch_norm": True, }, diff --git a/segmentation_models_pytorch/encoders/xception.py b/segmentation_models_pytorch/encoders/xception.py index c8c476ce..af3a26d4 100644 --- a/segmentation_models_pytorch/encoders/xception.py +++ b/segmentation_models_pytorch/encoders/xception.py @@ -1,18 +1,28 @@ -import torch.nn as nn - -from pretrainedmodels.models.xception import pretrained_settings -from pretrainedmodels.models.xception import Xception +from typing import List from ._base import EncoderMixin +from ._xception import Xception class XceptionEncoder(Xception, EncoderMixin): - def __init__(self, out_channels, *args, depth=5, **kwargs): + def __init__( + self, + out_channels: List[int], + *args, + depth: int = 5, + output_stride: int = 32, + **kwargs, + ): + if depth > 5 or depth < 1: + raise ValueError( + f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" + ) super().__init__(*args, **kwargs) - self._out_channels = out_channels self._depth = depth self._in_channels = 3 + self._out_channels = out_channels + self._output_stride = output_stride # modify padding to maintain output shape self.conv1.padding = (1, 1) @@ -26,36 +36,45 @@ def make_dilated(self, *args, **kwargs): "due to pooling operation for downsampling!" ) - def get_stages(self): - return [ - nn.Identity(), - nn.Sequential( - self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.relu - ), - self.block1, - self.block2, - nn.Sequential( - self.block3, - self.block4, - self.block5, - self.block6, - self.block7, - self.block8, - self.block9, - self.block10, - self.block11, - ), - nn.Sequential( - self.block12, self.conv3, self.bn3, self.relu, self.conv4, self.bn4 - ), - ] - def forward(self, x): - stages = self.get_stages() + features = [x] + + if self._depth >= 1: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu2(x) + features.append(x) + + if self._depth >= 2: + x = self.block1(x) + features.append(x) + + if self._depth >= 3: + x = self.block2(x) + features.append(x) + + if self._depth >= 4: + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + features.append(x) - features = [] - for i in range(self._depth + 1): - x = stages[i](x) + if self._depth >= 5: + x = self.block12(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu3(x) + x = self.conv4(x) + x = self.bn4(x) features.append(x) return features @@ -71,7 +90,12 @@ def load_state_dict(self, state_dict): xception_encoders = { "xception": { "encoder": XceptionEncoder, - "pretrained_settings": pretrained_settings["xception"], - "params": {"out_channels": (3, 64, 128, 256, 728, 2048)}, + "pretrained_settings": { + "imagenet": { + "repo_id": "smp-hub/xception.imagenet", + "revision": "01cfaf27c11353b1f0c578e7e26d2c000ea91049", + }, + }, + "params": {"out_channels": [3, 64, 128, 256, 728, 2048]}, } } diff --git a/segmentation_models_pytorch/losses/dice.py b/segmentation_models_pytorch/losses/dice.py index d9283161..b8baae98 100644 --- a/segmentation_models_pytorch/losses/dice.py +++ b/segmentation_models_pytorch/losses/dice.py @@ -44,9 +44,9 @@ def __init__( super(DiceLoss, self).__init__() self.mode = mode if classes is not None: - assert ( - mode != BINARY_MODE - ), "Masking classes is not supported with mode=binary" + assert mode != BINARY_MODE, ( + "Masking classes is not supported with mode=binary" + ) classes = to_tensor(classes, dtype=torch.long) self.classes = classes diff --git a/segmentation_models_pytorch/losses/jaccard.py b/segmentation_models_pytorch/losses/jaccard.py index d6aba280..b250cacf 100644 --- a/segmentation_models_pytorch/losses/jaccard.py +++ b/segmentation_models_pytorch/losses/jaccard.py @@ -43,9 +43,9 @@ def __init__( self.mode = mode if classes is not None: - assert ( - mode != BINARY_MODE - ), "Masking classes is not supported with mode=binary" + assert mode != BINARY_MODE, ( + "Masking classes is not supported with mode=binary" + ) classes = to_tensor(classes, dtype=torch.long) self.classes = classes diff --git a/tests/base/test_modules.py b/tests/base/test_modules.py new file mode 100644 index 00000000..5afa8e4f --- /dev/null +++ b/tests/base/test_modules.py @@ -0,0 +1,64 @@ +import pytest +from torch import nn +from segmentation_models_pytorch.base.modules import Conv2dReLU + + +def test_conv2drelu_batchnorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.BatchNorm2d) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_batchnorm_with_keywords(): + module = Conv2dReLU( + 3, + 16, + kernel_size=3, + padding=1, + use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}, + ) + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.BatchNorm2d) + assert module[1].momentum == 1e-4 and module[1].affine is False + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_identity(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.Identity) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_layernorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="layernorm") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.LayerNorm) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_instancenorm(): + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm") + + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], nn.InstanceNorm2d) + assert isinstance(module[2], nn.ReLU) + + +def test_conv2drelu_inplace(): + try: + from inplace_abn import InPlaceABN + except ImportError: + pytest.skip("InPlaceABN is not installed") + + module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="inplace") + + assert len(module) == 3 + assert isinstance(module[0], nn.Conv2d) + assert isinstance(module[1], InPlaceABN) + assert isinstance(module[2], nn.Identity) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..688fd00b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,16 @@ +def pytest_addoption(parser): + parser.addoption( + "--non-marked-only", action="store_true", help="Run only non-marked tests" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--non-marked-only"): + non_marked_items = [] + for item in items: + # Check if the test has no marks + if not item.own_markers: + non_marked_items.append(item) + + # Update the test collection to only include non-marked tests + items[:] = non_marked_items diff --git a/tests/encoders/base.py b/tests/encoders/base.py index 39cd4164..b18be2a9 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -1,14 +1,22 @@ +import pytest import unittest import torch import segmentation_models_pytorch as smp from functools import lru_cache -from tests.utils import default_device +from tests.utils import ( + default_device, + check_run_test_on_diff_or_main, + requires_torch_greater_or_equal, +) class BaseEncoderTester(unittest.TestCase): encoder_names = [] + # some tests might be slow, running them only on diff + files_for_diff = [] + # standard encoder configuration num_output_features = 6 output_strides = [1, 2, 4, 8, 16, 32] @@ -25,8 +33,15 @@ class BaseEncoderTester(unittest.TestCase): depth_to_test = [3, 4, 5] strides_to_test = [8, 16] # 32 is a default one + def get_tiny_encoder(self): + return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None) + @lru_cache - def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): + batch_size = batch_size or self.default_batch_size + num_channels = num_channels or self.default_num_channels + height = height or self.default_height + width = width or self.default_width return torch.rand(batch_size, num_channels, height, width) def get_features_output_strides(self, sample, features): @@ -36,12 +51,7 @@ def get_features_output_strides(self, sample, features): return height_strides, width_strides def test_forward_backward(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) for encoder_name in self.encoder_names: with self.subTest(encoder_name=encoder_name): # init encoder @@ -68,12 +78,7 @@ def test_in_channels(self): ] for encoder_name, in_channels in cases: - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=in_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample(num_channels=in_channels).to(default_device) with self.subTest(encoder_name=encoder_name, in_channels=in_channels): encoder = smp.encoders.get_encoder( @@ -82,16 +87,11 @@ def test_in_channels(self): encoder.eval() # forward - with torch.no_grad(): + with torch.inference_mode(): encoder.forward(sample) def test_depth(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) cases = [ (encoder_name, depth) @@ -110,7 +110,7 @@ def test_depth(self): encoder.eval() # forward - with torch.no_grad(): + with torch.inference_mode(): features = encoder.forward(sample) # check number of features @@ -127,12 +127,12 @@ def test_depth(self): self.assertEqual( height_strides, self.output_strides[: depth + 1], - f"Encoder `{encoder_name}` should have output strides {self.output_strides[:depth + 1]}, but has {height_strides}", + f"Encoder `{encoder_name}` should have output strides {self.output_strides[: depth + 1]}, but has {height_strides}", ) self.assertEqual( width_strides, self.output_strides[: depth + 1], - f"Encoder `{encoder_name}` should have output strides {self.output_strides[:depth + 1]}, but has {width_strides}", + f"Encoder `{encoder_name}` should have output strides {self.output_strides[: depth + 1]}, but has {width_strides}", ) # check encoder output stride property @@ -149,13 +149,14 @@ def test_depth(self): f"Encoder `{encoder_name}` should have {depth + 1} out_channels, but has {len(encoder.out_channels)}", ) + def test_invalid_depth(self): + with self.assertRaises(ValueError): + smp.encoders.get_encoder(self.encoder_names[0], depth=6) + with self.assertRaises(ValueError): + smp.encoders.get_encoder(self.encoder_names[0], depth=0) + def test_dilated(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) cases = [ (encoder_name, stride) @@ -187,7 +188,7 @@ def test_dilated(self): encoder.eval() # forward - with torch.no_grad(): + with torch.inference_mode(): features = encoder.forward(sample) height_strides, width_strides = self.get_features_output_strides( @@ -206,3 +207,78 @@ def test_dilated(self): expected_width_strides, f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}", ) + + @pytest.mark.compile + def test_compile(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample().to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + torch.compiler.reset() + compiled_encoder = torch.compile( + encoder, fullgraph=True, dynamic=True, backend="eager" + ) + + if encoder._is_torch_compilable: + compiled_encoder(sample) + else: + with self.assertRaises(Exception): + compiled_encoder(sample) + + @pytest.mark.torch_export + @requires_torch_greater_or_equal("2.4.0") + def test_torch_export(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample().to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + if not encoder._is_torch_exportable: + with self.assertRaises(Exception): + exported_encoder = torch.export.export( + encoder, + args=(sample,), + strict=True, + ) + return + + exported_encoder = torch.export.export( + encoder, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = encoder(sample) + exported_output = exported_encoder.module().forward(sample) + + for eager_feature, exported_feature in zip(eager_output, exported_output): + torch.testing.assert_close(eager_feature, exported_feature) + + @pytest.mark.torch_script + def test_torch_script(self): + sample = self._get_sample().to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + if not encoder._is_torch_scriptable: + with self.assertRaises(RuntimeError, msg="not torch scriptable"): + scripted_encoder = torch.jit.script(encoder) + return + + scripted_encoder = torch.jit.script(encoder) + + with torch.inference_mode(): + eager_output = encoder(sample) + scripted_output = scripted_encoder(sample) + + for eager_feature, scripted_feature in zip(eager_output, scripted_output): + torch.testing.assert_close(eager_feature, scripted_feature) diff --git a/tests/encoders/test_batchnorm_deprecation.py b/tests/encoders/test_batchnorm_deprecation.py new file mode 100644 index 00000000..ff53563f --- /dev/null +++ b/tests/encoders/test_batchnorm_deprecation.py @@ -0,0 +1,54 @@ +import pytest + +import torch + +import segmentation_models_pytorch as smp +from tests.utils import check_two_models_strictly_equal + + +@pytest.mark.parametrize("model_name", ["unet", "unetplusplus", "linknet", "manet"]) +@pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) +def test_seg_models_before_after_use_norm(model_name, decoder_option): + torch.manual_seed(42) + with pytest.warns(DeprecationWarning): + model_decoder_batchnorm = smp.create_model( + model_name, + "mobilenet_v2", + encoder_weights=None, + decoder_use_batchnorm=decoder_option, + ) + model_decoder_norm = smp.create_model( + model_name, + "mobilenet_v2", + encoder_weights=None, + decoder_use_norm=decoder_option, + ) + + model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) + + check_two_models_strictly_equal( + model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) + ) + + +@pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) +def test_pspnet_before_after_use_norm(decoder_option): + torch.manual_seed(42) + with pytest.warns(DeprecationWarning): + model_decoder_batchnorm = smp.create_model( + "pspnet", + "mobilenet_v2", + encoder_weights=None, + psp_use_batchnorm=decoder_option, + ) + model_decoder_norm = smp.create_model( + "pspnet", + "mobilenet_v2", + encoder_weights=None, + decoder_use_norm=decoder_option, + ) + model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) + + check_two_models_strictly_equal( + model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) + ) diff --git a/tests/encoders/test_common.py b/tests/encoders/test_common.py new file mode 100644 index 00000000..f94fd303 --- /dev/null +++ b/tests/encoders/test_common.py @@ -0,0 +1,15 @@ +import pytest +import segmentation_models_pytorch as smp +from tests.utils import slow_test + + +@pytest.mark.parametrize( + "encoder_name_and_weights", + [ + ("resnet18", "imagenet"), + ], +) +@slow_test +def test_load_encoder_from_hub(encoder_name_and_weights): + encoder_name, weights = encoder_name_and_weights + smp.encoders.get_encoder(encoder_name, weights=weights) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index bbde576c..2dcc7a52 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -1,54 +1,45 @@ +import segmentation_models_pytorch as smp + from tests.encoders import base from tests.utils import RUN_ALL_ENCODERS -class TestDenseNetEncoder(base.BaseEncoderTester): - supports_dilated = False - encoder_names = ( - ["densenet121"] - if not RUN_ALL_ENCODERS - else ["densenet121", "densenet169", "densenet161"] - ) - - class TestDPNEncoder(base.BaseEncoderTester): encoder_names = ( ["dpn68"] if not RUN_ALL_ENCODERS else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"] ) + files_for_diff = ["encoders/dpn.py"] + + def get_tiny_encoder(self): + params = { + "stage_idxs": [2, 3, 4, 6], + "out_channels": [3, 2, 70, 134, 262, 518], + "groups": 2, + "inc_sec": (2, 2, 2, 2), + "k_r": 2, + "k_sec": (1, 1, 1, 1), + "num_classes": 1000, + "num_init_features": 2, + "small": True, + "test_time_pool": True, + } + return smp.encoders.dpn.DPNEncoder(**params) class TestInceptionResNetV2Encoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ( ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] ) + files_for_diff = ["encoders/inceptionresnetv2.py"] + supports_dilated = False class TestInceptionV4Encoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] - - -class TestResNetEncoder(base.BaseEncoderTester): - encoder_names = ( - ["resnet18"] - if not RUN_ALL_ENCODERS - else [ - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - "resnext50_32x4d", - "resnext101_32x4d", - "resnext101_32x8d", - "resnext101_32x16d", - "resnext101_32x32d", - "resnext101_32x48d", - ] - ) + files_for_diff = ["encoders/inceptionv4.py"] + supports_dilated = False class TestSeNetEncoder(base.BaseEncoderTester): @@ -64,8 +55,26 @@ class TestSeNetEncoder(base.BaseEncoderTester): # "senet154", # extra large model ] ) + files_for_diff = ["encoders/senet.py"] + + def get_tiny_encoder(self): + params = { + "out_channels": [3, 2, 256, 512, 1024, 2048], + "block": smp.encoders.senet.SEResNetBottleneck, + "layers": [1, 1, 1, 1], + "downsample_kernel_size": 1, + "downsample_padding": 0, + "dropout_p": None, + "groups": 1, + "inplanes": 2, + "input_3x3": False, + "num_classes": 1000, + "reduction": 2, + } + return smp.encoders.senet.SENetEncoder(**params) class TestXceptionEncoder(base.BaseEncoderTester): supports_dilated = False encoder_names = ["xception"] if not RUN_ALL_ENCODERS else ["xception"] + files_for_diff = ["encoders/xception.py"] diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index 863537bf..29e2f416 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -1,3 +1,6 @@ +import segmentation_models_pytorch as smp +from functools import partial + from tests.encoders import base from tests.utils import RUN_ALL_ENCODERS @@ -14,6 +17,7 @@ class TestMobileoneEncoder(base.BaseEncoderTester): "mobileone_s4", ] ) + files_for_diff = ["encoders/mobileone.py"] class TestMixTransformerEncoder(base.BaseEncoderTester): @@ -22,6 +26,24 @@ class TestMixTransformerEncoder(base.BaseEncoderTester): if not RUN_ALL_ENCODERS else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"] ) + files_for_diff = ["encoders/mix_transformer.py"] + + def get_tiny_encoder(self): + params = { + "out_channels": [3, 0, 4, 4, 4, 4], + "patch_size": 4, + "embed_dims": [4, 4, 4, 4], + "num_heads": [1, 1, 1, 1], + "mlp_ratios": [1, 1, 1, 1], + "qkv_bias": True, + "norm_layer": partial(smp.encoders.mix_transformer.LayerNorm, eps=1e-6), + "depths": [1, 1, 1, 1], + "sr_ratios": [8, 4, 2, 1], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + } + + return smp.encoders.mix_transformer.MixVisionTransformerEncoder(**params) class TestEfficientNetEncoder(base.BaseEncoderTester): @@ -39,3 +61,4 @@ class TestEfficientNetEncoder(base.BaseEncoderTester): # "efficientnet-b7", # extra large model ] ) + files_for_diff = ["encoders/efficientnet.py"] diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index b467c968..3793606e 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -24,6 +24,7 @@ class TestTimmEfficientNetEncoder(base.BaseEncoderTester): "timm-tf_efficientnet_lite4", ] ) + files_for_diff = ["encoders/timm_efficientnet.py"] class TestTimmGERNetEncoder(base.BaseEncoderTester): @@ -33,6 +34,9 @@ class TestTimmGERNetEncoder(base.BaseEncoderTester): else ["timm-gernet_s", "timm-gernet_m", "timm-gernet_l"] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): encoder_names = ( @@ -48,6 +52,9 @@ class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmRegNetEncoder(base.BaseEncoderTester): encoder_names = ( @@ -81,9 +88,11 @@ class TestTimmRegNetEncoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmRes2NetEncoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ( ["timm-res2net50_26w_4s"] if not RUN_ALL_ENCODERS @@ -98,10 +107,12 @@ class TestTimmRes2NetEncoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmResnestEncoder(base.BaseEncoderTester): default_batch_size = 2 - supports_dilated = False encoder_names = ( ["timm-resnest14d"] if not RUN_ALL_ENCODERS @@ -117,6 +128,9 @@ class TestTimmResnestEncoder(base.BaseEncoderTester): ] ) + def test_compile(self): + self.skipTest("Test to be removed") + class TestTimmSkNetEncoder(base.BaseEncoderTester): default_batch_size = 2 @@ -129,3 +143,4 @@ class TestTimmSkNetEncoder(base.BaseEncoderTester): "timm-skresnext50_32x4d", ] ) + files_for_diff = ["encoders/timm_sknet.py"] diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index 753ee4de..99f8990f 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -9,8 +9,9 @@ ] if has_timm_test_models: - timm_encoders.append("tu-test_resnet.r160_in1k") + timm_encoders.insert(0, "tu-test_resnet.r160_in1k") class TestTimmUniversalEncoder(base.BaseEncoderTester): encoder_names = timm_encoders + files_for_diff = ["encoders/timm_universal.py"] diff --git a/tests/encoders/test_timm_vit_encoders.py b/tests/encoders/test_timm_vit_encoders.py new file mode 100644 index 00000000..260d926f --- /dev/null +++ b/tests/encoders/test_timm_vit_encoders.py @@ -0,0 +1,236 @@ +import timm +import torch +import pytest + +from segmentation_models_pytorch.encoders import TimmViTEncoder +from segmentation_models_pytorch.encoders.timm_vit import sample_block_indices_uniformly + +from tests.encoders import base +from tests.utils import ( + default_device, + check_run_test_on_diff_or_main, + requires_torch_greater_or_equal, + requires_timm_greater_or_equal, +) + +timm_vit_encoders = ["vit_tiny_patch16_224"] + + +@requires_timm_greater_or_equal("1.0.0") +class TestTimmViTEncoders(base.BaseEncoderTester): + encoder_names = timm_vit_encoders + tiny_encoder_patch_size = 224 + default_height = 224 + default_width = 224 + + files_for_diff = ["encoders/dpt.py"] + + num_output_features = 4 + default_depth = 4 + output_strides = None + supports_dilated = False + + depth_to_test = [2, 3, 4] + + def get_tiny_encoder(self) -> TimmViTEncoder: + return TimmViTEncoder( + name=self.encoder_names[0], + pretrained=False, + depth=self.default_depth, + in_channels=3, + ) + + def get_encoder(self, encoder_name: str, **kwargs) -> TimmViTEncoder: + default_kwargs = { + "name": encoder_name, + "pretrained": False, + "depth": self.default_depth, + "in_channels": 3, + } + default_kwargs.update(kwargs) + return TimmViTEncoder(**default_kwargs) + + def test_forward_backward(self): + for encoder_name in self.encoder_names: + sample = self._get_sample().to(default_device) + with self.subTest(encoder_name=encoder_name): + # init encoder + encoder = self.get_encoder(encoder_name).to(default_device) + + # forward + features, prefix_tokens = encoder.forward(sample) + self.assertEqual( + len(features), + self.num_output_features, + f"Encoder `{encoder_name}` should have {self.num_output_features} output feature maps, but has {len(features)}", + ) + if encoder.has_prefix_tokens: + self.assertEqual( + len(prefix_tokens), + self.num_output_features, + f"Encoder `{encoder_name}` should have {self.num_output_features} prefix tokens, but has {len(prefix_tokens)}", + ) + + # backward + features[-1].mean().backward() + + def test_in_channels(self): + cases = [ + (encoder_name, in_channels) + for encoder_name in self.encoder_names + for in_channels in self.in_channels_to_test + ] + + for encoder_name, in_channels in cases: + sample = self._get_sample(num_channels=in_channels).to(default_device) + + with self.subTest(encoder_name=encoder_name, in_channels=in_channels): + encoder = self.get_encoder(encoder_name, in_channels=in_channels).to( + default_device + ) + encoder.eval() + + # forward + with torch.inference_mode(): + encoder.forward(sample) + + def test_depth(self): + cases = [ + (encoder_name, depth) + for encoder_name in self.encoder_names + for depth in self.depth_to_test + ] + + for encoder_name, depth in cases: + sample = self._get_sample().to(default_device) + with self.subTest(encoder_name=encoder_name, depth=depth): + encoder = self.get_encoder(encoder_name, depth=depth).to(default_device) + encoder.eval() + + # forward + with torch.inference_mode(): + features, _ = encoder.forward(sample) + + # check number of features + self.assertEqual( + len(features), + depth, + f"Encoder `{encoder_name}` should have {depth} output feature maps, but has {len(features)}", + ) + + # check feature strides + height_strides, width_strides = self.get_features_output_strides( + sample, features + ) + + encoder_out_indices = sample_block_indices_uniformly(depth, 12) + feature_info = timm.create_model(model_name=encoder_name).feature_info + output_strides = [ + feature_info[i]["reduction"] for i in encoder_out_indices + ] + + self.assertEqual( + height_strides, + output_strides, + f"Encoder `{encoder_name}` should have output strides {output_strides}, but has {height_strides}", + ) + self.assertEqual( + width_strides, + output_strides, + f"Encoder `{encoder_name}` should have output strides {output_strides}, but has {width_strides}", + ) + + # check encoder output stride property + self.assertEqual( + encoder.output_strides, + output_strides, + f"Encoder `{encoder_name}` last feature map should have output stride {output_strides[depth - 1]}, but has {encoder.output_stride}", + ) + + # check out channels also have proper length + self.assertEqual( + len(encoder.out_channels), + depth, + f"Encoder `{encoder_name}` should have {depth} out_channels, but has {len(encoder.out_channels)}", + ) + + def test_invalid_depth(self): + with self.assertRaises(ValueError): + self.get_encoder(self.encoder_names[0], depth=0) + with self.assertRaises(ValueError): + self.get_encoder(self.encoder_names[0], depth=25) + + def test_invalid_out_indices(self): + # out of range + with self.assertRaises(ValueError): + self.get_encoder(self.encoder_names[0], depth=1, output_indices=-25) + with self.assertRaises(ValueError): + self.get_encoder(self.encoder_names[0], depth=3, output_indices=[1, 2, 25]) + + # invalid length + with self.assertRaises(ValueError): + self.get_encoder( + self.encoder_names[0], + depth=2, + output_indices=[ + 2, + ], + ) + + def test_dilated(self): + pytest.skip("Dilation is not supported for ViT encoders") + + @pytest.mark.compile + def test_compile(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + sample = self._get_sample( + height=self.tiny_encoder_patch_size, width=self.tiny_encoder_patch_size + ).to(default_device) + + torch.compiler.reset() + compiled_encoder = torch.compile( + encoder, fullgraph=True, dynamic=True, backend="eager" + ) + + if encoder._is_torch_compilable: + compiled_encoder(sample) + else: + with self.assertRaises(Exception): + compiled_encoder(sample) + + @pytest.mark.torch_export + @requires_torch_greater_or_equal("2.4.0") + def test_torch_export(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample( + height=self.tiny_encoder_patch_size, width=self.tiny_encoder_patch_size + ).to(default_device) + + encoder = self.get_tiny_encoder() + encoder = encoder.eval().to(default_device) + + exported_encoder = torch.export.export( + encoder, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = encoder(sample) + exported_output = exported_encoder.module().forward(sample) + + for eager_feature, exported_feature in zip(eager_output, exported_output): + torch.testing.assert_close(eager_feature, exported_feature) + + @pytest.mark.torch_script + def test_torch_script(self): + pytest.skip( + "Encoder with prefix tokens are not supported for scripting, due to poor type handling" + ) diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index 99b8b9d5..c0d7c64f 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -1,9 +1,60 @@ +import segmentation_models_pytorch as smp + from tests.encoders import base from tests.utils import RUN_ALL_ENCODERS -class TestMobileoneEncoder(base.BaseEncoderTester): +class TestResNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["resnet18"] + if not RUN_ALL_ENCODERS + else [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x4d", + "resnext101_32x8d", + "resnext101_32x16d", + "resnext101_32x32d", + "resnext101_32x48d", + ] + ) + files_for_diff = ["encoders/resnet.py"] + + def get_tiny_encoder(self): + params = { + "out_channels": [3, 64, 64, 128, 256, 512], + "block": smp.encoders.resnet.BasicBlock, + "layers": [1, 1, 1, 1], + } + return smp.encoders.resnet.ResNetEncoder(**params) + + +class TestDenseNetEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["densenet121"] + if not RUN_ALL_ENCODERS + else ["densenet121", "densenet169", "densenet161"] + ) + files_for_diff = ["encoders/densenet.py"] + + def get_tiny_encoder(self): + params = { + "out_channels": [3, 2, 3, 2, 2, 2], + "num_init_features": 2, + "growth_rate": 1, + "block_config": (1, 1, 1, 1), + } + return smp.encoders.densenet.DenseNetEncoder(**params) + + +class TestMobileNetEncoder(base.BaseEncoderTester): encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] + files_for_diff = ["encoders/mobilenet.py"] class TestVggEncoder(base.BaseEncoderTester): @@ -22,3 +73,12 @@ class TestVggEncoder(base.BaseEncoderTester): "vgg19_bn", ] ) + files_for_diff = ["encoders/vgg.py"] + + def get_tiny_encoder(self): + params = { + "out_channels": [4, 4, 4, 4, 4, 4], + "config": [4, "M", 4, "M", 4, "M", 4, "M", 4, "M"], + "batch_norm": False, + } + return smp.encoders.vgg.VGGEncoder(**params) diff --git a/tests/models/base.py b/tests/models/base.py index 02e17303..2f317348 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -14,6 +14,7 @@ default_device, slow_test, requires_torch_greater_or_equal, + check_run_test_on_diff_or_main, ) @@ -21,6 +22,7 @@ class BaseModelTester(unittest.TestCase): test_encoder_name = ( "tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18" ) + files_for_diff = [r".*"] # should be overriden test_model_type = None @@ -31,6 +33,8 @@ class BaseModelTester(unittest.TestCase): default_height = 64 default_width = 64 + compile_dynamic = True + @property def model_type(self): if self.test_model_type is None: @@ -54,19 +58,23 @@ def decoder_channels(self): return None @lru_cache - def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): + batch_size = batch_size or self.default_batch_size + num_channels = num_channels or self.default_num_channels + height = height or self.default_height + width = width or self.default_width return torch.rand(batch_size, num_channels, height, width) + @lru_cache + def get_default_model(self): + model = smp.create_model(self.model_type, self.test_encoder_name) + model = model.to(default_device) + return model + def test_forward_backward(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) - model = smp.create_model( - arch=self.model_type, encoder_name=self.test_encoder_name - ).to(default_device) + sample = self._get_sample().to(default_device) + + model = self.get_default_model() # check default in_channels=3 output = model(sample) @@ -91,23 +99,26 @@ def test_in_channels_and_depth_and_out_classes( if self.model_type in ["unet", "unetplusplus", "manet"]: kwargs = {"decoder_channels": self.decoder_channels[:depth]} - model = smp.create_model( - arch=self.model_type, - encoder_name=self.test_encoder_name, - encoder_depth=depth, - in_channels=in_channels, - classes=classes, - **kwargs, - ).to(default_device) - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=in_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + if self.model_type == "dpt": + kwargs = {"decoder_intermediate_channels": self.decoder_channels[:depth]} + + model = ( + smp.create_model( + arch=self.model_type, + encoder_name=self.test_encoder_name, + encoder_depth=depth, + in_channels=in_channels, + classes=classes, + **kwargs, + ) + .to(default_device) + .eval() + ) + + sample = self._get_sample(num_channels=in_channels).to(default_device) # check in channels correctly set - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) self.assertEqual(output.shape[1], classes) @@ -122,7 +133,8 @@ def test_classification_head(self): "dropout": 0.5, "activation": "sigmoid", }, - ).to(default_device) + ) + model = model.to(default_device).eval() self.assertIsNotNone(model.classification_head) self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d) @@ -132,24 +144,37 @@ def test_classification_head(self): self.assertIsInstance(model.classification_head[3], torch.nn.Linear) self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid) - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) - with torch.no_grad(): + with torch.inference_mode(): _, cls_probs = model(sample) self.assertEqual(cls_probs.shape[1], 10) + def test_any_resolution(self): + model = self.get_default_model() + + sample = self._get_sample( + height=self.default_height + 3, + width=self.default_width + 7, + ).to(default_device) + + if model.requires_divisible_input_shape: + with self.assertRaises(RuntimeError, msg="Wrong input shape"): + output = model(sample) + return + + with torch.inference_mode(): + output = model(sample) + + self.assertEqual(output.shape[2], self.default_height + 3) + self.assertEqual(output.shape[3], self.default_width + 7) + @requires_torch_greater_or_equal("2.0.1") def test_save_load_with_hub_mixin(self): # instantiate model - model = smp.create_model( - arch=self.model_type, encoder_name=self.test_encoder_name - ).to(default_device) + model = self.get_default_model() + model.eval() # save model with tempfile.TemporaryDirectory() as tmpdir: @@ -157,18 +182,15 @@ def test_save_load_with_hub_mixin(self): tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99} ) restored_model = smp.from_pretrained(tmpdir).to(default_device) + restored_model.eval() + with open(os.path.join(tmpdir, "README.md"), "r") as f: readme = f.read() # check inference is correct - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) restored_output = restored_model(sample) @@ -197,10 +219,80 @@ def test_preserve_forward_output(self): output_tensor = torch.load(output_tensor_path, weights_only=True) output_tensor = output_tensor.to(default_device) - with torch.no_grad(): + with torch.inference_mode(): output = model(input_tensor) self.assertEqual(output.shape, output_tensor.shape) is_close = torch.allclose(output, output_tensor, atol=5e-2) max_diff = torch.max(torch.abs(output - output_tensor)) self.assertTrue(is_close, f"Max diff: {max_diff}") + + @pytest.mark.compile + def test_compile(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample().to(default_device) + model = self.get_default_model() + model = model.eval().to(default_device) + + if not model._is_torch_compilable: + with self.assertRaises((RuntimeError)): + torch.compiler.reset() + compiled_model = torch.compile( + model, fullgraph=True, dynamic=self.compile_dynamic, backend="eager" + ) + return + + torch.compiler.reset() + compiled_model = torch.compile( + model, fullgraph=True, dynamic=self.compile_dynamic, backend="eager" + ) + with torch.inference_mode(): + compiled_model(sample) + + @pytest.mark.torch_export + def test_torch_export(self, eps=1e-5): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + torch.manual_seed(42) + sample = self._get_sample().to(default_device) + model = self.get_default_model() + model.eval() + + exported_model = torch.export.export( + model, + args=(sample,), + strict=True, + ) + + with torch.inference_mode(): + eager_output = model(sample) + exported_output = exported_model.module().forward(sample) + + self.assertEqual(eager_output.shape, exported_output.shape) + torch.testing.assert_close(eager_output, exported_output, rtol=eps, atol=eps) + + @pytest.mark.torch_script + def test_torch_script(self): + if not check_run_test_on_diff_or_main(self.files_for_diff): + self.skipTest("No diff and not on `main`.") + + sample = self._get_sample().to(default_device) + model = self.get_default_model() + model.eval() + + if not model._is_torch_scriptable: + with self.assertRaises(RuntimeError): + scripted_model = torch.jit.script(model) + return + + scripted_model = torch.jit.script(model) + + with torch.inference_mode(): + scripted_output = scripted_model(sample) + eager_output = model(sample) + + self.assertEqual(scripted_output.shape, eager_output.shape) + torch.testing.assert_close(scripted_output, eager_output, rtol=1e-3, atol=1e-3) diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py index d3d350e9..de112633 100644 --- a/tests/models/test_deeplab.py +++ b/tests/models/test_deeplab.py @@ -1,16 +1,15 @@ -import pytest from tests.models import base -@pytest.mark.deeplabv3 class TestDeeplabV3Model(base.BaseModelTester): test_model_type = "deeplabv3" + files_for_diff = [r"decoders/deeplabv3/", r"base/"] default_batch_size = 2 -@pytest.mark.deeplabv3plus class TestDeeplabV3PlusModel(base.BaseModelTester): test_model_type = "deeplabv3plus" + files_for_diff = [r"decoders/deeplabv3plus/", r"base/"] default_batch_size = 2 diff --git a/tests/models/test_dpt.py b/tests/models/test_dpt.py new file mode 100644 index 00000000..40df1e38 --- /dev/null +++ b/tests/models/test_dpt.py @@ -0,0 +1,60 @@ +import pytest +import inspect +import torch +import segmentation_models_pytorch as smp + +from tests.models import base +from tests.utils import ( + slow_test, + default_device, + requires_torch_greater_or_equal, +) + + +class TestDPTModel(base.BaseModelTester): + test_encoder_name = "tu-vit_tiny_patch16_224" + files_for_diff = [r"decoders/dpt/", r"base/"] + + default_height = 224 + default_width = 224 + + # should be overriden + test_model_type = "dpt" + + compile_dynamic = False + + @property + def decoder_channels(self): + signature = inspect.signature(self.model_class) + return signature.parameters["decoder_intermediate_channels"].default + + @property + def hub_checkpoint(self): + return "smp-test-models/dpt-tu-test_vit" + + @slow_test + @requires_torch_greater_or_equal("2.0.1") + @pytest.mark.logits_match + def test_load_pretrained(self): + hub_checkpoint = "smp-hub/dpt-large-ade20k" + + model = smp.from_pretrained(hub_checkpoint) + model = model.eval().to(default_device) + + input_tensor = torch.ones((1, 3, 384, 384)) + input_tensor = input_tensor.to(default_device) + + expected_logits_slice = torch.tensor( + [3.4166, 3.4422, 3.4677, 3.2784, 3.0880, 2.9497] + ) + with torch.inference_mode(): + output = model(input_tensor) + + resulted_logits_slice = output[0, 0, 0, 0:6].cpu() + + self.assertEqual(expected_logits_slice.shape, resulted_logits_slice.shape) + is_close = torch.allclose( + expected_logits_slice, resulted_logits_slice, atol=5e-2 + ) + max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice)) + self.assertTrue(is_close, f"Max diff: {max_diff}") diff --git a/tests/models/test_fpn.py b/tests/models/test_fpn.py index 15ae1f6a..e0db74bc 100644 --- a/tests/models/test_fpn.py +++ b/tests/models/test_fpn.py @@ -1,7 +1,29 @@ -import pytest +import segmentation_models_pytorch as smp + from tests.models import base -@pytest.mark.fpn class TestFpnModel(base.BaseModelTester): test_model_type = "fpn" + files_for_diff = [r"decoders/fpn/", r"base/"] + + def test_interpolation(self): + # test bilinear + model_1 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bilinear", + ) + assert model_1.decoder.p2.interpolation_mode == "bilinear" + assert model_1.decoder.p3.interpolation_mode == "bilinear" + assert model_1.decoder.p4.interpolation_mode == "bilinear" + + # test bicubic + model_2 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bicubic", + ) + assert model_2.decoder.p2.interpolation_mode == "bicubic" + assert model_2.decoder.p3.interpolation_mode == "bicubic" + assert model_2.decoder.p4.interpolation_mode == "bicubic" diff --git a/tests/models/test_linknet.py b/tests/models/test_linknet.py index 1ab5eb4e..6f9490d9 100644 --- a/tests/models/test_linknet.py +++ b/tests/models/test_linknet.py @@ -1,7 +1,6 @@ -import pytest from tests.models import base -@pytest.mark.linknet class TestLinknetModel(base.BaseModelTester): test_model_type = "linknet" + files_for_diff = [r"decoders/linknet/", r"base/"] diff --git a/tests/models/test_manet.py b/tests/models/test_manet.py index 33a8ae3b..0e2dbf9b 100644 --- a/tests/models/test_manet.py +++ b/tests/models/test_manet.py @@ -1,7 +1,27 @@ -import pytest +import segmentation_models_pytorch as smp + from tests.models import base -@pytest.mark.manet class TestManetModel(base.BaseModelTester): test_model_type = "manet" + files_for_diff = [r"decoders/manet/", r"base/"] + + def test_interpolation(self): + # test bilinear + model_1 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bilinear", + ) + for block in model_1.decoder.blocks: + assert block.interpolation_mode == "bilinear" + + # test bicubic + model_2 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bicubic", + ) + for block in model_2.decoder.blocks: + assert block.interpolation_mode == "bicubic" diff --git a/tests/models/test_pan.py b/tests/models/test_pan.py index d66fefe0..8edb833a 100644 --- a/tests/models/test_pan.py +++ b/tests/models/test_pan.py @@ -1,11 +1,48 @@ import pytest +import segmentation_models_pytorch as smp + from tests.models import base -@pytest.mark.pan class TestPanModel(base.BaseModelTester): test_model_type = "pan" + files_for_diff = [r"decoders/pan/", r"base/"] default_batch_size = 2 default_height = 128 default_width = 128 + + def test_interpolation(self): + # test bilinear + model_1 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bilinear", + ) + assert model_1.decoder.gau1.interpolation_mode == "bilinear" + assert model_1.decoder.gau1.align_corners is True + assert model_1.decoder.gau2.interpolation_mode == "bilinear" + assert model_1.decoder.gau2.align_corners is True + assert model_1.decoder.gau3.interpolation_mode == "bilinear" + assert model_1.decoder.gau3.align_corners is True + + # test bicubic + model_2 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bicubic", + ) + assert model_2.decoder.gau1.interpolation_mode == "bicubic" + assert model_2.decoder.gau1.align_corners is None + assert model_2.decoder.gau2.interpolation_mode == "bicubic" + assert model_2.decoder.gau2.align_corners is None + assert model_2.decoder.gau3.interpolation_mode == "bicubic" + assert model_2.decoder.gau3.align_corners is None + + with pytest.warns(DeprecationWarning): + smp.create_model( + self.test_model_type, + self.test_encoder_name, + upscale_mode="bicubic", + ) + assert model_2.decoder.gau1.interpolation_mode == "bicubic" diff --git a/tests/models/test_psp.py b/tests/models/test_psp.py index 2603cdda..c29b5e99 100644 --- a/tests/models/test_psp.py +++ b/tests/models/test_psp.py @@ -1,9 +1,8 @@ -import pytest from tests.models import base -@pytest.mark.psp class TestPspModel(base.BaseModelTester): test_model_type = "pspnet" + files_for_diff = [r"decoders/pspnet/", r"base/"] default_batch_size = 2 diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 3ca5016c..b0f288ef 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -6,9 +6,9 @@ from tests.utils import slow_test, default_device, requires_torch_greater_or_equal -@pytest.mark.segformer class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" + files_for_diff = [r"decoders/segformer/", r"base/"] @slow_test @requires_torch_greater_or_equal("2.0.1") @@ -21,7 +21,7 @@ def test_load_pretrained(self): sample = torch.ones([1, 3, 512, 512]).to(default_device) - with torch.no_grad(): + with torch.inference_mode(): output = model(sample) self.assertEqual(output.shape, (1, 150, 512, 512)) diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py index 54c69bf0..98e37206 100644 --- a/tests/models/test_unet.py +++ b/tests/models/test_unet.py @@ -1,7 +1,26 @@ -import pytest +import segmentation_models_pytorch as smp from tests.models import base -@pytest.mark.unet class TestUnetModel(base.BaseModelTester): test_model_type = "unet" + files_for_diff = [r"decoders/unet/", r"base/"] + + def test_interpolation(self): + # test bilinear + model_1 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bilinear", + ) + for block in model_1.decoder.blocks: + assert block.interpolation_mode == "bilinear" + + # test bicubic + model_2 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bicubic", + ) + for block in model_2.decoder.blocks: + assert block.interpolation_mode == "bicubic" diff --git a/tests/models/test_unetplusplus.py b/tests/models/test_unetplusplus.py index 9e67f2ed..1d958ae3 100644 --- a/tests/models/test_unetplusplus.py +++ b/tests/models/test_unetplusplus.py @@ -1,7 +1,35 @@ -import pytest +import segmentation_models_pytorch as smp + from tests.models import base -@pytest.mark.unetplusplus class TestUnetPlusPlusModel(base.BaseModelTester): test_model_type = "unetplusplus" + files_for_diff = [r"decoders/unetplusplus/", r"base/"] + + def test_interpolation(self): + # test bilinear + model_1 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bilinear", + ) + is_tested = False + for module in model_1.decoder.modules(): + if module.__class__.__name__ == "DecoderBlock": + assert module.interpolation_mode == "bilinear" + is_tested = True + assert is_tested + + # test bicubic + model_2 = smp.create_model( + self.test_model_type, + self.test_encoder_name, + decoder_interpolation="bicubic", + ) + is_tested = False + for module in model_2.decoder.modules(): + if module.__class__.__name__ == "DecoderBlock": + assert module.interpolation_mode == "bicubic" + is_tested = True + assert is_tested diff --git a/tests/models/test_upernet.py b/tests/models/test_upernet.py index 71d703f9..a69062ae 100644 --- a/tests/models/test_upernet.py +++ b/tests/models/test_upernet.py @@ -1,8 +1,14 @@ import pytest + from tests.models import base -@pytest.mark.upernet class TestUnetModel(base.BaseModelTester): test_model_type = "upernet" + files_for_diff = [r"decoders/upernet/", r"base/"] + default_batch_size = 2 + + @pytest.mark.torch_export + def test_torch_export(self): + super().test_torch_export(eps=1e-3) diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..1078c493 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,36 @@ +import torch +import tempfile +import segmentation_models_pytorch as smp + +import pytest + + +def test_from_pretrained_with_mismatched_keys(): + original_model = smp.Unet(classes=1) + + with tempfile.TemporaryDirectory() as temp_dir: + original_model.save_pretrained(temp_dir) + + # we should catch warning here and check if there specific keys there + with pytest.warns(UserWarning): + restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False) + + assert restored_model.segmentation_head[0].out_channels == 2 + + # verify all the weight are the same expect mismatched ones + original_state_dict = original_model.state_dict() + restored_state_dict = restored_model.state_dict() + + expected_mismatched_keys = [ + "segmentation_head.0.weight", + "segmentation_head.0.bias", + ] + mismatched_keys = [] + for key in original_state_dict: + if key not in expected_mismatched_keys: + assert torch.allclose(original_state_dict[key], restored_state_dict[key]) + else: + mismatched_keys.append(key) + + assert len(mismatched_keys) == 2 + assert sorted(mismatched_keys) == sorted(expected_mismatched_keys) diff --git a/tests/test_losses.py b/tests/test_losses.py index 5c3ad75a..94d85d5c 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -93,7 +93,7 @@ def test_soft_tversky_score(y_true, y_pred, expected, eps, alpha, beta): assert float(actual) == pytest.approx(expected, eps) -@torch.no_grad() +@torch.inference_mode() def test_dice_loss_binary(): eps = 1e-5 criterion = DiceLoss(mode=smp.losses.BINARY_MODE, from_logits=False) @@ -131,7 +131,7 @@ def test_dice_loss_binary(): assert float(loss) == pytest.approx(1.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_tversky_loss_binary(): eps = 1e-5 # with alpha=0.5; beta=0.5 it is equal to DiceLoss @@ -172,7 +172,7 @@ def test_tversky_loss_binary(): assert float(loss) == pytest.approx(1.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_binary_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode=smp.losses.BINARY_MODE, from_logits=False) @@ -210,7 +210,7 @@ def test_binary_jaccard_loss(): assert float(loss) == pytest.approx(1.0, eps) -@torch.no_grad() +@torch.inference_mode() def test_multiclass_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode=smp.losses.MULTICLASS_MODE, from_logits=False) @@ -237,7 +237,7 @@ def test_multiclass_jaccard_loss(): assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_multilabel_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode=smp.losses.MULTILABEL_MODE, from_logits=False) @@ -263,7 +263,7 @@ def test_multilabel_jaccard_loss(): assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps) -@torch.no_grad() +@torch.inference_mode() def test_soft_ce_loss(): criterion = SoftCrossEntropyLoss(smooth_factor=0.1, ignore_index=-100) @@ -276,7 +276,7 @@ def test_soft_ce_loss(): assert float(loss) == pytest.approx(1.0125, abs=0.0001) -@torch.no_grad() +@torch.inference_mode() def test_soft_bce_loss(): criterion = SoftBCEWithLogitsLoss(smooth_factor=0.1, ignore_index=-100) @@ -287,7 +287,7 @@ def test_soft_bce_loss(): assert float(loss) == pytest.approx(0.7201, abs=0.0001) -@torch.no_grad() +@torch.inference_mode() def test_binary_mcc_loss(): eps = 1e-5 criterion = MCCLoss(eps=eps) diff --git a/tests/utils.py b/tests/utils.py index e8bce88e..f9e50fc2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,31 +1,21 @@ import os +import re import timm import torch import unittest +from git import Repo +from typing import List from packaging.version import Version has_timm_test_models = Version(timm.__version__) >= Version("1.0.12") default_device = "cuda" if torch.cuda.is_available() else "cpu" - -def get_commit_message(): - commit_msg = os.getenv("COMMIT_MESSAGE", "") - return commit_msg.lower() - - -# Check both environment variables and commit message -commit_message = get_commit_message() -RUN_ALL_ENCODERS = ( - os.getenv("RUN_ALL_ENCODERS", "false").lower() in ["true", "1", "y", "yes"] - or "run-all-encoders" in commit_message -) - -RUN_SLOW = ( - os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"] - or "run-slow" in commit_message -) +YES_LIST = ["true", "1", "y", "yes"] +RUN_ALL_ENCODERS = os.getenv("RUN_ALL_ENCODERS", "false").lower() in YES_LIST +RUN_SLOW = os.getenv("RUN_SLOW", "false").lower() in YES_LIST +RUN_ALL = os.getenv("RUN_ALL", "false").lower() in YES_LIST def slow_test(test_case): @@ -38,6 +28,15 @@ def slow_test(test_case): return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case) +def requires_timm_greater_or_equal(version: str): + timm_version = Version(timm.__version__) + provided_version = Version(version) + return unittest.skipUnless( + timm_version >= provided_version, + f"timm version {timm_version} is less than {provided_version}", + ) + + def requires_torch_greater_or_equal(version: str): torch_version = Version(torch.__version__) provided_version = Version(version) @@ -45,3 +44,46 @@ def requires_torch_greater_or_equal(version: str): torch_version >= provided_version, f"torch version {torch_version} is less than {provided_version}", ) + + +def check_run_test_on_diff_or_main(filepath_patterns: List[str]): + if RUN_ALL: + return True + + try: + repo = Repo(".") + current_branch = repo.active_branch.name + diff_files = repo.git.diff("main", name_only=True).splitlines() + + except Exception: + return True + + if current_branch == "main": + return True + + for pattern in filepath_patterns: + for file_path in diff_files: + if re.search(pattern, file_path): + return True + + return False + + +def check_two_models_strictly_equal( + model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor +) -> None: + for (k1, v1), (k2, v2) in zip( + model_a.state_dict().items(), model_b.state_dict().items() + ): + assert k1 == k2, f"Key mismatch: {k1} != {k2}" + torch.testing.assert_close( + v1, v2, msg=f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" + ) + + model_a.eval() + model_b.eval() + with torch.inference_mode(): + output_a = model_a(input_data) + output_b = model_b(input_data) + + torch.testing.assert_close(output_a, output_b)