cyto_dl.nn.head.gan_head module#

class cyto_dl.nn.head.gan_head.GANHead(gan_loss=Pix2PixHD(   (gan_loss): GANLoss(     (loss): BCEWithLogitsLoss()   )   (feature_matching_loss): L1Loss() ), reconstruction_loss=MSELoss(), reconstruction_loss_weight=100, postprocess={'input': <function detach>, 'prediction': <function detach>})[source]#

Bases: BaseHead

GAN Task head.

Parameters:
  • gan_loss=Pix2PixHD(scales=1) – Loss for optimizing GAN

  • reconstruction_loss=torch.nn.MSELoss() – Loss for optimizing generator’s image reconstructions

  • reconstruction_loss_weight=100 – Weighting of reconstruction loss

  • postprocess={“input” (detach, “prediction”: detach}) – Postprocessing for input and predictions of head

forward(x)[source]#
run_head(backbone_features, batch, stage, n_postprocess=1, discriminator=None, run_forward=True, y_hat=None)[source]#