Source code for cyto_dl.nn.head.gan_head_superres

from typing import Callable

import numpy as np
import torch

from cyto_dl.models.im2im.utils.postprocessing import detach
from cyto_dl.nn.losses import Pix2PixHD

from .gan_head import GANHead
from .res_blocks_head import ResBlocksHead


[docs]class GANHead_resize(GANHead, ResBlocksHead): """Inherit run_head from GANHead, use __init__ and forward of ResBlocksHead.""" def __init__( self, in_channels: int, out_channels: int, gan_loss=Pix2PixHD(scales=1), reconstruction_loss=torch.nn.MSELoss(), reconstruction_loss_weight=100, postprocess={"input": detach, "prediction": detach}, final_act: Callable = torch.nn.Identity(), resolution="lr", spatial_dims=3, n_convs=1, dropout=0.0, upsample_method="pixelshuffle", upsample_ratio=None, first_layer=torch.nn.Identity(), dense: bool = False, ): """ Parameters ---------- gan_loss=Pix2PixHD(scales=1) Loss for optimizing GAN reconstruction_loss=torch.nn.MSELoss() Loss for optimizing generator's image reconstructions reconstruction_loss_weight=100 Weighting of reconstruction loss postprocess={"input": detach, "prediction": detach} Postprocessing for `input` and `predictions` of head """ ResBlocksHead.__init__( self, loss=None, in_channels=in_channels, out_channels=out_channels, final_act=final_act, postprocess=postprocess, resolution=resolution, spatial_dims=spatial_dims, n_convs=n_convs, dropout=dropout, upsample_method=upsample_method, upsample_ratio=upsample_ratio, first_layer=first_layer, dense=dense, ) self.gan_loss = gan_loss self.reconstruction_loss = reconstruction_loss self.reconstruction_loss_weight = reconstruction_loss_weight def _ensure_same_shape(self, x, y): min_shape = np.minimum(x.shape, y.shape) x = x[:, :, : min_shape[2], : min_shape[3], : min_shape[4]] y = y[:, :, : min_shape[2], : min_shape[3], : min_shape[4]] return x, y def _calculate_loss(self, y_hat, batch, discriminator): batch[self.head_name], y_hat = self._ensure_same_shape(batch[self.head_name], y_hat) return GANHead._calculate_loss(self, y_hat, batch, discriminator)
[docs] def forward(self, x): return ResBlocksHead.forward(self, x)