cyto_dl.nn.head.gan_head_superres module#
- class cyto_dl.nn.head.gan_head_superres.GANHead_resize(in_channels: int, out_channels: int, gan_loss=Pix2PixHD( (gan_loss): GANLoss( (loss): BCEWithLogitsLoss() ) (feature_matching_loss): L1Loss() ), reconstruction_loss=MSELoss(), reconstruction_loss_weight=100, postprocess={'input': <function detach>, 'prediction': <function detach>}, final_act: ~typing.Callable = Identity(), resolution='lr', spatial_dims=3, n_convs=1, dropout=0.0, upsample_method='pixelshuffle', upsample_ratio=None, first_layer=Identity(), dense: bool = False)[source]#
Bases:
GANHead
,ResBlocksHead
Inherit run_head from GANHead, use __init__ and forward of ResBlocksHead.
- Parameters:
gan_loss=Pix2PixHD(scales=1) – Loss for optimizing GAN
reconstruction_loss=torch.nn.MSELoss() – Loss for optimizing generator’s image reconstructions
reconstruction_loss_weight=100 – Weighting of reconstruction loss
postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head