Source code for cyto_dl.models.im2im.utils.noise_annealer
import torch
[docs]class NoiseAnnealer:
"""Anneals variance of gaussian noise of real and fake examples passed to discriminator, called
instance noise.
This makes the generator's task harder and increases support of the real and fake distributions
so they overlap, which has nice theoretical implications for the quality of the discriminator.
Also can be used as a curriculum learning technique by iteratively unblurring the target to
make segmentation harder over time.
"""
def __init__(self, annealing_steps: int = 5000, init_variance: float = 0.3):
"""
Parameters
----------
annealing_steps:int=5000
Number of steps to linearly anneal variance from `init_variance` to 0
init_variance:float=0.3
Initial variance of noise
"""
self.init_variance = init_variance
self.noise = init_variance
self.step_size = init_variance / annealing_steps
self.annealing_steps = annealing_steps
self._done = False
self.steps = 0
[docs] def update_noise(self):
if self.steps > self.annealing_steps:
self._done = True
else:
self.noise -= self.step_size
self.steps += 1
def __call__(self, img):
self.update_noise()
if self._done:
return img
else:
noise_tensor = torch.randn(img.shape) * self.noise
noise_tensor = noise_tensor.type_as(img)
return torch.add(img, noise_tensor)