import os import sys import tqdm import wget import gdown import torch import shutil import base64 import warnings import importlib import numpy as np import torch.nn.functional as F import torchvision.transforms as transforms import albumentations as A import albumentations.pytorch as AP from PIL import Image from io import BytesIO from packaging import version filepath = os.path.abspath(__file__) repopath = os.path.split(filepath)[0] sys.path.append(repopath) from transparent_background.InSPyReNet import InSPyReNet_SwinB from transparent_background.utils import * class Remover: def __init__(self, mode="base", jit=False, device=None, ckpt=None, resize='static'): """ Args: mode (str): Choose among below options base -> slow & large gpu memory required, high quality results fast -> resize input into small size for fast computation base-nightly -> nightly release for base mode jit (bool): use TorchScript for fast computation device (str, optional): specifying device for computation. find available GPU resource if not specified. ckpt (str, optional): specifying model checkpoint. find downloaded checkpoint or try download if not specified. fast (bool, optional, DEPRECATED): replaced by mode argument. use fast mode if True. """ cfg_path = os.environ.get('TRANSPARENT_BACKGROUND_FILE_PATH', os.path.abspath(os.path.expanduser('~'))) home_dir = os.path.join(cfg_path, ".transparent-background") os.makedirs(home_dir, exist_ok=True) if not os.path.isfile(os.path.join(home_dir, "config.yaml")): shutil.copy(os.path.join(repopath, "config.yaml"), os.path.join(home_dir, "config.yaml")) self.meta = load_config(os.path.join(home_dir, "config.yaml"))[mode] if device is not None: self.device = device else: self.device = "cpu" if torch.cuda.is_available(): self.device = "cuda:0" elif ( version.parse(torch.__version__) >= version.parse("1.13") and torch.backends.mps.is_available() ): self.device = "mps:0" download = False if ckpt is None: ckpt_dir = home_dir ckpt_name = self.meta.ckpt_name if not os.path.isfile(os.path.join(ckpt_dir, ckpt_name)): download = True elif ( self.meta.md5 != hashlib.md5( open(os.path.join(ckpt_dir, ckpt_name), "rb").read() ).hexdigest() ): if self.meta.md5 is not None: download = True if download: if 'drive.google.com' in self.meta.url: gdown.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name), fuzzy=True, proxy=self.meta.http_proxy) elif 'github.com' in self.meta.url: wget.download(self.meta.url, os.path.join(ckpt_dir, ckpt_name)) else: raise NotImplementedError('Please use valid URL') else: ckpt_dir, ckpt_name = os.path.split(os.path.abspath(ckpt)) self.model = InSPyReNet_SwinB(depth=64, pretrained=False, threshold=None, **self.meta) self.model.eval() self.model.load_state_dict( torch.load(os.path.join(ckpt_dir, ckpt_name), map_location="cpu", weights_only=True), strict=True, ) self.model = self.model.to(self.device) if jit: ckpt_name = self.meta.ckpt_name.replace( ".pth", "_{}.pt".format(self.device) ) try: traced_model = torch.jit.load( os.path.join(ckpt_dir, ckpt_name), map_location=self.device ) del self.model self.model = traced_model except: traced_model = torch.jit.trace( self.model, torch.rand(1, 3, *self.meta.base_size).to(self.device), strict=True, ) del self.model self.model = traced_model torch.jit.save(self.model, os.path.join(ckpt_dir, ckpt_name)) if resize != 'static': warnings.warn('Resizing method for TorchScript mode only supports static resize. Fallback to static.') resize = 'static' resize_tf = None resize_fn = None if resize == 'static': resize_tf = static_resize(self.meta.base_size) resize_fn = A.Resize(*self.meta.base_size) elif resize == 'dynamic': if 'base' not in mode: warnings.warn('Dynamic resizing only supports base and base-nightly mode. It will cause severe performance degradation with fast mode.') resize_tf = dynamic_resize(L=1280) resize_fn = dynamic_resize_a(L=1280) else: raise AttributeError(f'Unsupported resizing method {resize}') self.transform = transforms.Compose( [ resize_tf, tonumpy(), normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), totensor(), ] ) self.cv2_transform = A.Compose( [ resize_fn, A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), AP.ToTensorV2(), ] ) self.background = {'img': None, 'name': None, 'shape': None} desc = "Mode={}, Device={}, Torchscript={}".format( mode, self.device, "enabled" if jit else "disabled" ) print("Settings -> {}".format(desc)) def process(self, img, type="rgba", threshold=None, reverse=False): """ Args: img (PIL.Image or np.ndarray): input image as PIL.Image or np.ndarray type type (str): output type option as below. 'rgba' will generate RGBA output regarding saliency score as an alpha map. 'green' will change the background with green screen. 'white' will change the background with white color. '[255, 0, 0]' will change the background with color code [255, 0, 0]. 'blur' will blur the background. 'overlay' will cover the salient object with translucent green color, and highlight the edges. Another image file (e.g., 'samples/backgroud.png') will be used as a background, and the object will be overlapped on it. threshold (float or str, optional): produce hard prediction w.r.t specified threshold value (0.0 ~ 1.0) Returns: PIL.Image: output image """ if isinstance(img, np.ndarray): is_numpy = True shape = img.shape[:2] x = self.cv2_transform(image=img)["image"] else: is_numpy = False shape = img.size[::-1] x = self.transform(img) x = x.unsqueeze(0) x = x.to(self.device) with torch.no_grad(): pred = self.model(x) pred = F.interpolate(pred, shape, mode="bilinear", align_corners=True) pred = pred.data.cpu() pred = pred.numpy().squeeze() if threshold is not None: pred = (pred > float(threshold)).astype(np.float64) if reverse: pred = 1 - pred img = np.array(img) if type.startswith("["): type = [int(i) for i in type[1:-1].split(",")] if type == "map": img = (np.stack([pred] * 3, axis=-1) * 255).astype(np.uint8) elif type == "rgba": if threshold is None: # pymatting is imported here to avoid the overhead in other cases. try: from pymatting.foreground.estimate_foreground_ml_cupy import estimate_foreground_ml_cupy as estimate_foreground_ml except ImportError: try: from pymatting.foreground.estimate_foreground_ml_pyopencl import estimate_foreground_ml_pyopencl as estimate_foreground_ml except ImportError: from pymatting import estimate_foreground_ml img = estimate_foreground_ml(img / 255.0, pred) img = 255 * np.clip(img, 0., 1.) + 0.5 img = img.astype(np.uint8) r, g, b = cv2.split(img) pred = (pred * 255).astype(np.uint8) img = cv2.merge([r, g, b, pred]) elif type == "green": bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) elif type == "white": bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * [255, 255, 255] img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) elif len(type) == 3: print(type) bg = np.stack([np.ones_like(pred)] * 3, axis=-1) * type img = img * pred[..., np.newaxis] + bg * (1 - pred[..., np.newaxis]) elif type == "blur": img = img * pred[..., np.newaxis] + cv2.GaussianBlur(img, (0, 0), 15) * ( 1 - pred[..., np.newaxis] ) elif type == "overlay": bg = ( np.stack([np.ones_like(pred)] * 3, axis=-1) * [120, 255, 155] + img ) // 2 img = bg * pred[..., np.newaxis] + img * (1 - pred[..., np.newaxis]) border = cv2.Canny(((pred > 0.5) * 255).astype(np.uint8), 50, 100) img[border != 0] = [120, 255, 155] elif type.lower().endswith(IMG_EXTS): if self.background['name'] != type: background_img = cv2.cvtColor(cv2.imread(type), cv2.COLOR_BGR2RGB) background_img = cv2.resize(background_img, img.shape[:2][::-1]) self.background['img'] = background_img self.background['shape'] = img.shape[:2][::-1] self.background['name'] = type elif self.background['shape'] != img.shape[:2][::-1]: self.background['img'] = cv2.resize(self.background['img'], img.shape[:2][::-1]) self.background['shape'] = img.shape[:2][::-1] img = img * pred[..., np.newaxis] + self.background['img'] * ( 1 - pred[..., np.newaxis] ) if is_numpy: return img.astype(np.uint8) else: return Image.fromarray(img.astype(np.uint8)) def to_base64(image): buffered = BytesIO() image.save(buffered, format="JPEG") base64_img = base64.b64encode(buffered.getvalue()).decode("utf-8") return base64_img def entry_point(out_type, mode, device, ckpt, source, dest, jit, threshold, resize, save_format=None, reverse=False, flet_progress=None, flet_page=None, preview=None, preview_out=None, options=None): warnings.filterwarnings("ignore") remover = Remover(mode=mode, jit=jit, device=device, ckpt=ckpt, resize=resize) if source.isnumeric() is True: save_dir = None _format = "Webcam" if importlib.util.find_spec('pyvirtualcam') is not None: try: import pyvirtualcam vcam = pyvirtualcam.Camera(width=640, height=480, fps=30) except: vcam = None else: raise ImportError("pyvirtualcam not found. Install with \"pip install transparent-background[webcam]\"") elif os.path.isdir(source): save_dir = os.path.join(os.getcwd(), source.split(os.sep)[-1]) _format = get_format(os.listdir(source)) elif os.path.isfile(source): save_dir = os.getcwd() _format = get_format([source]) else: raise FileNotFoundError("File or directory {} is invalid.".format(source)) if out_type == "rgba" and _format == "Video": raise AttributeError("type 'rgba' cannot be applied to video input.") if dest is not None: save_dir = dest if save_dir is not None: os.makedirs(save_dir, exist_ok=True) loader = eval(_format + "Loader")(source) frame_progress = tqdm.tqdm( total=len(loader), position=1 if (_format == "Video" and len(loader) > 1) else 0, leave=False, bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}", ) sample_progress = ( tqdm.tqdm( total=len(loader), desc="Total:", position=0, bar_format="{desc:<15}{percentage:3.0f}%|{bar:50}{r_bar}", ) if (_format == "Video" and len(loader) > 1) else None ) if flet_progress is not None: assert flet_page is not None flet_progress.value = 0 flet_step = 1 / frame_progress.total writer = None for img, name in loader: filename, ext = os.path.splitext(name) ext = ext[1:] ext = save_format if save_format is not None else ext frame_progress.set_description("{}".format(name)) if out_type.lower().endswith(IMG_EXTS): outname = "{}_{}".format( filename, os.path.splitext(os.path.split(out_type)[-1])[0], ) else: outname = "{}_{}".format(filename, out_type) if reverse: outname += '_reverse' if _format == "Video" and writer is None: writer = cv2.VideoWriter( os.path.join(save_dir, f"{outname}.{ext}"), cv2.VideoWriter_fourcc(*"mp4v"), loader.fps, img.size, ) writer.set(cv2.VIDEOWRITER_PROP_QUALITY, 100) frame_progress.refresh() frame_progress.reset() frame_progress.total = int(loader.cap.get(cv2.CAP_PROP_FRAME_COUNT)) if sample_progress is not None: sample_progress.update() if flet_progress is not None: assert flet_page is not None flet_progress.value = 0 flet_step = 1 / frame_progress.total flet_progress.update() if _format == "Video" and img is None: if writer is not None: writer.release() writer = None continue out = remover.process(img, type=out_type, threshold=threshold, reverse=reverse) if _format == "Image": if out_type == "rgba" and ext.lower() != 'png': warnings.warn('Output format for rgba mode only supports png format. Fallback to png output.') ext = 'png' out.save(os.path.join(save_dir, f"{outname}.{ext}")) elif _format == "Video" and writer is not None: writer.write(cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB)) elif _format == "Webcam": if vcam is not None: vcam.send(np.array(out)) vcam.sleep_until_next_frame() else: cv2.imshow( "transparent-background", cv2.cvtColor(np.array(out), cv2.COLOR_BGR2RGB) ) frame_progress.update() if flet_progress is not None: flet_progress.value += flet_step flet_progress.update() if out_type == 'rgba': o = np.array(out).astype(np.float64) o[:, :, :3] *= (o[:, :, -1:] / 255) out = Image.fromarray(o[:, :, :3].astype(np.uint8)) preview.src_base64 = to_base64(img.resize((480, 300)).convert('RGB')) preview_out.src_base64 = to_base64(out.resize((480, 300)).convert('RGB')) preview.update() preview_out.update() if options is not None and options['abort']: break print("\nDone. Results are saved in {}".format(os.path.abspath(save_dir))) def console(): args = parse_args() entry_point(args.type, args.mode, args.device, args.ckpt, args.source, args.dest, args.jit, args.threshold, args.resize, args.format, args.reverse)