| | |
| |
|
| | |
| | """ |
| | |
| | Purpose : |
| | |
| | """ |
| | import torch.nn |
| | import torch |
| | import torch.nn as nn |
| |
|
| | __author__ = "Chethan Radhakrishna and Soumick Chatterjee" |
| | __credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"] |
| | __license__ = "GPL" |
| | __version__ = "1.0.0" |
| | __maintainer__ = "Chethan Radhakrishna" |
| | __email__ = "[email protected]" |
| | __status__ = "Development" |
| |
|
| |
|
| | class ConvBlock(nn.Module): |
| | """ |
| | Convolution Block |
| | """ |
| |
|
| | def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True): |
| | super(ConvBlock, self).__init__() |
| | self.conv = nn.Sequential( |
| | nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| | stride=stride, padding=padding, bias=bias), |
| | nn.PReLU(num_parameters=out_channels, init=0.25), |
| | |
| | nn.BatchNorm3d(num_features=out_channels), |
| | nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| | stride=stride, padding=padding, bias=bias), |
| | nn.PReLU(num_parameters=out_channels, init=0.25), |
| | |
| | nn.BatchNorm3d(num_features=out_channels)) |
| |
|
| | def forward(self, x): |
| | x = self.conv(x) |
| | return x |
| |
|
| |
|
| | class SeparableConvBlock(nn.Module): |
| | """ |
| | Convolution Block |
| | """ |
| |
|
| | def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True): |
| | super(SeparableConvBlock, self).__init__() |
| | self.conv = nn.Sequential( |
| | nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, |
| | bias=bias), |
| | nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| | stride=stride, padding=padding, bias=bias), |
| | nn.PReLU(num_parameters=out_channels, init=0.25), |
| | |
| | nn.BatchNorm3d(num_features=out_channels), |
| | nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, |
| | bias=bias), |
| | nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size, |
| | stride=stride, padding=padding, bias=bias), |
| | nn.PReLU(num_parameters=out_channels, init=0.25), |
| | |
| | nn.BatchNorm3d(num_features=out_channels)) |
| |
|
| | def forward(self, x): |
| | x = self.conv(x) |
| | return x |
| |
|
| |
|
| | class UpConv(nn.Module): |
| | """ |
| | Up Convolution Block |
| | """ |
| |
|
| | |
| | def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1): |
| | super(UpConv, self).__init__() |
| | self.up = nn.Sequential( |
| | nn.Upsample(scale_factor=2), |
| | nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, |
| | stride=stride, padding=padding), |
| | nn.BatchNorm3d(num_features=out_channels), |
| | nn.PReLU(num_parameters=out_channels, init=0.25)) |
| |
|
| | def forward(self, x): |
| | x = self.up(x) |
| | return x |
| |
|
| |
|
| | class AttentionBlock(nn.Module): |
| | """ |
| | Attention Block |
| | """ |
| |
|
| | def __init__(self, f_g, f_l, f_int): |
| | super(AttentionBlock, self).__init__() |
| |
|
| | self.W_g = nn.Sequential( |
| | nn.Conv3d(f_l, f_int, kernel_size=1, stride=1, padding=0, bias=True), |
| | nn.BatchNorm3d(f_int) |
| | ) |
| |
|
| | self.W_x = nn.Sequential( |
| | nn.Conv3d(f_g, f_int, kernel_size=1, stride=1, padding=0, bias=True), |
| | nn.BatchNorm3d(f_int) |
| | ) |
| |
|
| | self.psi = nn.Sequential( |
| | nn.Conv3d(f_int, 1, kernel_size=1, stride=1, padding=0, bias=True), |
| | nn.BatchNorm3d(1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | def forward(self, g, x): |
| | g1 = self.W_g(g) |
| | x1 = self.W_x(x) |
| | psi = self.relu(g1 + x1) |
| | psi = self.psi(psi) |
| | out = x * psi |
| | return out |
| |
|
| |
|
| | class AttUnet(nn.Module): |
| | """ |
| | Attention Unet implementation |
| | Paper: https://arxiv.org/abs/1804.03999 |
| | """ |
| |
|
| | def __init__(self, in_ch=1, out_ch=6, init_features=64): |
| | super(AttUnet, self).__init__() |
| |
|
| | n1 = init_features |
| | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] |
| |
|
| | self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2) |
| | self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2) |
| | self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2) |
| | self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2) |
| |
|
| | self.Conv1 = ConvBlock(in_ch, filters[0]) |
| | self.Conv2 = SeparableConvBlock(filters[0], filters[1]) |
| | self.Conv3 = SeparableConvBlock(filters[1], filters[2]) |
| | self.Conv4 = SeparableConvBlock(filters[2], filters[3]) |
| | self.Conv5 = SeparableConvBlock(filters[3], filters[4]) |
| |
|
| | self.Up5 = UpConv(filters[4], filters[3]) |
| | self.Att5 = AttentionBlock(f_g=filters[3], f_l=filters[3], f_int=filters[2]) |
| | self.Up_conv5 = SeparableConvBlock(filters[4], filters[3]) |
| |
|
| | self.Up4 = UpConv(filters[3], filters[2]) |
| | self.Att4 = AttentionBlock(f_g=filters[2], f_l=filters[2], f_int=filters[1]) |
| | self.Up_conv4 = SeparableConvBlock(filters[3], filters[2]) |
| |
|
| | self.Up3 = UpConv(filters[2], filters[1]) |
| | self.Att3 = AttentionBlock(f_g=filters[1], f_l=filters[1], f_int=filters[0]) |
| | self.Up_conv3 = SeparableConvBlock(filters[2], filters[1]) |
| |
|
| | self.Up2 = UpConv(filters[1], filters[0]) |
| | self.Att2 = AttentionBlock(f_g=filters[0], f_l=filters[0], f_int=32) |
| | self.Up_conv2 = ConvBlock(filters[1], filters[0]) |
| |
|
| | self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) |
| |
|
| | |
| |
|
| | def forward(self, x): |
| | e1 = self.Conv1(x) |
| |
|
| | e2 = self.Maxpool1(e1) |
| | e2 = self.Conv2(e2) |
| |
|
| | e3 = self.Maxpool2(e2) |
| | e3 = self.Conv3(e3) |
| |
|
| | e4 = self.Maxpool3(e3) |
| | e4 = self.Conv4(e4) |
| |
|
| | e5 = self.Maxpool4(e4) |
| | e5 = self.Conv5(e5) |
| |
|
| | d5 = self.Up5(e5) |
| | x4 = self.Att5(d5, e4) |
| | d5 = torch.cat((x4, d5), dim=1) |
| | d5 = self.Up_conv5(d5) |
| |
|
| | d4 = self.Up4(d5) |
| | x3 = self.Att4(d4, e3) |
| | d4 = torch.cat((x3, d4), dim=1) |
| | d4 = self.Up_conv4(d4) |
| |
|
| | d3 = self.Up3(d4) |
| | x2 = self.Att3(d3, e2) |
| | d3 = torch.cat((x2, d3), dim=1) |
| | d3 = self.Up_conv3(d3) |
| |
|
| | d2 = self.Up2(d3) |
| | x1 = self.Att2(d2, e1) |
| | d2 = torch.cat((x1, d2), dim=1) |
| | d2 = self.Up_conv2(d2) |
| |
|
| | out = self.Conv(d2) |
| |
|
| | |
| |
|
| | return out |