Spaces:
Runtime error
Runtime error
Commit
·
a870b86
1
Parent(s):
af92eb2
Create raft_core_utils_utils.py
Browse files- raft_core_utils_utils.py +82 -0
raft_core_utils_utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy import interpolate
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InputPadder:
|
| 8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
| 9 |
+
def __init__(self, dims, mode='sintel'):
|
| 10 |
+
self.ht, self.wd = dims[-2:]
|
| 11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
| 12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
| 13 |
+
if mode == 'sintel':
|
| 14 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
| 15 |
+
else:
|
| 16 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
| 17 |
+
|
| 18 |
+
def pad(self, *inputs):
|
| 19 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
| 20 |
+
|
| 21 |
+
def unpad(self,x):
|
| 22 |
+
ht, wd = x.shape[-2:]
|
| 23 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
| 24 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
| 25 |
+
|
| 26 |
+
def forward_interpolate(flow):
|
| 27 |
+
flow = flow.detach().cpu().numpy()
|
| 28 |
+
dx, dy = flow[0], flow[1]
|
| 29 |
+
|
| 30 |
+
ht, wd = dx.shape
|
| 31 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
| 32 |
+
|
| 33 |
+
x1 = x0 + dx
|
| 34 |
+
y1 = y0 + dy
|
| 35 |
+
|
| 36 |
+
x1 = x1.reshape(-1)
|
| 37 |
+
y1 = y1.reshape(-1)
|
| 38 |
+
dx = dx.reshape(-1)
|
| 39 |
+
dy = dy.reshape(-1)
|
| 40 |
+
|
| 41 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
| 42 |
+
x1 = x1[valid]
|
| 43 |
+
y1 = y1[valid]
|
| 44 |
+
dx = dx[valid]
|
| 45 |
+
dy = dy[valid]
|
| 46 |
+
|
| 47 |
+
flow_x = interpolate.griddata(
|
| 48 |
+
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
| 49 |
+
|
| 50 |
+
flow_y = interpolate.griddata(
|
| 51 |
+
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
| 52 |
+
|
| 53 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
| 54 |
+
return torch.from_numpy(flow).float()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
| 58 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
| 59 |
+
H, W = img.shape[-2:]
|
| 60 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
| 61 |
+
xgrid = 2*xgrid/(W-1) - 1
|
| 62 |
+
ygrid = 2*ygrid/(H-1) - 1
|
| 63 |
+
|
| 64 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 65 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
| 66 |
+
|
| 67 |
+
if mask:
|
| 68 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 69 |
+
return img, mask.float()
|
| 70 |
+
|
| 71 |
+
return img
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def coords_grid(batch, ht, wd, device):
|
| 75 |
+
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
|
| 76 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
| 77 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def upflow8(flow, mode='bilinear'):
|
| 81 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
| 82 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|