Source code for monai.networks.layers.factories

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
handles spatial 1D, 2D, 3D network components with a factory pattern.
"""

from torch import nn as nn


def get_conv_type(dim, is_transpose):
    if is_transpose:
        types = [nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d]
    else:
        types = [nn.Conv1d, nn.Conv2d, nn.Conv3d]

    return types[dim - 1]


def get_dropout_type(dim):
    types = [nn.Dropout, nn.Dropout2d, nn.Dropout3d]
    return types[dim - 1]


def get_normalize_type(dim, is_instance):
    if is_instance:
        types = [nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d]
    else:
        types = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]

    return types[dim - 1]


def get_maxpooling_type(dim, is_adaptive):
    if is_adaptive:
        types = [nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d]
    else:
        types = [nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d]
    return types[dim - 1]


def get_avgpooling_type(dim, is_adaptive):
    if is_adaptive:
        types = [nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d]
    else:
        types = [nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d]
    return types[dim - 1]