Upload 7 files
Browse files
models/__pycache__/afwm.cpython-310.pyc
ADDED
|
Binary file (6.52 kB). View file
|
|
|
models/__pycache__/networks.cpython-310.pyc
ADDED
|
Binary file (4.98 kB). View file
|
|
|
models/afwm.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from .correlation import correlation
|
| 5 |
+
|
| 6 |
+
def apply_offset(offset):
|
| 7 |
+
|
| 8 |
+
sizes = list(offset.size()[2:])
|
| 9 |
+
grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes])
|
| 10 |
+
grid_list = reversed(grid_list)
|
| 11 |
+
|
| 12 |
+
grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...]
|
| 13 |
+
for dim, grid in enumerate(grid_list)]
|
| 14 |
+
|
| 15 |
+
grid_list = [grid / ((size - 1.0) / 2.0) - 1.0
|
| 16 |
+
for grid, size in zip(grid_list, reversed(sizes))]
|
| 17 |
+
|
| 18 |
+
return torch.stack(grid_list, dim=-1)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ResBlock(nn.Module):
|
| 22 |
+
def __init__(self, in_channels):
|
| 23 |
+
super(ResBlock, self).__init__()
|
| 24 |
+
self.block = nn.Sequential(
|
| 25 |
+
nn.BatchNorm2d(in_channels),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
|
| 28 |
+
nn.BatchNorm2d(in_channels),
|
| 29 |
+
nn.ReLU(inplace=True),
|
| 30 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return self.block(x) + x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DownSample(nn.Module):
|
| 38 |
+
def __init__(self, in_channels, out_channels):
|
| 39 |
+
super(DownSample, self).__init__()
|
| 40 |
+
self.block= nn.Sequential(
|
| 41 |
+
nn.BatchNorm2d(in_channels),
|
| 42 |
+
nn.ReLU(inplace=True),
|
| 43 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
return self.block(x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FeatureEncoder(nn.Module):
|
| 52 |
+
def __init__(self, in_channels, chns=[64,128,256,256,256]):
|
| 53 |
+
super(FeatureEncoder, self).__init__()
|
| 54 |
+
self.encoders = []
|
| 55 |
+
for i, out_chns in enumerate(chns):
|
| 56 |
+
if i == 0:
|
| 57 |
+
encoder = nn.Sequential(DownSample(in_channels, out_chns),
|
| 58 |
+
ResBlock(out_chns),
|
| 59 |
+
ResBlock(out_chns))
|
| 60 |
+
else:
|
| 61 |
+
encoder = nn.Sequential(DownSample(chns[i-1], out_chns),
|
| 62 |
+
ResBlock(out_chns),
|
| 63 |
+
ResBlock(out_chns))
|
| 64 |
+
|
| 65 |
+
self.encoders.append(encoder)
|
| 66 |
+
|
| 67 |
+
self.encoders = nn.ModuleList(self.encoders)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
encoder_features = []
|
| 72 |
+
for encoder in self.encoders:
|
| 73 |
+
x = encoder(x)
|
| 74 |
+
encoder_features.append(x)
|
| 75 |
+
return encoder_features
|
| 76 |
+
|
| 77 |
+
class RefinePyramid(nn.Module):
|
| 78 |
+
def __init__(self, chns=[64,128,256,256,256], fpn_dim=256):
|
| 79 |
+
super(RefinePyramid, self).__init__()
|
| 80 |
+
self.chns = chns
|
| 81 |
+
|
| 82 |
+
self.adaptive = []
|
| 83 |
+
for in_chns in list(reversed(chns)):
|
| 84 |
+
adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
|
| 85 |
+
self.adaptive.append(adaptive_layer)
|
| 86 |
+
self.adaptive = nn.ModuleList(self.adaptive)
|
| 87 |
+
|
| 88 |
+
self.smooth = []
|
| 89 |
+
for i in range(len(chns)):
|
| 90 |
+
smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1)
|
| 91 |
+
self.smooth.append(smooth_layer)
|
| 92 |
+
self.smooth = nn.ModuleList(self.smooth)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
conv_ftr_list = x
|
| 96 |
+
|
| 97 |
+
feature_list = []
|
| 98 |
+
last_feature = None
|
| 99 |
+
for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
|
| 100 |
+
feature = self.adaptive[i](conv_ftr)
|
| 101 |
+
|
| 102 |
+
if last_feature is not None:
|
| 103 |
+
feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest')
|
| 104 |
+
|
| 105 |
+
feature = self.smooth[i](feature)
|
| 106 |
+
last_feature = feature
|
| 107 |
+
feature_list.append(feature)
|
| 108 |
+
|
| 109 |
+
return tuple(reversed(feature_list))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AFlowNet(nn.Module):
|
| 113 |
+
def __init__(self, num_pyramid, fpn_dim=256):
|
| 114 |
+
super(AFlowNet, self).__init__()
|
| 115 |
+
self.netMain = []
|
| 116 |
+
self.netRefine = []
|
| 117 |
+
for i in range(num_pyramid):
|
| 118 |
+
netMain_layer = torch.nn.Sequential(
|
| 119 |
+
torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 120 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
| 121 |
+
torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 122 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
| 123 |
+
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
| 124 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
| 125 |
+
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
netRefine_layer = torch.nn.Sequential(
|
| 129 |
+
torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1),
|
| 130 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
| 131 |
+
torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
| 132 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
| 133 |
+
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
| 134 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
| 135 |
+
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
| 136 |
+
)
|
| 137 |
+
self.netMain.append(netMain_layer)
|
| 138 |
+
self.netRefine.append(netRefine_layer)
|
| 139 |
+
|
| 140 |
+
self.netMain = nn.ModuleList(self.netMain)
|
| 141 |
+
self.netRefine = nn.ModuleList(self.netRefine)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def forward(self, x, x_warps, x_conds, warp_feature=True):
|
| 145 |
+
last_flow = None
|
| 146 |
+
|
| 147 |
+
for i in range(len(x_warps)):
|
| 148 |
+
x_warp = x_warps[len(x_warps) - 1 - i]
|
| 149 |
+
x_cond = x_conds[len(x_warps) - 1 - i]
|
| 150 |
+
|
| 151 |
+
if last_flow is not None and warp_feature:
|
| 152 |
+
x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1),
|
| 153 |
+
mode='bilinear', padding_mode='border')
|
| 154 |
+
else:
|
| 155 |
+
x_warp_after = x_warp
|
| 156 |
+
|
| 157 |
+
tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False)
|
| 158 |
+
flow = self.netMain[i](tenCorrelation)
|
| 159 |
+
flow = apply_offset(flow)
|
| 160 |
+
|
| 161 |
+
if last_flow is not None:
|
| 162 |
+
flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
|
| 163 |
+
else:
|
| 164 |
+
flow = flow.permute(0, 3, 1, 2)
|
| 165 |
+
|
| 166 |
+
last_flow = flow
|
| 167 |
+
x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
|
| 168 |
+
concat = torch.cat([x_warp,x_cond],1)
|
| 169 |
+
flow = self.netRefine[i](concat)
|
| 170 |
+
flow = apply_offset(flow)
|
| 171 |
+
flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
|
| 172 |
+
|
| 173 |
+
last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear')
|
| 174 |
+
|
| 175 |
+
x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1),
|
| 176 |
+
mode='bilinear', padding_mode='border')
|
| 177 |
+
return x_warp, last_flow,
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class AFWM(nn.Module):
|
| 181 |
+
|
| 182 |
+
def __init__(self, opt, input_nc):
|
| 183 |
+
super(AFWM, self).__init__()
|
| 184 |
+
num_filters = [64,128,256,256,256]
|
| 185 |
+
self.image_features = FeatureEncoder(3, num_filters)
|
| 186 |
+
self.cond_features = FeatureEncoder(input_nc, num_filters)
|
| 187 |
+
self.image_FPN = RefinePyramid(num_filters)
|
| 188 |
+
self.cond_FPN = RefinePyramid(num_filters)
|
| 189 |
+
self.aflow_net = AFlowNet(len(num_filters))
|
| 190 |
+
|
| 191 |
+
def forward(self, cond_input, image_input):
|
| 192 |
+
cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential
|
| 193 |
+
image_pyramids = self.image_FPN(self.image_features(image_input))
|
| 194 |
+
|
| 195 |
+
x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids)
|
| 196 |
+
|
| 197 |
+
return x_warp, last_flow
|
| 198 |
+
|
models/correlation/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This is an adaptation of the <a href="https://github.com/lmb-freiburg/flownet2">FlowNet2 implementation</a> in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the <a href="https://github.com/lmb-freiburg/flownet2#license-and-citation">licensing terms</a> of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
|
models/correlation/__pycache__/correlation.cpython-310.pyc
ADDED
|
Binary file (13.7 kB). View file
|
|
|
models/correlation/correlation.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import cupy
|
| 6 |
+
import math
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
kernel_Correlation_rearrange = '''
|
| 10 |
+
extern "C" __global__ void kernel_Correlation_rearrange(
|
| 11 |
+
const int n,
|
| 12 |
+
const float* input,
|
| 13 |
+
float* output
|
| 14 |
+
) {
|
| 15 |
+
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
| 16 |
+
|
| 17 |
+
if (intIndex >= n) {
|
| 18 |
+
return;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
int intSample = blockIdx.z;
|
| 22 |
+
int intChannel = blockIdx.y;
|
| 23 |
+
|
| 24 |
+
float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
| 25 |
+
|
| 26 |
+
__syncthreads();
|
| 27 |
+
|
| 28 |
+
int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}};
|
| 29 |
+
int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}};
|
| 30 |
+
int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX;
|
| 31 |
+
|
| 32 |
+
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
|
| 33 |
+
}
|
| 34 |
+
'''
|
| 35 |
+
|
| 36 |
+
kernel_Correlation_updateOutput = '''
|
| 37 |
+
extern "C" __global__ void kernel_Correlation_updateOutput(
|
| 38 |
+
const int n,
|
| 39 |
+
const float* rbot0,
|
| 40 |
+
const float* rbot1,
|
| 41 |
+
float* top
|
| 42 |
+
) {
|
| 43 |
+
extern __shared__ char patch_data_char[];
|
| 44 |
+
|
| 45 |
+
float *patch_data = (float *)patch_data_char;
|
| 46 |
+
|
| 47 |
+
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
| 48 |
+
int x1 = (blockIdx.x + 3) * {{intStride}};
|
| 49 |
+
int y1 = (blockIdx.y + 3) * {{intStride}};
|
| 50 |
+
int item = blockIdx.z;
|
| 51 |
+
int ch_off = threadIdx.x;
|
| 52 |
+
|
| 53 |
+
// Load 3D patch into shared shared memory
|
| 54 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
| 55 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
| 56 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
| 57 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
| 58 |
+
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
| 59 |
+
int idxPatchData = ji_off + ch;
|
| 60 |
+
patch_data[idxPatchData] = rbot0[idx1];
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
__syncthreads();
|
| 66 |
+
|
| 67 |
+
__shared__ float sum[32];
|
| 68 |
+
|
| 69 |
+
// Compute correlation
|
| 70 |
+
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
| 71 |
+
sum[ch_off] = 0;
|
| 72 |
+
|
| 73 |
+
int s2o = (top_channel % 7 - 3) * {{intStride}};
|
| 74 |
+
int s2p = (top_channel / 7 - 3) * {{intStride}};
|
| 75 |
+
|
| 76 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
| 77 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
| 78 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
| 79 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
| 80 |
+
int x2 = x1 + s2o;
|
| 81 |
+
int y2 = y1 + s2p;
|
| 82 |
+
|
| 83 |
+
int idxPatchData = ji_off + ch;
|
| 84 |
+
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
| 85 |
+
|
| 86 |
+
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
__syncthreads();
|
| 92 |
+
|
| 93 |
+
if (ch_off == 0) {
|
| 94 |
+
float total_sum = 0;
|
| 95 |
+
for (int idx = 0; idx < 32; idx++) {
|
| 96 |
+
total_sum += sum[idx];
|
| 97 |
+
}
|
| 98 |
+
const int sumelems = SIZE_3(rbot0);
|
| 99 |
+
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
| 100 |
+
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
'''
|
| 105 |
+
|
| 106 |
+
kernel_Correlation_updateGradFirst = '''
|
| 107 |
+
#define ROUND_OFF 50000
|
| 108 |
+
|
| 109 |
+
extern "C" __global__ void kernel_Correlation_updateGradFirst(
|
| 110 |
+
const int n,
|
| 111 |
+
const int intSample,
|
| 112 |
+
const float* rbot0,
|
| 113 |
+
const float* rbot1,
|
| 114 |
+
const float* gradOutput,
|
| 115 |
+
float* gradFirst,
|
| 116 |
+
float* gradSecond
|
| 117 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
| 118 |
+
int n = intIndex % SIZE_1(gradFirst); // channels
|
| 119 |
+
int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos
|
| 120 |
+
int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos
|
| 121 |
+
|
| 122 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
| 123 |
+
// We use a large offset, for the inner part not to become negative.
|
| 124 |
+
const int round_off = ROUND_OFF;
|
| 125 |
+
const int round_off_s1 = {{intStride}} * round_off;
|
| 126 |
+
|
| 127 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
| 128 |
+
int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
|
| 129 |
+
int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
|
| 130 |
+
|
| 131 |
+
// Same here:
|
| 132 |
+
int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}}
|
| 133 |
+
int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}}
|
| 134 |
+
|
| 135 |
+
float sum = 0;
|
| 136 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
| 137 |
+
xmin = max(0,xmin);
|
| 138 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
| 139 |
+
|
| 140 |
+
ymin = max(0,ymin);
|
| 141 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
| 142 |
+
|
| 143 |
+
for (int p = -3; p <= 3; p++) {
|
| 144 |
+
for (int o = -3; o <= 3; o++) {
|
| 145 |
+
// Get rbot1 data:
|
| 146 |
+
int s2o = {{intStride}} * o;
|
| 147 |
+
int s2p = {{intStride}} * p;
|
| 148 |
+
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
| 149 |
+
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
| 150 |
+
|
| 151 |
+
// Index offset for gradOutput in following loops:
|
| 152 |
+
int op = (p+3) * 7 + (o+3); // index[o,p]
|
| 153 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
| 154 |
+
|
| 155 |
+
for (int y = ymin; y <= ymax; y++) {
|
| 156 |
+
for (int x = xmin; x <= xmax; x++) {
|
| 157 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
| 158 |
+
sum += gradOutput[idxgradOutput] * bot1tmp;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
| 164 |
+
const int sumelems = SIZE_1(gradFirst);
|
| 165 |
+
const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}});
|
| 166 |
+
gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
|
| 167 |
+
} }
|
| 168 |
+
'''
|
| 169 |
+
|
| 170 |
+
kernel_Correlation_updateGradSecond = '''
|
| 171 |
+
#define ROUND_OFF 50000
|
| 172 |
+
|
| 173 |
+
extern "C" __global__ void kernel_Correlation_updateGradSecond(
|
| 174 |
+
const int n,
|
| 175 |
+
const int intSample,
|
| 176 |
+
const float* rbot0,
|
| 177 |
+
const float* rbot1,
|
| 178 |
+
const float* gradOutput,
|
| 179 |
+
float* gradFirst,
|
| 180 |
+
float* gradSecond
|
| 181 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
| 182 |
+
int n = intIndex % SIZE_1(gradSecond); // channels
|
| 183 |
+
int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos
|
| 184 |
+
int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos
|
| 185 |
+
|
| 186 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
| 187 |
+
// We use a large offset, for the inner part not to become negative.
|
| 188 |
+
const int round_off = ROUND_OFF;
|
| 189 |
+
const int round_off_s1 = {{intStride}} * round_off;
|
| 190 |
+
|
| 191 |
+
float sum = 0;
|
| 192 |
+
for (int p = -3; p <= 3; p++) {
|
| 193 |
+
for (int o = -3; o <= 3; o++) {
|
| 194 |
+
int s2o = {{intStride}} * o;
|
| 195 |
+
int s2p = {{intStride}} * p;
|
| 196 |
+
|
| 197 |
+
//Get X,Y ranges and clamp
|
| 198 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
| 199 |
+
int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
|
| 200 |
+
int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
|
| 201 |
+
|
| 202 |
+
// Same here:
|
| 203 |
+
int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}}
|
| 204 |
+
int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}}
|
| 205 |
+
|
| 206 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
| 207 |
+
xmin = max(0,xmin);
|
| 208 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
| 209 |
+
|
| 210 |
+
ymin = max(0,ymin);
|
| 211 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
| 212 |
+
|
| 213 |
+
// Get rbot0 data:
|
| 214 |
+
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
| 215 |
+
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
| 216 |
+
|
| 217 |
+
// Index offset for gradOutput in following loops:
|
| 218 |
+
int op = (p+3) * 7 + (o+3); // index[o,p]
|
| 219 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
| 220 |
+
|
| 221 |
+
for (int y = ymin; y <= ymax; y++) {
|
| 222 |
+
for (int x = xmin; x <= xmax; x++) {
|
| 223 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
| 224 |
+
sum += gradOutput[idxgradOutput] * bot0tmp;
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
const int sumelems = SIZE_1(gradSecond);
|
| 231 |
+
const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}});
|
| 232 |
+
gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
|
| 233 |
+
} }
|
| 234 |
+
'''
|
| 235 |
+
|
| 236 |
+
def cupy_kernel(strFunction, objVariables):
|
| 237 |
+
strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride']))
|
| 238 |
+
|
| 239 |
+
while True:
|
| 240 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
| 241 |
+
|
| 242 |
+
if objMatch is None:
|
| 243 |
+
break
|
| 244 |
+
# end
|
| 245 |
+
|
| 246 |
+
intArg = int(objMatch.group(2))
|
| 247 |
+
|
| 248 |
+
strTensor = objMatch.group(4)
|
| 249 |
+
intSizes = objVariables[strTensor].size()
|
| 250 |
+
|
| 251 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
|
| 252 |
+
# end
|
| 253 |
+
|
| 254 |
+
while True:
|
| 255 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
| 256 |
+
|
| 257 |
+
if objMatch is None:
|
| 258 |
+
break
|
| 259 |
+
# end
|
| 260 |
+
|
| 261 |
+
intArgs = int(objMatch.group(2))
|
| 262 |
+
strArgs = objMatch.group(4).split(',')
|
| 263 |
+
|
| 264 |
+
strTensor = strArgs[0]
|
| 265 |
+
intStrides = objVariables[strTensor].stride()
|
| 266 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
|
| 267 |
+
|
| 268 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
| 269 |
+
# end
|
| 270 |
+
|
| 271 |
+
return strKernel
|
| 272 |
+
# end
|
| 273 |
+
|
| 274 |
+
@cupy.util.memoize(for_each_device=True)
|
| 275 |
+
def cupy_launch(strFunction, strKernel):
|
| 276 |
+
return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
|
| 277 |
+
# end
|
| 278 |
+
|
| 279 |
+
class _FunctionCorrelation(torch.autograd.Function):
|
| 280 |
+
@staticmethod
|
| 281 |
+
def forward(self, first, second, intStride):
|
| 282 |
+
rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
|
| 283 |
+
rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
|
| 284 |
+
|
| 285 |
+
self.save_for_backward(first, second, rbot0, rbot1)
|
| 286 |
+
|
| 287 |
+
self.intStride = intStride
|
| 288 |
+
|
| 289 |
+
assert(first.is_contiguous() == True)
|
| 290 |
+
assert(second.is_contiguous() == True)
|
| 291 |
+
|
| 292 |
+
output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ])
|
| 293 |
+
|
| 294 |
+
if first.is_cuda == True:
|
| 295 |
+
n = first.shape[2] * first.shape[3]
|
| 296 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
| 297 |
+
'intStride': self.intStride,
|
| 298 |
+
'input': first,
|
| 299 |
+
'output': rbot0
|
| 300 |
+
}))(
|
| 301 |
+
grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
|
| 302 |
+
block=tuple([ 16, 1, 1 ]),
|
| 303 |
+
args=[ n, first.data_ptr(), rbot0.data_ptr() ]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
n = second.shape[2] * second.shape[3]
|
| 307 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
| 308 |
+
'intStride': self.intStride,
|
| 309 |
+
'input': second,
|
| 310 |
+
'output': rbot1
|
| 311 |
+
}))(
|
| 312 |
+
grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
|
| 313 |
+
block=tuple([ 16, 1, 1 ]),
|
| 314 |
+
args=[ n, second.data_ptr(), rbot1.data_ptr() ]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
n = output.shape[1] * output.shape[2] * output.shape[3]
|
| 318 |
+
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
|
| 319 |
+
'intStride': self.intStride,
|
| 320 |
+
'rbot0': rbot0,
|
| 321 |
+
'rbot1': rbot1,
|
| 322 |
+
'top': output
|
| 323 |
+
}))(
|
| 324 |
+
grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
|
| 325 |
+
block=tuple([ 32, 1, 1 ]),
|
| 326 |
+
shared_mem=first.shape[1] * 4,
|
| 327 |
+
args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
elif first.is_cuda == False:
|
| 331 |
+
raise NotImplementedError()
|
| 332 |
+
|
| 333 |
+
# end
|
| 334 |
+
|
| 335 |
+
return output
|
| 336 |
+
# end
|
| 337 |
+
|
| 338 |
+
@staticmethod
|
| 339 |
+
def backward(self, gradOutput):
|
| 340 |
+
first, second, rbot0, rbot1 = self.saved_tensors
|
| 341 |
+
|
| 342 |
+
assert(gradOutput.is_contiguous() == True)
|
| 343 |
+
|
| 344 |
+
gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
|
| 345 |
+
gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
|
| 346 |
+
|
| 347 |
+
if first.is_cuda == True:#
|
| 348 |
+
if gradFirst is not None:
|
| 349 |
+
for intSample in range(first.shape[0]):
|
| 350 |
+
n = first.shape[1] * first.shape[2] * first.shape[3]
|
| 351 |
+
cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
|
| 352 |
+
'intStride': self.intStride,
|
| 353 |
+
'rbot0': rbot0,
|
| 354 |
+
'rbot1': rbot1,
|
| 355 |
+
'gradOutput': gradOutput,
|
| 356 |
+
'gradFirst': gradFirst,
|
| 357 |
+
'gradSecond': None
|
| 358 |
+
}))(
|
| 359 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
| 360 |
+
block=tuple([ 512, 1, 1 ]),
|
| 361 |
+
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
|
| 362 |
+
)
|
| 363 |
+
# end
|
| 364 |
+
# end
|
| 365 |
+
|
| 366 |
+
if gradSecond is not None:
|
| 367 |
+
for intSample in range(first.shape[0]):
|
| 368 |
+
n = first.shape[1] * first.shape[2] * first.shape[3]
|
| 369 |
+
cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
|
| 370 |
+
'intStride': self.intStride,
|
| 371 |
+
'rbot0': rbot0,
|
| 372 |
+
'rbot1': rbot1,
|
| 373 |
+
'gradOutput': gradOutput,
|
| 374 |
+
'gradFirst': None,
|
| 375 |
+
'gradSecond': gradSecond
|
| 376 |
+
}))(
|
| 377 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
| 378 |
+
block=tuple([ 512, 1, 1 ]),
|
| 379 |
+
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
|
| 380 |
+
)
|
| 381 |
+
# end
|
| 382 |
+
# end
|
| 383 |
+
|
| 384 |
+
elif first.is_cuda == False:
|
| 385 |
+
raise NotImplementedError()
|
| 386 |
+
|
| 387 |
+
# end
|
| 388 |
+
|
| 389 |
+
return gradFirst, gradSecond, None
|
| 390 |
+
# end
|
| 391 |
+
# end
|
| 392 |
+
|
| 393 |
+
def FunctionCorrelation(tenFirst, tenSecond, intStride):
|
| 394 |
+
return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
|
| 395 |
+
# end
|
| 396 |
+
|
| 397 |
+
class ModuleCorrelation(torch.nn.Module):
|
| 398 |
+
def __init__(self):
|
| 399 |
+
super(ModuleCorrelation, self).__init__()
|
| 400 |
+
# end
|
| 401 |
+
|
| 402 |
+
def forward(self, tenFirst, tenSecond, intStride):
|
| 403 |
+
return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
|
| 404 |
+
# end
|
| 405 |
+
# end
|
models/networks.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.parallel
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class UnetSkipConnectionBlock(nn.Module):
|
| 7 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
| 8 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 9 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
| 10 |
+
self.outermost = outermost
|
| 11 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 12 |
+
|
| 13 |
+
if input_nc is None:
|
| 14 |
+
input_nc = outer_nc
|
| 15 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
| 16 |
+
stride=2, padding=1, bias=use_bias)
|
| 17 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
| 18 |
+
uprelu = nn.ReLU(True)
|
| 19 |
+
if norm_layer != None:
|
| 20 |
+
downnorm = norm_layer(inner_nc)
|
| 21 |
+
upnorm = norm_layer(outer_nc)
|
| 22 |
+
|
| 23 |
+
if outermost:
|
| 24 |
+
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
| 25 |
+
upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
| 26 |
+
down = [downconv]
|
| 27 |
+
up = [uprelu, upsample, upconv]
|
| 28 |
+
model = down + [submodule] + up
|
| 29 |
+
elif innermost:
|
| 30 |
+
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
| 31 |
+
upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
| 32 |
+
down = [downrelu, downconv]
|
| 33 |
+
if norm_layer == None:
|
| 34 |
+
up = [uprelu, upsample, upconv]
|
| 35 |
+
else:
|
| 36 |
+
up = [uprelu, upsample, upconv, upnorm]
|
| 37 |
+
model = down + up
|
| 38 |
+
else:
|
| 39 |
+
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
| 40 |
+
upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
| 41 |
+
if norm_layer == None:
|
| 42 |
+
down = [downrelu, downconv]
|
| 43 |
+
up = [uprelu, upsample, upconv]
|
| 44 |
+
else:
|
| 45 |
+
down = [downrelu, downconv, downnorm]
|
| 46 |
+
up = [uprelu, upsample, upconv, upnorm]
|
| 47 |
+
|
| 48 |
+
if use_dropout:
|
| 49 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
| 50 |
+
else:
|
| 51 |
+
model = down + [submodule] + up
|
| 52 |
+
|
| 53 |
+
self.model = nn.Sequential(*model)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
if self.outermost:
|
| 57 |
+
return self.model(x)
|
| 58 |
+
else:
|
| 59 |
+
return torch.cat([x, self.model(x)], 1)
|
| 60 |
+
|
| 61 |
+
class ResidualBlock(nn.Module):
|
| 62 |
+
def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
|
| 63 |
+
super(ResidualBlock, self).__init__()
|
| 64 |
+
self.relu = nn.ReLU(True)
|
| 65 |
+
if norm_layer == None:
|
| 66 |
+
self.block = nn.Sequential(
|
| 67 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
| 68 |
+
nn.ReLU(inplace=True),
|
| 69 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
self.block = nn.Sequential(
|
| 73 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
| 74 |
+
norm_layer(in_features),
|
| 75 |
+
nn.ReLU(inplace=True),
|
| 76 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
| 77 |
+
norm_layer(in_features)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
residual = x
|
| 82 |
+
out = self.block(x)
|
| 83 |
+
out += residual
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
return out
|
| 86 |
+
|
| 87 |
+
class ResUnetGenerator(nn.Module):
|
| 88 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
| 89 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 90 |
+
super(ResUnetGenerator, self).__init__()
|
| 91 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
| 92 |
+
|
| 93 |
+
for i in range(num_downs - 5):
|
| 94 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
| 95 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 96 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 97 |
+
unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
| 98 |
+
unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
| 99 |
+
|
| 100 |
+
self.model = unet_block
|
| 101 |
+
|
| 102 |
+
def forward(self, input):
|
| 103 |
+
return self.model(input)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ResUnetSkipConnectionBlock(nn.Module):
|
| 107 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
| 108 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
| 109 |
+
super(ResUnetSkipConnectionBlock, self).__init__()
|
| 110 |
+
self.outermost = outermost
|
| 111 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
| 112 |
+
|
| 113 |
+
if input_nc is None:
|
| 114 |
+
input_nc = outer_nc
|
| 115 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
|
| 116 |
+
stride=2, padding=1, bias=use_bias)
|
| 117 |
+
|
| 118 |
+
res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
|
| 119 |
+
res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
|
| 120 |
+
|
| 121 |
+
downrelu = nn.ReLU(True)
|
| 122 |
+
uprelu = nn.ReLU(True)
|
| 123 |
+
if norm_layer != None:
|
| 124 |
+
downnorm = norm_layer(inner_nc)
|
| 125 |
+
upnorm = norm_layer(outer_nc)
|
| 126 |
+
|
| 127 |
+
if outermost:
|
| 128 |
+
upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
| 129 |
+
upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
| 130 |
+
down = [downconv, downrelu] + res_downconv
|
| 131 |
+
up = [upsample, upconv]
|
| 132 |
+
model = down + [submodule] + up
|
| 133 |
+
elif innermost:
|
| 134 |
+
upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
| 135 |
+
upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
| 136 |
+
down = [downconv, downrelu] + res_downconv
|
| 137 |
+
if norm_layer == None:
|
| 138 |
+
up = [upsample, upconv, uprelu] + res_upconv
|
| 139 |
+
else:
|
| 140 |
+
up = [upsample, upconv, upnorm, uprelu] + res_upconv
|
| 141 |
+
model = down + up
|
| 142 |
+
else:
|
| 143 |
+
upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
| 144 |
+
upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
| 145 |
+
if norm_layer == None:
|
| 146 |
+
down = [downconv, downrelu] + res_downconv
|
| 147 |
+
up = [upsample, upconv, uprelu] + res_upconv
|
| 148 |
+
else:
|
| 149 |
+
down = [downconv, downnorm, downrelu] + res_downconv
|
| 150 |
+
up = [upsample, upconv, upnorm, uprelu] + res_upconv
|
| 151 |
+
|
| 152 |
+
if use_dropout:
|
| 153 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
| 154 |
+
else:
|
| 155 |
+
model = down + [submodule] + up
|
| 156 |
+
|
| 157 |
+
self.model = nn.Sequential(*model)
|
| 158 |
+
|
| 159 |
+
def forward(self, x):
|
| 160 |
+
if self.outermost:
|
| 161 |
+
return self.model(x)
|
| 162 |
+
else:
|
| 163 |
+
return torch.cat([x, self.model(x)], 1)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def save_checkpoint(model, save_path):
|
| 167 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
| 168 |
+
os.makedirs(os.path.dirname(save_path))
|
| 169 |
+
torch.save(model.state_dict(), save_path)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def load_checkpoint(model, checkpoint_path):
|
| 173 |
+
|
| 174 |
+
if not os.path.exists(checkpoint_path):
|
| 175 |
+
print('No checkpoint!')
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
checkpoint = torch.load(checkpoint_path)
|
| 179 |
+
checkpoint_new = model.state_dict()
|
| 180 |
+
for param in checkpoint_new:
|
| 181 |
+
checkpoint_new[param] = checkpoint[param]
|
| 182 |
+
|
| 183 |
+
model.load_state_dict(checkpoint_new)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|