cyto_dl.nn.torchvision_wrapper module#

class cyto_dl.nn.torchvision_wrapper.TorchVisionWrapper(base_encoder, in_channels=1)[source]#

Bases: Module

Wrap a torchvision model to accept a different number of input channels.

Parameters:
  • base_encoder – An initialized torchvision model. The following models are supported: ShuffleNetV2, ConvNeXt, QuantizableShuffleNetV2, MaxVit, MobileNetV2, QuantizableResNet, QuantizableMobileNetV2, VGG, GoogLeNet, FCN, MNASNet, SwinTransformer, SqueezeNet, VisionTransformer, AlexNet, DenseNet, QuantizableGoogLeNet, ResNet, LRASPP, QuantizableMobileNetV3, Inception3, QuantizableInception3, RegNet, MobileNetV3, EfficientNet

  • in_channels – number of input channels (default: 1)

VALID_MODELS = ('ShuffleNetV2', 'ConvNeXt', 'QuantizableShuffleNetV2', 'MaxVit', 'MobileNetV2', 'QuantizableResNet', 'QuantizableMobileNetV2', 'VGG', 'GoogLeNet', 'FCN', 'MNASNet', 'SwinTransformer', 'SqueezeNet', 'VisionTransformer', 'AlexNet', 'DenseNet', 'QuantizableGoogLeNet', 'ResNet', 'LRASPP', 'QuantizableMobileNetV3', 'Inception3', 'QuantizableInception3', 'RegNet', 'MobileNetV3', 'EfficientNet')#
forward(x)[source]#