Source code for cyto_dl.nn.discriminators.multi_scale_discriminator

from torch import nn

from .n_layer_discriminator import NLayerDiscriminator


[docs]class MultiScaleDiscriminator(nn.Module): """Modified version of Pix2PixHD discriminator, which returns discriminator activations at multiple spatial scales.""" def __init__(self, n_scales: int = 2, dim: int = 3, **kwargs): """ Parameters ---------- n_scales:int=2 Number of spatial scales to **kwargs Arguments to pass to NLayerDiscriminator """ super().__init__() if dim not in (2, 3): raise ValueError(f"dim must be 2 or 3, got {dim}") self.scales = range(n_scales) kwargs.update({"dim": dim}) self.discriminators = nn.ModuleDict( {str(scale): NLayerDiscriminator(**kwargs) for scale in self.scales} ) self.pooling_fn = nn.AvgPool3d if dim == 3 else nn.AvgPool2d
[docs] def forward(self, x, real, pred): features = {} for key, img in zip(["real", "pred"], [real, pred]): result = {} source_img = x.detach().clone() for scale in self.scales: result[scale] = self.discriminators[str(scale)]( img, source_img, requires_features=True, ) source_img = self.pooling_fn( kernel_size=3, padding=1, stride=2, count_include_pad=False )(source_img) img = self.pooling_fn(kernel_size=3, padding=1, stride=2, count_include_pad=False)( img ) features[key] = result return features