| | |
| |
|
| | |
| | |
| |
|
| | import torch.nn as nn |
| | from tricorder.torch.transforms import Interpolator |
| |
|
| | __author__ = "Soumick Chatterjee" |
| | __copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL" |
| | __credits__ = ["Soumick Chatterjee"] |
| |
|
| | __license__ = "apache-2.0" |
| | __version__ = "1.0.0" |
| | __email__ = "[email protected]" |
| | __status__ = "Published" |
| |
|
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, in_features, drop_prob=0.2): |
| | super(ResidualBlock, self).__init__() |
| |
|
| | conv_block = [layer_pad(1), |
| | layer_conv(in_features, in_features, 3), |
| | layer_norm(in_features), |
| | act_relu(), |
| | layer_drop(p=drop_prob, inplace=True), |
| | layer_pad(1), |
| | layer_conv(in_features, in_features, 3), |
| | layer_norm(in_features)] |
| |
|
| | self.conv_block = nn.Sequential(*conv_block) |
| |
|
| | def forward(self, x): |
| | return x + self.conv_block(x) |
| |
|
| |
|
| | class DownsamplingBlock(nn.Module): |
| | def __init__(self, in_features, out_features): |
| | super(DownsamplingBlock, self).__init__() |
| |
|
| | conv_block = [layer_conv(in_features, out_features, 3, stride=2, padding=1), |
| | layer_norm(out_features), |
| | act_relu()] |
| | self.conv_block = nn.Sequential(*conv_block) |
| |
|
| | def forward(self, x): |
| | return self.conv_block(x) |
| |
|
| |
|
| | class UpsamplingBlock(nn.Module): |
| | def __init__(self, in_features, out_features, mode="convtrans", interpolator=None, post_interp_convtrans=False): |
| | super(UpsamplingBlock, self).__init__() |
| |
|
| | self.interpolator = interpolator |
| | self.mode = mode |
| | self.post_interp_convtrans = post_interp_convtrans |
| | if self.post_interp_convtrans: |
| | self.post_conv = layer_conv(out_features, out_features, 1) |
| |
|
| | if mode == "convtrans": |
| | conv_block = [layer_convtrans( |
| | in_features, out_features, 3, stride=2, padding=1, output_padding=1), ] |
| | else: |
| | conv_block = [layer_pad(1), |
| | layer_conv(in_features, out_features, 3), ] |
| | conv_block += [layer_norm(out_features), |
| | act_relu()] |
| | self.conv_block = nn.Sequential(*conv_block) |
| |
|
| | def forward(self, x, out_shape=None): |
| | if self.mode == "convtrans": |
| | if self.post_interp_convtrans: |
| | x = self.conv_block(x) |
| | if x.shape[2:] != out_shape: |
| | return self.post_conv(self.interpolator(x, out_shape)) |
| | else: |
| | return x |
| | else: |
| | return self.conv_block(x) |
| | else: |
| | return self.conv_block(self.interpolator(x, out_shape)) |
| |
|
| |
|
| | class ReconResNetBase(nn.Module): |
| | def __init__(self, in_channels=1, out_channels=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2, |
| | is_replicatepad=0, out_act="sigmoid", forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False, is3D=False): |
| | super(ReconResNetBase, self).__init__() |
| |
|
| | layers = {} |
| | if is3D: |
| | layers["layer_conv"] = nn.Conv3d |
| | layers["layer_convtrans"] = nn.ConvTranspose3d |
| | if do_batchnorm: |
| | layers["layer_norm"] = nn.BatchNorm3d |
| | else: |
| | layers["layer_norm"] = nn.InstanceNorm3d |
| | layers["layer_drop"] = nn.Dropout3d |
| | if is_replicatepad == 0: |
| | layers["layer_pad"] = nn.ReflectionPad3d |
| | elif is_replicatepad == 1: |
| | layers["layer_pad"] = nn.ReplicationPad3d |
| | layers["interp_mode"] = 'trilinear' |
| | else: |
| | layers["layer_conv"] = nn.Conv2d |
| | layers["layer_convtrans"] = nn.ConvTranspose2d |
| | if do_batchnorm: |
| | layers["layer_norm"] = nn.BatchNorm2d |
| | else: |
| | layers["layer_norm"] = nn.InstanceNorm2d |
| | layers["layer_drop"] = nn.Dropout2d |
| | if is_replicatepad == 0: |
| | layers["layer_pad"] = nn.ReflectionPad2d |
| | elif is_replicatepad == 1: |
| | layers["layer_pad"] = nn.ReplicationPad2d |
| | layers["interp_mode"] = 'bilinear' |
| | if is_relu_leaky: |
| | layers["act_relu"] = nn.PReLU |
| | else: |
| | layers["act_relu"] = nn.ReLU |
| | globals().update(layers) |
| |
|
| | self.forwardV = forwardV |
| | self.upinterp_algo = upinterp_algo |
| |
|
| | interpolator = Interpolator( |
| | mode=layers["interp_mode"] if self.upinterp_algo == "convtrans" else self.upinterp_algo) |
| |
|
| | |
| | intialConv = [layer_pad(3), |
| | layer_conv(in_channels, starting_nfeatures, 7), |
| | layer_norm(starting_nfeatures), |
| | act_relu()] |
| |
|
| | |
| | downsam = [] |
| | in_features = starting_nfeatures |
| | out_features = in_features*2 |
| | for _ in range(updown_blocks): |
| | downsam.append(DownsamplingBlock(in_features, out_features)) |
| | in_features = out_features |
| | out_features = in_features*2 |
| |
|
| | |
| | resblocks = [] |
| | for _ in range(res_blocks): |
| | resblocks += [ResidualBlock(in_features, res_drop_prob)] |
| |
|
| | |
| | upsam = [] |
| | out_features = in_features//2 |
| | for _ in range(updown_blocks): |
| | upsam.append(UpsamplingBlock(in_features, out_features, |
| | self.upinterp_algo, interpolator, post_interp_convtrans)) |
| | in_features = out_features |
| | out_features = in_features//2 |
| |
|
| | |
| | finalconv = [layer_pad(3), |
| | layer_conv(starting_nfeatures, out_channels, 7), ] |
| |
|
| | if out_act == "sigmoid": |
| | finalconv += [nn.Sigmoid(), ] |
| | elif out_act == "relu": |
| | finalconv += [act_relu(), ] |
| | elif out_act == "tanh": |
| | finalconv += [nn.Tanh(), ] |
| |
|
| | self.intialConv = nn.Sequential(*intialConv) |
| | self.downsam = nn.ModuleList(downsam) |
| | self.resblocks = nn.Sequential(*resblocks) |
| | self.upsam = nn.ModuleList(upsam) |
| | self.finalconv = nn.Sequential(*finalconv) |
| |
|
| | if self.forwardV == 0: |
| | self.forward = self.forwardV0 |
| | elif self.forwardV == 1: |
| | self.forward = self.forwardV1 |
| | elif self.forwardV == 2: |
| | self.forward = self.forwardV2 |
| | elif self.forwardV == 3: |
| | self.forward = self.forwardV3 |
| | elif self.forwardV == 4: |
| | self.forward = self.forwardV4 |
| | elif self.forwardV == 5: |
| | self.forward = self.forwardV5 |
| |
|
| | def forwardV0(self, x): |
| | |
| | x = self.intialConv(x) |
| | shapes = [] |
| | for downblock in self.downsam: |
| | shapes.append(x.shape[2:]) |
| | x = downblock(x) |
| | x = self.resblocks(x) |
| | for i, upblock in enumerate(self.upsam): |
| | x = upblock(x, shapes[-1-i]) |
| | return self.finalconv(x) |
| |
|
| | def forwardV1(self, x): |
| | |
| | out = self.intialConv(x) |
| | shapes = [] |
| | for downblock in self.downsam: |
| | shapes.append(out.shape[2:]) |
| | out = downblock(out) |
| | out = self.resblocks(out) |
| | for i, upblock in enumerate(self.upsam): |
| | out = upblock(out, shapes[-1-i]) |
| | return x + self.finalconv(out) |
| |
|
| | def forwardV2(self, x): |
| | |
| | out = self.intialConv(x) |
| | shapes = [] |
| | for downblock in self.downsam: |
| | shapes.append(out.shape[2:]) |
| | out = downblock(out) |
| | out = out + self.resblocks(out) |
| | for i, upblock in enumerate(self.upsam): |
| | out = upblock(out, shapes[-1-i]) |
| | return x + self.finalconv(out) |
| |
|
| | def forwardV3(self, x): |
| | |
| | out = x + self.intialConv(x) |
| | shapes = [] |
| | for downblock in self.downsam: |
| | shapes.append(out.shape[2:]) |
| | out = downblock(out) |
| | out = out + self.resblocks(out) |
| | for i, upblock in enumerate(self.upsam): |
| | out = upblock(out, shapes[-1-i]) |
| | return x + self.finalconv(out) |
| |
|
| | def forwardV4(self, x): |
| | |
| | iniconv = x + self.intialConv(x) |
| | shapes = [] |
| | if len(self.downsam) > 0: |
| | for i, downblock in enumerate(self.downsam): |
| | if i == 0: |
| | shapes.append(iniconv.shape[2:]) |
| | out = downblock(iniconv) |
| | else: |
| | shapes.append(out.shape[2:]) |
| | out = downblock(out) |
| | else: |
| | out = iniconv |
| | out = out + self.resblocks(out) |
| | for i, upblock in enumerate(self.upsam): |
| | out = upblock(out, shapes[-1-i]) |
| | out = iniconv + out |
| | return x + self.finalconv(out) |
| |
|
| | def forwardV5(self, x): |
| | |
| | outs = [x + self.intialConv(x)] |
| | shapes = [] |
| | for i, downblock in enumerate(self.downsam): |
| | shapes.append(outs[-1].shape[2:]) |
| | outs.append(downblock(outs[-1])) |
| | outs[-1] = outs[-1] + self.resblocks(outs[-1]) |
| | for i, upblock in enumerate(self.upsam): |
| | outs[-1] = upblock(outs[-1], shapes[-1-i]) |
| | outs[-1] = outs[-2] + outs.pop() |
| | return x + self.finalconv(outs.pop()) |