Source code for cyto_dl.nn.res_unit

# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ADAPTED FROM: https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/convolutions.py

from typing import Sequence, Union

import numpy as np
import torch
import torch.nn as nn
from monai.networks.blocks import ADN, Convolution
from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding
from monai.networks.layers.factories import Conv


[docs]class ResidualUnit(nn.Module): """Residual module with multiple convolutions and a residual connection. For example: .. code-block:: python from monai.networks.blocks import ResidualUnit convs = ResidualUnit( spatial_dims=3, in_channels=1, out_channels=1, adn_ordering="AN", act=("prelu", {"init": 0.2}), norm=("layer", {"normalized_shape": (10, 10, 10)}), ) print(convs) output:: ResidualUnit( (conv): Sequential( (unit0): Convolution( (conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) (adn): ADN( (A): PReLU(num_parameters=1) (N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True) ) ) (unit1): Convolution( (conv): Conv3d(1, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) (adn): ADN( (A): PReLU(num_parameters=1) (N): LayerNorm((10, 10, 10), eps=1e-05, elementwise_affine=True) ) ) ) (residual): Identity() ) Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. strides: convolution stride. Defaults to 1. kernel_size: convolution kernel size. Defaults to 3. subunits: number of convolutions. Defaults to 2. adn_ordering: a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA". act: activation type and arguments. Defaults to PReLU. norm: feature normalization type and arguments. Defaults to instance norm. dropout: dropout ratio. Defaults to no dropout. dropout_dim: determine the dimensions of dropout. Defaults to 1. - When dropout_dim = 1, randomly zeroes some of the elements for each channel. - When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map). - When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map). The value of dropout_dim should be no larger than the value of `dimensions`. dilation: dilation rate. Defaults to 1. bias: whether to have a bias term. Defaults to True. last_conv_only: for the last subunit, whether to use the convolutional layer only. Defaults to False. padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension. Defaults to None. See also: :py:class:`monai.networks.blocks.Convolution` """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, strides: Union[Sequence[int], int] = 1, kernel_size: Union[Sequence[int], int] = 3, subunits: int = 2, adn_ordering: str = "NDA", act: Union[tuple, str, None] = "PRELU", norm: Union[tuple, str, None] = "INSTANCE", dropout: Union[tuple, str, float, None] = None, dropout_dim: Union[int, None] = 1, dilation: Union[Sequence[int], int] = 1, bias: bool = True, last_conv_only: bool = False, padding: Union[Sequence[int], int, None] = None, ) -> None: super().__init__() self.spatial_dims = spatial_dims self.in_channels = in_channels self.out_channels = out_channels self.conv = nn.Sequential() self.residual = nn.Identity() _same_padding = same_padding(kernel_size, dilation) if padding is None: padding = _same_padding schannels = in_channels sstrides = strides subunits = max(1, subunits) spadding = padding for su in range(subunits): conv_only = last_conv_only and su == (subunits - 1) unit = Convolution( self.spatial_dims, schannels, out_channels, strides=sstrides, kernel_size=kernel_size, adn_ordering=adn_ordering, act=act, norm=norm, dropout=dropout, dropout_dim=dropout_dim, dilation=dilation, bias=bias, conv_only=conv_only, padding=spadding, ) self.conv.add_module(f"unit{su:d}", unit) # after first loop set channels and strides to what they should be for subsequent units schannels = out_channels sstrides = 1 spadding = _same_padding # apply convolution to input to change number of output channels and size to match that coming from self.conv if ( np.prod(strides) != 1 or in_channels != out_channels or (np.prod(strides) == 1 and padding != _same_padding) ): rkernel_size = kernel_size rpadding = padding if ( np.prod(strides) == 1 and padding == _same_padding ): # if only adapting number of channels a 1x1 kernel is used with no padding rkernel_size = 1 rpadding = 0 conv_type = Conv[Conv.CONV, self.spatial_dims] self.residual = conv_type( in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: res: torch.Tensor = self.residual(x) # create the additive residual from x cx: torch.Tensor = self.conv(x) # apply x to sequence of operations return cx + res # add the residual to the output