cyto_dl.nn.losses.gan_loss module#

class cyto_dl.nn.losses.gan_loss.GANLoss(gan_mode: str = 'vanilla', target_real_label: float = 1.0, target_fake_label: float = 0.0)[source]#

Bases: Module

Define different GAN objectives.

The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input.

Initialize the GANLoss class.

Parameters:
  • gan_mode (str=’vanilla’) – Type of GAN objective vanilla, lsgan, and wgangp are supported.

  • target_real_label (float=1.0) – label for a real image

  • target_fake_label (float=0.0) – label of a fake image

Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.

get_target_tensor(prediction: Tensor, target_is_real: bool)[source]#

Create label tensors with the same size as the input.

Parameters:
  • prediction (torch.Tensor) – Prediction output from a discriminator

  • target_is_real (bool) – if the ground truth label is for real images or fake images

Returns:

A label tensor filled with ground truth label, and with the size of input

class cyto_dl.nn.losses.gan_loss.Pix2PixHD(scales, loss_weights={'FM': 10, 'GAN': 1})[source]#

Bases: Module

get_feature_matching_loss(features)[source]#
get_gan_loss(features, feature_type)[source]#