cyto_dl.nn.losses.gan_loss module#
- class cyto_dl.nn.losses.gan_loss.GANLoss(gan_mode: str = 'vanilla', target_real_label: float = 1.0, target_fake_label: float = 0.0)[source]#
Bases:
Module
Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input.
Initialize the GANLoss class.
- Parameters:
gan_mode (str=’vanilla’) – Type of GAN objective vanilla, lsgan, and wgangp are supported.
target_real_label (float=1.0) – label for a real image
target_fake_label (float=0.0) – label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
- get_target_tensor(prediction: Tensor, target_is_real: bool)[source]#
Create label tensors with the same size as the input.
- Parameters:
prediction (torch.Tensor) – Prediction output from a discriminator
target_is_real (bool) – if the ground truth label is for real images or fake images
- Returns:
A label tensor filled with ground truth label, and with the size of input