import torch
[docs]class Net(torch.nn.Module):
def __init__(self):
super().__init__()
mult_chan = 32
depth = 4
self.net_recurse = _Net_recurse(
n_in_channels=1, mult_chan=mult_chan, depth=depth
)
self.conv_out = torch.nn.Conv2d(mult_chan, 1, kernel_size=3, padding=1)
[docs] def forward(self, x):
x_rec = self.net_recurse(x)
return self.conv_out(x_rec)
class _Net_recurse(torch.nn.Module):
def __init__(self, n_in_channels, mult_chan=2, depth=0):
"""Class for recursive definition of U-network.p
Parameters
----------
in_channels
Number of channels for input.
mult_chan
Factor to determine number of output channels
depth
If 0, this subnet will only be convolutions that double the channel
count.
"""
super().__init__()
self.depth = depth
n_out_channels = n_in_channels * mult_chan
self.sub_2conv_more = SubNet2Conv(n_in_channels, n_out_channels)
if depth > 0:
self.sub_2conv_less = SubNet2Conv(2 * n_out_channels, n_out_channels)
self.conv_down = torch.nn.Conv2d(
n_out_channels, n_out_channels, 2, stride=2
)
self.bn0 = torch.nn.BatchNorm2d(n_out_channels)
self.relu0 = torch.nn.ReLU()
self.convt = torch.nn.ConvTranspose2d(
2 * n_out_channels, n_out_channels, kernel_size=2, stride=2
)
self.bn1 = torch.nn.BatchNorm2d(n_out_channels)
self.relu1 = torch.nn.ReLU()
self.sub_u = _Net_recurse(n_out_channels, mult_chan=2, depth=(depth - 1))
def forward(self, x):
if self.depth == 0:
return self.sub_2conv_more(x)
else: # depth > 0
x_2conv_more = self.sub_2conv_more(x)
x_conv_down = self.conv_down(x_2conv_more)
x_bn0 = self.bn0(x_conv_down)
x_relu0 = self.relu0(x_bn0)
x_sub_u = self.sub_u(x_relu0)
x_convt = self.convt(x_sub_u)
x_bn1 = self.bn1(x_convt)
x_relu1 = self.relu1(x_bn1)
x_cat = torch.cat((x_2conv_more, x_relu1), 1) # concatenate
x_2conv_less = self.sub_2conv_less(x_cat)
return x_2conv_less
[docs]class SubNet2Conv(torch.nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv1 = torch.nn.Conv2d(n_in, n_out, kernel_size=3, padding=1)
self.bn1 = torch.nn.BatchNorm2d(n_out)
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(n_out, n_out, kernel_size=3, padding=1)
self.bn2 = torch.nn.BatchNorm2d(n_out)
self.relu2 = torch.nn.ReLU()
[docs] def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
return x