Update pintar.py
Browse files
pintar.py
CHANGED
|
@@ -29,6 +29,8 @@ def Normalize(inputs):
|
|
| 29 |
l = inputs[:, :, 0:1]
|
| 30 |
ab = inputs[:, :, 1:3]
|
| 31 |
l = l - 50
|
|
|
|
|
|
|
| 32 |
lab = np.concatenate((l, ab), 2)
|
| 33 |
return lab.astype('float32')
|
| 34 |
|
|
@@ -77,12 +79,12 @@ if __name__ == "__main__":
|
|
| 77 |
img_lab = img_lab.to(device).unsqueeze(0)
|
| 78 |
|
| 79 |
with torch.no_grad():
|
| 80 |
-
img_resize = F.interpolate(img_lab
|
| 81 |
-
img_L_resize = F.interpolate(img_resize[:, :1, :, :]
|
| 82 |
|
| 83 |
color_vector = colorEncoder(img_resize)
|
| 84 |
fake_ab = colorUNet((img_L_resize, color_vector))
|
| 85 |
-
fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear',
|
| 86 |
|
| 87 |
fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
|
| 88 |
fake_img = Lab2RGB_out(fake_img)
|
|
|
|
| 29 |
l = inputs[:, :, 0:1]
|
| 30 |
ab = inputs[:, :, 1:3]
|
| 31 |
l = l - 50
|
| 32 |
+
l = l / 50 # Normalizar L al rango [-1, 1]
|
| 33 |
+
ab = ab / 110 # Normalizar ab al rango [-1, 1]
|
| 34 |
lab = np.concatenate((l, ab), 2)
|
| 35 |
return lab.astype('float32')
|
| 36 |
|
|
|
|
| 79 |
img_lab = img_lab.to(device).unsqueeze(0)
|
| 80 |
|
| 81 |
with torch.no_grad():
|
| 82 |
+
img_resize = F.interpolate(img_lab, size=(256, 256), mode='bilinear', align_corners=False)
|
| 83 |
+
img_L_resize = F.interpolate(img_resize[:, :1, :, :], size=(256, 256), mode='bilinear', align_corners=False)
|
| 84 |
|
| 85 |
color_vector = colorEncoder(img_resize)
|
| 86 |
fake_ab = colorUNet((img_L_resize, color_vector))
|
| 87 |
+
fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear', align_corners=False)
|
| 88 |
|
| 89 |
fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
|
| 90 |
fake_img = Lab2RGB_out(fake_img)
|