cyto_dl.nn.head.gan_head module#
- class cyto_dl.nn.head.gan_head.GANHead(gan_loss=Pix2PixHD( (gan_loss): GANLoss( (loss): BCEWithLogitsLoss() ) (feature_matching_loss): L1Loss() ), reconstruction_loss=MSELoss(), reconstruction_loss_weight=100, postprocess={'input': <function detach>, 'prediction': <function detach>})[source]#
Bases:
BaseHead
GAN Task head.
- Parameters:
gan_loss=Pix2PixHD(scales=1) – Loss for optimizing GAN
reconstruction_loss=torch.nn.MSELoss() – Loss for optimizing generator’s image reconstructions
reconstruction_loss_weight=100 – Weighting of reconstruction loss
postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head