cyto_dl.nn.head.gan_head_superres module#

class cyto_dl.nn.head.gan_head_superres.GANHead_resize(in_channels: int, out_channels: int, gan_loss=Pix2PixHD(   (gan_loss): GANLoss(     (loss): BCEWithLogitsLoss()   )   (feature_matching_loss): L1Loss() ), reconstruction_loss=MSELoss(), reconstruction_loss_weight=100, postprocess={'input': <function detach>, 'prediction': <function detach>}, final_act: ~typing.Callable = Identity(), resolution='lr', spatial_dims=3, n_convs=1, dropout=0.0, upsample_method='pixelshuffle', upsample_ratio=None, first_layer=Identity(), dense: bool = False)[source]#

Bases: GANHead, ResBlocksHead

Inherit run_head from GANHead, use __init__ and forward of ResBlocksHead.

Parameters:
  • gan_loss=Pix2PixHD(scales=1) – Loss for optimizing GAN

  • reconstruction_loss=torch.nn.MSELoss() – Loss for optimizing generator’s image reconstructions

  • reconstruction_loss_weight=100 – Weighting of reconstruction loss

  • postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head

forward(x)[source]#