cyto_dl.nn.losses.gan_loss module#

class cyto_dl.nn.losses.gan_loss.GANLoss(gan_mode: str = 'vanilla', target_real_label: float = 1.0, target_fake_label: float = 0.0)[source]#

Bases: Module

Define different GAN objectives.

The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input.

Initialize the GANLoss class.

  • gan_mode (str=’vanilla’) – Type of GAN objective vanilla, lsgan, and wgangp are supported.

  • target_real_label (float=1.0) – label for a real image

  • target_fake_label (float=0.0) – label of a fake image

Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.

get_target_tensor(prediction: Tensor, target_is_real: bool)[source]#

Create label tensors with the same size as the input.

  • prediction (torch.Tensor) – Prediction output from a discriminator

  • target_is_real (bool) – if the ground truth label is for real images or fake images


A label tensor filled with ground truth label, and with the size of input

class cyto_dl.nn.losses.gan_loss.Pix2PixHD(scales, loss_weights={'FM': 10, 'GAN': 1})[source]#

Bases: Module

get_gan_loss(features, feature_type)[source]#