| | import os |
| | import sys |
| | import tqdm |
| | import wget |
| | import gdown |
| | import torch |
| | import shutil |
| | import base64 |
| | import warnings |
| | import importlib |
| |
|
| | import numpy as np |
| | import torch.nn.functional as F |
| | import torchvision.transforms as transforms |
| | import albumentations as A |
| | import albumentations.pytorch as AP |
| |
|
| | from PIL import Image |
| | from io import BytesIO |
| | from packaging import version |
| |
|
| | filepath = os.path.abspath(__file__) |
| | repopath = os.path.split(filepath)[0] |
| | sys.path.append(repopath) |
| |
|
| | from transparent_background.InSPyReNet import InSPyReNet_SwinB |
| | from transparent_background.utils import * |
| |
|
| | class Remover: |
| | def __init__(self, mode="base", jit=False, device=None, ckpt=None, resize='static'): |
| | """ |
| | Args: |
| | mode (str): Choose among below options |
| | base -> slow & large gpu memory required, high quality results |
| | fast -> resize input into small size for fast computation |
| | base-nightly -> nightly release for base mode |
| | jit (bool): use TorchScript for fast computation |
| | device (str, optional): specifying device for computation. find available GPU resource if not specified. |
| | ckpt (str, optional): specifying model checkpoint. find downloaded checkpoint or try download if not specified. |
| | fast (bool, optional, DEPRECATED): replaced by mode argument. use fast mode if True. |
| | """ |
| | cfg_path = os.environ.get('TRANSPARENT_BACKGROUND_FILE_PATH', os.path.abspath(os.path.expanduser('~'))) |
| | home_dir = os.path.join(cfg_path, ".transparent-background") |
| | os.makedirs(home_dir, exist_ok=True) |
| |
|
| | if not os.path.isfile(os.path.join(home_dir, "config.yaml")): |
| | shutil.copy(os.path.join(repopath, "config.yaml"), os.path.join(home_dir, "config.yaml")) |
| | self.meta = load_config(os.path.join(home_dir, "config.yaml"))[mode] |
| |
|
| | if device is not None: |
| | self.device = device |
| | else: |
| | self.device = "cpu" |
| | if torch.cuda.is_available(): |
| | self.device = "cuda:0" |
| | elif ( |
| | version.parse(torch.__version__) >= version.parse("1.13") |
| | and torch.backends.mps.is_available() |
| | ): |
| | self.device = "mps:0" |
| |
|
| | download = False |
| | if ckpt is None: |
| | ckpt_dir = home_dir |
| | ckpt_name = self.meta.ckpt_name |
| |
|
| | if not os.path.isfile(os.path.join(ckpt_dir, ckpt_name)): |
| | download = True |
| | elif ( |
| | self.meta.md5 |
| | != hashlib.md5( |
| | open(os.path.join(ckpt_dir, ckpt_name), "rb").read() |
| | ).hexdigest() |
| | ): |
| | if self.meta.md5 is not None: |
| | download = True |
| |
|
| | if download: |
| | if 'drive.google.com' in self.meta.url: |
| | gdown.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name), fuzzy=True, proxy=self.meta.http_proxy) |
| | elif 'github.com' in self.meta.url: |
| | wget.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name)) |
| | else: |
| | raise NotImplementedError('Please use valid URL') |
| | else: |
| | ckpt_dir, ckpt_name = os.path.split(os.path.abspath(ckpt)) |
| |
|
| | self.model = InSPyReNet_SwinB(depth=64, pretrained=False, threshold=None, **self.meta) |
| | self.model.eval() |
| | self.model.load_state_dict( |
| | torch.load(os.path.join(ckpt_dir, ckpt_name), map_location="cpu", weights_only=True), |
| | strict=True, |
| | ) |
| | self.model = self.model.to(self.device) |
| |
|
| | if jit: |
| | ckpt_name = self.meta.ckpt_name.replace( |
| | ".pth", "_{}.pt".format(self.device) |
| | ) |
| | try: |
| | traced_model = torch.jit.load( |
| | os.path.join(ckpt_dir, ckpt_name), map_location=self.device |
| | ) |
| | del self.model |
| | self.model = traced_model |
| | except: |
| | traced_model = torch.jit.trace( |
| | self.model, |
| | torch.rand(1, 3, *self.meta.base_size).to(self.device), |
| | strict=True, |
| | ) |
| | del self.model |
| | self.model = traced_model |
| | torch.jit.save(self.model, os.path.join(ckpt_dir, ckpt_name)) |
| | if resize != 'static': |
| | warnings.warn('Resizing method for TorchScript mode only supports static resize. Fallback to static.') |
| | resize = 'static' |
| |
|
| | resize_tf = None |
| | resize_fn = None |
| | if resize == 'static': |
| | resize_tf = static_resize(self.meta.base_size) |
| | resize_fn = A.Resize(*self.meta.base_size) |
| | elif resize == 'dynamic': |
| | if 'base' not in mode: |
| | warnings.warn('Dynamic resizing only supports base and base-nightly mode. It will cause severe performance degradation with fast mode.') |
| | resize_tf = dynamic_resize(L=1280) |
| | resize_fn = dynamic_resize_a(L=1280) |
| | else: |
| | raise AttributeError(f'Unsupported resizing method {resize}') |
| |
|
| | self.transform = transforms.Compose( |
| | [ |
| | resize_tf, |
| | tonumpy(), |
| | normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | totensor(), |
| | ] |
| | ) |
| |
|
| | self.cv2_transform = A.Compose( |
| | [ |
| | resize_fn, |
| | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | AP.ToTensorV2(), |
| | ] |
| | ) |
| |
|
| | self.background = {'img': None, 'name': None, 'shape': None} |
| | desc = "Mode={}, Device={}, Torchscript={}".format( |
| | mode, self.device, "enabled" if jit else "disabled" |
| | ) |
| | print("Settings -> {}".format(desc)) |
| |
|
| | def process(self, img, type="rgba", threshold=None, reverse=False): |
| | """ |
| | Args: |
| | img (PIL.Image or np.ndarray): input image as PIL.Image or np.ndarray type |
| | type (str): output type option as below. |
| | 'rgba' will generate RGBA output regarding saliency score as an alpha map. |
| | 'green' will change the background with green screen. |
| | 'white' will change the background with white color. |
| | '[255, 0, 0]' will change the background with color code [255, 0, 0]. |
| | 'blur' will blur the background. |
| | 'overlay' will cover the salient object with translucent green color, and highlight the edges. |
| | Another image file (e.g., 'samples/backgroud.png') will be used as a background, and the object will be overlapped on it. |
| | threshold (float or str, optional): produce hard prediction w.r.t specified threshold value (0.0 ~ 1.0) |
| | Returns: |
| | PIL.Image: output image |
| | |
| | """ |
| |
|
| | if isinstance(img, np.ndarray): |
| | is_numpy = True |
| | shape = img.shape[:2] |
| | x = self.cv2_transform(image=img)["image"] |
| | else: |
| | is_numpy = False |
| | shape = img.size[::-1] |
| | x = self.transform(img) |
| |
|
| | x = x.unsqueeze(0) |
| | x = x.to(self.device) |
| |
|
| | with torch.no_grad(): |
| | pred = self.model(x) |
| |
|
| | pred = F.interpolate(pred, shape, mode="bilinear", align_corners=True) |
| | pred = pred.data.cpu() |
| | pred = pred.numpy().squeeze() |
| |
|
| | if threshold is not None: |
| | pred = (pred > float(threshold)).astype(np.float64) |
| | if reverse: |
| | pred = 1 - pred |
| |
|
| | img = np.array(img) |
| |
|
| | if type.startswith("["): |
| | type = [int(i) for i in type[1:-1].split(",")] |
| |
|
| | if type == "map": |
| | img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8) |
| |
|
| | elif type == "rgba": |
| | if threshold is None: |
| | |
| | try: |
| | from pymatting.foreground.estimate_foreground_ml_cupy import estimate_foreground_ml_cupy as estimate_foreground_ml |
| | except ImportError: |
| | try: |
| | from pymatting.foreground.estimate_foreground_ml_pyopencl import estimate_foreground_ml_pyopencl as estimate_foreground_ml |
| | except ImportError: |
| | from pymatting import estimate_foreground_ml |
| | img = estimate_foreground_ml(img / 255.0, pred) |
| | img = 255 * np.clip(img, 0., 1.) + 0.5 |
| | img = img.astype(np.uint8) |
| |
|
| | r, g, b = cv2.split(img) |
| | pred = (pred * 255).astype(np.uint8) |
| | img = cv2.merge([r, g, b, pred]) |
| |
|
| | elif type == "green": |
| | bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] |
| | img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
| |
|
| | elif type == "white": |
| | bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [255, 255, 255] |
| | img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
| |
|
| | elif len(type) == 3: |
| | print(type) |
| | bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * type |
| | img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) |
| |
|
| | elif type == "blur": |
| | img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * ( |
| | 1 - pred[..., np.newaxis] |
| | ) |
| |
|
| | elif type == "overlay": |
| | bg = ( |
| | np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img |
| | ) // 2 |
| | img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis]) |
| | border = cv2.Canny(((pred > 0.5) * 255).astype(np.uint8), 50, 100) |
| | img[border != 0] = [120, 255, 155] |
| |
|
| | elif type.lower().endswith(IMG_EXTS): |
| | if self.background['name'] != type: |
| | background_img = cv2.cvtColor(cv2.imread(type), cv2.COLOR_BGR2RGB) |
| | background_img = cv2.resize(background_img, img.shape[:2][::-1]) |
| | |
| | self.background['img'] = background_img |
| | self.background['shape'] = img.shape[:2][::-1] |
| | self.background['name'] = type |
| | |
| | elif self.background['shape'] != img.shape[:2][::-1]: |
| | self.background['img'] = cv2.resize(self.background['img'], img.shape[:2][::-1]) |
| | self.background['shape'] = img.shape[:2][::-1] |
| |
|
| | img = img * pred[..., np.newaxis] + self.background['img'] * ( |
| | 1 - pred[..., np.newaxis] |
| | ) |
| |
|
| | if is_numpy: |
| | return img.astype(np.uint8) |
| | else: |
| | return Image.fromarray(img.astype(np.uint8)) |
| |
|
| | def to_base64(image): |
| | buffered = BytesIO() |
| | image.save(buffered, format="JPEG") |
| | base64_img = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| | return base64_img |
| |
|
| | def entry_point(out_type, mode, device, ckpt, source, dest, jit, threshold, resize, save_format=None, reverse=False, flet_progress=None, flet_page=None, preview=None, preview_out=None, options=None): |
| | warnings.filterwarnings("ignore") |
| |
|
| | remover = Remover(mode=mode, jit=jit, device=device, ckpt=ckpt, resize=resize) |
| |
|
| | if source.isnumeric() is True: |
| | save_dir = None |
| | _format = "Webcam" |
| | if importlib.util.find_spec('pyvirtualcam') is not None: |
| | try: |
| | import pyvirtualcam |
| | vcam = pyvirtualcam.Camera(width=640, height=480, fps=30) |
| | except: |
| | vcam = None |
| | else: |
| | raise ImportError("pyvirtualcam not found. Install with \"pip install transparent-background[webcam]\"") |
| |
|
| | elif os.path.isdir(source): |
| | save_dir = os.path.join(os.getcwd(), source.split(os.sep)[-1]) |
| | _format = get_format(os.listdir(source)) |
| |
|
| | elif os.path.isfile(source): |
| | save_dir = os.getcwd() |
| | _format = get_format([source]) |
| |
|
| | else: |
| | raise FileNotFoundError("File or directory {} is invalid.".format(source)) |
| |
|
| | if out_type == "rgba" and _format == "Video": |
| | raise AttributeError("type 'rgba' cannot be applied to video input.") |
| |
|
| | if dest is not None: |
| | save_dir = dest |
| |
|
| | if save_dir is not None: |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | loader = eval(_format + "Loader")(source) |
| | frame_progress = tqdm.tqdm( |
| | total=len(loader), |
| | position=1 if (_format == "Video" and len(loader) > 1) else 0, |
| | leave=False, |
| | bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}", |
| | ) |
| | sample_progress = ( |
| | tqdm.tqdm( |
| | total=len(loader), |
| | desc="Total:", |
| | position=0, |
| | bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}", |
| | ) |
| | if (_format == "Video" and len(loader) > 1) |
| | else None |
| | ) |
| | if flet_progress is not None: |
| | assert flet_page is not None |
| | flet_progress.value = 0 |
| | flet_step = 1 / frame_progress.total |
| |
|
| | writer = None |
| |
|
| | for img, name in loader: |
| | filename, ext = os.path.splitext(name) |
| | ext = ext[1:] |
| | ext = save_format if save_format is not None else ext |
| | frame_progress.set_description("{}".format(name)) |
| | if out_type.lower().endswith(IMG_EXTS): |
| | outname = "{}_{}".format( |
| | filename, |
| | os.path.splitext(os.path.split(out_type)[-1])[0], |
| | ) |
| | else: |
| | outname = "{}_{}".format(filename, out_type) |
| |
|
| | if reverse: |
| | outname += '_reverse' |
| |
|
| | if _format == "Video" and writer is None: |
| | writer = cv2.VideoWriter( |
| | os.path.join(save_dir, f"{outname}.{ext}"), |
| | cv2.VideoWriter_fourcc(*"mp4v"), |
| | loader.fps, |
| | img.size, |
| | ) |
| | writer.set(cv2.VIDEOWRITER_PROP_QUALITY, 100) |
| | frame_progress.refresh() |
| | frame_progress.reset() |
| | frame_progress.total = int(loader.cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | if sample_progress is not None: |
| | sample_progress.update() |
| |
|
| | if flet_progress is not None: |
| | assert flet_page is not None |
| | flet_progress.value = 0 |
| | flet_step = 1 / frame_progress.total |
| | flet_progress.update() |
| |
|
| | if _format == "Video" and img is None: |
| | if writer is not None: |
| | writer.release() |
| | writer = None |
| | continue |
| |
|
| | out = remover.process(img, type=out_type, threshold=threshold, reverse=reverse) |
| |
|
| | if _format == "Image": |
| | if out_type == "rgba" and ext.lower() != 'png': |
| | warnings.warn('Output format for rgba mode only supports png format. Fallback to png output.') |
| | ext = 'png' |
| | out.save(os.path.join(save_dir, f"{outname}.{ext}")) |
| | elif _format == "Video" and writer is not None: |
| | writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB)) |
| | elif _format == "Webcam": |
| | if vcam is not None: |
| | vcam.send(np.array(out)) |
| | vcam.sleep_until_next_frame() |
| | else: |
| | cv2.imshow( |
| | "transparent-background", cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB) |
| | ) |
| | frame_progress.update() |
| | if flet_progress is not None: |
| | flet_progress.value += flet_step |
| | flet_progress.update() |
| |
|
| | if out_type == 'rgba': |
| | o = np.array(out).astype(np.float64) |
| | o[:, :, :3] *= (o[:, :, -1:] / 255) |
| | out = Image.fromarray(o[:, :, :3].astype(np.uint8)) |
| |
|
| | preview.src_base64 = to_base64(img.resize((480, 300)).convert('RGB')) |
| | preview_out.src_base64 = to_base64(out.resize((480, 300)).convert('RGB')) |
| | preview.update() |
| | preview_out.update() |
| |
|
| | if options is not None and options['abort']: |
| | break |
| | |
| | print("\nDone. Results are saved in {}".format(os.path.abspath(save_dir))) |
| |
|
| | def console(): |
| | args = parse_args() |
| | entry_point(args.type, args.mode, args.device, args.ckpt, args.source, args.dest, args.jit, args.threshold, args.resize, args.format, args.reverse) |
| |
|