DRgaddam's picture
addinng file
247a0eb verified
import os
import sys
import tqdm
import wget
import gdown
import torch
import shutil
import base64
import warnings
import importlib
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
import albumentations as A
import albumentations.pytorch as AP
from PIL import Image
from io import BytesIO
from packaging import version
filepath = os.path.abspath(__file__)
repopath = os.path.split(filepath)[0]
sys.path.append(repopath)
from transparent_background.InSPyReNet import InSPyReNet_SwinB
from transparent_background.utils import *
class Remover:
def __init__(self, mode="base", jit=False, device=None, ckpt=None, resize='static'):
"""
Args:
mode (str): Choose among below options
base -> slow & large gpu memory required, high quality results
fast -> resize input into small size for fast computation
base-nightly -> nightly release for base mode
jit (bool): use TorchScript for fast computation
device (str, optional): specifying device for computation. find available GPU resource if not specified.
ckpt (str, optional): specifying model checkpoint. find downloaded checkpoint or try download if not specified.
fast (bool, optional, DEPRECATED): replaced by mode argument. use fast mode if True.
"""
cfg_path = os.environ.get('TRANSPARENT_BACKGROUND_FILE_PATH', os.path.abspath(os.path.expanduser('~')))
home_dir = os.path.join(cfg_path, ".transparent-background")
os.makedirs(home_dir, exist_ok=True)
if not os.path.isfile(os.path.join(home_dir, "config.yaml")):
shutil.copy(os.path.join(repopath, "config.yaml"), os.path.join(home_dir, "config.yaml"))
self.meta = load_config(os.path.join(home_dir, "config.yaml"))[mode]
if device is not None:
self.device = device
else:
self.device = "cpu"
if torch.cuda.is_available():
self.device = "cuda:0"
elif (
version.parse(torch.__version__) >= version.parse("1.13")
and torch.backends.mps.is_available()
):
self.device = "mps:0"
download = False
if ckpt is None:
ckpt_dir = home_dir
ckpt_name = self.meta.ckpt_name
if not os.path.isfile(os.path.join(ckpt_dir, ckpt_name)):
download = True
elif (
self.meta.md5
!= hashlib.md5(
open(os.path.join(ckpt_dir, ckpt_name), "rb").read()
).hexdigest()
):
if self.meta.md5 is not None:
download = True
if download:
if 'drive.google.com' in self.meta.url:
gdown.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name), fuzzy=True, proxy=self.meta.http_proxy)
elif 'github.com' in self.meta.url:
wget.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name))
else:
raise NotImplementedError('Please use valid URL')
else:
ckpt_dir, ckpt_name = os.path.split(os.path.abspath(ckpt))
self.model = InSPyReNet_SwinB(depth=64, pretrained=False, threshold=None, **self.meta)
self.model.eval()
self.model.load_state_dict(
torch.load(os.path.join(ckpt_dir, ckpt_name), map_location="cpu", weights_only=True),
strict=True,
)
self.model = self.model.to(self.device)
if jit:
ckpt_name = self.meta.ckpt_name.replace(
".pth", "_{}.pt".format(self.device)
)
try:
traced_model = torch.jit.load(
os.path.join(ckpt_dir, ckpt_name), map_location=self.device
)
del self.model
self.model = traced_model
except:
traced_model = torch.jit.trace(
self.model,
torch.rand(1, 3, *self.meta.base_size).to(self.device),
strict=True,
)
del self.model
self.model = traced_model
torch.jit.save(self.model, os.path.join(ckpt_dir, ckpt_name))
if resize != 'static':
warnings.warn('Resizing method for TorchScript mode only supports static resize. Fallback to static.')
resize = 'static'
resize_tf = None
resize_fn = None
if resize == 'static':
resize_tf = static_resize(self.meta.base_size)
resize_fn = A.Resize(*self.meta.base_size)
elif resize == 'dynamic':
if 'base' not in mode:
warnings.warn('Dynamic resizing only supports base and base-nightly mode. It will cause severe performance degradation with fast mode.')
resize_tf = dynamic_resize(L=1280)
resize_fn = dynamic_resize_a(L=1280)
else:
raise AttributeError(f'Unsupported resizing method {resize}')
self.transform = transforms.Compose(
[
resize_tf,
tonumpy(),
normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
totensor(),
]
)
self.cv2_transform = A.Compose(
[
resize_fn,
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
AP.ToTensorV2(),
]
)
self.background = {'img': None, 'name': None, 'shape': None}
desc = "Mode={}, Device={}, Torchscript={}".format(
mode, self.device, "enabled" if jit else "disabled"
)
print("Settings -> {}".format(desc))
def process(self, img, type="rgba", threshold=None, reverse=False):
"""
Args:
img (PIL.Image or np.ndarray): input image as PIL.Image or np.ndarray type
type (str): output type option as below.
'rgba' will generate RGBA output regarding saliency score as an alpha map.
'green' will change the background with green screen.
'white' will change the background with white color.
'[255, 0, 0]' will change the background with color code [255, 0, 0].
'blur' will blur the background.
'overlay' will cover the salient object with translucent green color, and highlight the edges.
Another image file (e.g., 'samples/backgroud.png') will be used as a background, and the object will be overlapped on it.
threshold (float or str, optional): produce hard prediction w.r.t specified threshold value (0.0 ~ 1.0)
Returns:
PIL.Image: output image
"""
if isinstance(img, np.ndarray):
is_numpy = True
shape = img.shape[:2]
x = self.cv2_transform(image=img)["image"]
else:
is_numpy = False
shape = img.size[::-1]
x = self.transform(img)
x = x.unsqueeze(0)
x = x.to(self.device)
with torch.no_grad():
pred = self.model(x)
pred = F.interpolate(pred, shape, mode="bilinear", align_corners=True)
pred = pred.data.cpu()
pred = pred.numpy().squeeze()
if threshold is not None:
pred = (pred > float(threshold)).astype(np.float64)
if reverse:
pred = 1 - pred
img = np.array(img)
if type.startswith("["):
type = [int(i) for i in type[1:-1].split(",")]
if type == "map":
img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8)
elif type == "rgba":
if threshold is None:
# pymatting is imported here to avoid the overhead in other cases.
try:
from pymatting.foreground.estimate_foreground_ml_cupy import estimate_foreground_ml_cupy as estimate_foreground_ml
except ImportError:
try:
from pymatting.foreground.estimate_foreground_ml_pyopencl import estimate_foreground_ml_pyopencl as estimate_foreground_ml
except ImportError:
from pymatting import estimate_foreground_ml
img = estimate_foreground_ml(img / 255.0, pred)
img = 255 * np.clip(img, 0., 1.) + 0.5
img = img.astype(np.uint8)
r, g, b = cv2.split(img)
pred = (pred * 255).astype(np.uint8)
img = cv2.merge([r, g, b, pred])
elif type == "green":
bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155]
img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis])
elif type == "white":
bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [255, 255, 255]
img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis])
elif len(type) == 3:
print(type)
bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * type
img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis])
elif type == "blur":
img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * (
1 - pred[..., np.newaxis]
)
elif type == "overlay":
bg = (
np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img
) // 2
img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis])
border = cv2.Canny(((pred > 0.5) * 255).astype(np.uint8), 50, 100)
img[border != 0] = [120, 255, 155]
elif type.lower().endswith(IMG_EXTS):
if self.background['name'] != type:
background_img = cv2.cvtColor(cv2.imread(type), cv2.COLOR_BGR2RGB)
background_img = cv2.resize(background_img, img.shape[:2][::-1])
self.background['img'] = background_img
self.background['shape'] = img.shape[:2][::-1]
self.background['name'] = type
elif self.background['shape'] != img.shape[:2][::-1]:
self.background['img'] = cv2.resize(self.background['img'], img.shape[:2][::-1])
self.background['shape'] = img.shape[:2][::-1]
img = img * pred[..., np.newaxis] + self.background['img'] * (
1 - pred[..., np.newaxis]
)
if is_numpy:
return img.astype(np.uint8)
else:
return Image.fromarray(img.astype(np.uint8))
def to_base64(image):
buffered = BytesIO()
image.save(buffered, format="JPEG")
base64_img = base64.b64encode(buffered.getvalue()).decode("utf-8")
return base64_img
def entry_point(out_type, mode, device, ckpt, source, dest, jit, threshold, resize, save_format=None, reverse=False, flet_progress=None, flet_page=None, preview=None, preview_out=None, options=None):
warnings.filterwarnings("ignore")
remover = Remover(mode=mode, jit=jit, device=device, ckpt=ckpt, resize=resize)
if source.isnumeric() is True:
save_dir = None
_format = "Webcam"
if importlib.util.find_spec('pyvirtualcam') is not None:
try:
import pyvirtualcam
vcam = pyvirtualcam.Camera(width=640, height=480, fps=30)
except:
vcam = None
else:
raise ImportError("pyvirtualcam not found. Install with \"pip install transparent-background[webcam]\"")
elif os.path.isdir(source):
save_dir = os.path.join(os.getcwd(), source.split(os.sep)[-1])
_format = get_format(os.listdir(source))
elif os.path.isfile(source):
save_dir = os.getcwd()
_format = get_format([source])
else:
raise FileNotFoundError("File or directory {} is invalid.".format(source))
if out_type == "rgba" and _format == "Video":
raise AttributeError("type 'rgba' cannot be applied to video input.")
if dest is not None:
save_dir = dest
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
loader = eval(_format + "Loader")(source)
frame_progress = tqdm.tqdm(
total=len(loader),
position=1 if (_format == "Video" and len(loader) > 1) else 0,
leave=False,
bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}",
)
sample_progress = (
tqdm.tqdm(
total=len(loader),
desc="Total:",
position=0,
bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}",
)
if (_format == "Video" and len(loader) > 1)
else None
)
if flet_progress is not None:
assert flet_page is not None
flet_progress.value = 0
flet_step = 1 / frame_progress.total
writer = None
for img, name in loader:
filename, ext = os.path.splitext(name)
ext = ext[1:]
ext = save_format if save_format is not None else ext
frame_progress.set_description("{}".format(name))
if out_type.lower().endswith(IMG_EXTS):
outname = "{}_{}".format(
filename,
os.path.splitext(os.path.split(out_type)[-1])[0],
)
else:
outname = "{}_{}".format(filename, out_type)
if reverse:
outname += '_reverse'
if _format == "Video" and writer is None:
writer = cv2.VideoWriter(
os.path.join(save_dir, f"{outname}.{ext}"),
cv2.VideoWriter_fourcc(*"mp4v"),
loader.fps,
img.size,
)
writer.set(cv2.VIDEOWRITER_PROP_QUALITY, 100)
frame_progress.refresh()
frame_progress.reset()
frame_progress.total = int(loader.cap.get(cv2.CAP_PROP_FRAME_COUNT))
if sample_progress is not None:
sample_progress.update()
if flet_progress is not None:
assert flet_page is not None
flet_progress.value = 0
flet_step = 1 / frame_progress.total
flet_progress.update()
if _format == "Video" and img is None:
if writer is not None:
writer.release()
writer = None
continue
out = remover.process(img, type=out_type, threshold=threshold, reverse=reverse)
if _format == "Image":
if out_type == "rgba" and ext.lower() != 'png':
warnings.warn('Output format for rgba mode only supports png format. Fallback to png output.')
ext = 'png'
out.save(os.path.join(save_dir, f"{outname}.{ext}"))
elif _format == "Video" and writer is not None:
writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB))
elif _format == "Webcam":
if vcam is not None:
vcam.send(np.array(out))
vcam.sleep_until_next_frame()
else:
cv2.imshow(
"transparent-background", cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB)
)
frame_progress.update()
if flet_progress is not None:
flet_progress.value += flet_step
flet_progress.update()
if out_type == 'rgba':
o = np.array(out).astype(np.float64)
o[:, :, :3] *= (o[:, :, -1:] / 255)
out = Image.fromarray(o[:, :, :3].astype(np.uint8))
preview.src_base64 = to_base64(img.resize((480, 300)).convert('RGB'))
preview_out.src_base64 = to_base64(out.resize((480, 300)).convert('RGB'))
preview.update()
preview_out.update()
if options is not None and options['abort']:
break
print("\nDone. Results are saved in {}".format(os.path.abspath(save_dir)))
def console():
args = parse_args()
entry_point(args.type, args.mode, args.device, args.ckpt, args.source, args.dest, args.jit, args.threshold, args.resize, args.format, args.reverse)