Source code for cyto_dl.nn.hr_skip

import torch
from monai.networks.blocks import Convolution


[docs]class HRSkip(torch.nn.Module): def __init__( self, in_channels, out_channels, dropout=0.3, scale_factor=[2, 2, 2], mode="nearest", align_corners=None, antialias=False, ): super().__init__() self.model = Convolution( spatial_dims=3, in_channels=in_channels, out_channels=out_channels, act=("leakyrelu", {"inplace": True, "negative_slope": 0.2}), norm="INSTANCE", dropout=dropout, ) self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners self.antialias = antialias
[docs] def forward(self, x): x = torch.nn.functional.interpolate( input=x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, antialias=self.antialias, ) return self.model(x)