| | import os |
| | import random |
| | import cv2 |
| | from scipy import ndimage |
| |
|
| | import gradio as gr |
| | import argparse |
| | import litellm |
| |
|
| | import numpy as np |
| | import torch |
| | import torchvision |
| | from PIL import Image, ImageDraw, ImageFont |
| |
|
| | |
| | import GroundingDINO.groundingdino.datasets.transforms as T |
| | from GroundingDINO.groundingdino.models import build_model |
| | from GroundingDINO.groundingdino.util.slconfig import SLConfig |
| | from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap |
| |
|
| | |
| | from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator |
| | import numpy as np |
| |
|
| | |
| | import torch |
| | from diffusers import StableDiffusionInpaintPipeline |
| |
|
| | |
| | from transformers import BlipProcessor, BlipForConditionalGeneration |
| |
|
| | import openai |
| |
|
| | def show_anns(anns): |
| | if len(anns) == 0: |
| | return |
| | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
| | full_img = None |
| |
|
| | |
| | for i in range(len(sorted_anns)): |
| | ann = anns[i] |
| | m = ann['segmentation'] |
| | if full_img is None: |
| | full_img = np.zeros((m.shape[0], m.shape[1], 3)) |
| | map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) |
| | map[m != 0] = i + 1 |
| | color_mask = np.random.random((1, 3)).tolist()[0] |
| | full_img[m != 0] = color_mask |
| | full_img = full_img*255 |
| | |
| | res = np.zeros((map.shape[0], map.shape[1], 3)) |
| | res[:, :, 0] = map % 256 |
| | res[:, :, 1] = map // 256 |
| | res.astype(np.float32) |
| | full_img = Image.fromarray(np.uint8(full_img)) |
| | return full_img, res |
| |
|
| | def generate_caption(processor, blip_model, raw_image): |
| | |
| | inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16) |
| | out = blip_model.generate(**inputs) |
| | caption = processor.decode(out[0], skip_special_tokens=True) |
| | return caption |
| |
|
| | def generate_tags(caption, split=',', max_tokens=100, model="gpt-3.5-turbo", openai_api_key=''): |
| | openai.api_key = openai_api_key |
| | openai.api_base = 'https://closeai.deno.dev/v1' |
| | prompt = [ |
| | { |
| | 'role': 'system', |
| | 'content': 'Extract the unique nouns in the caption. Remove all the adjectives. ' + \ |
| | f'List the nouns in singular form. Split them by "{split} ". ' + \ |
| | f'Caption: {caption}.' |
| | } |
| | ] |
| | response = litellm.completion(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens) |
| | reply = response['choices'][0]['message']['content'] |
| | |
| | tags = reply.split(':')[-1].strip() |
| | return tags |
| |
|
| | def transform_image(image_pil): |
| |
|
| | transform = T.Compose( |
| | [ |
| | T.RandomResize([800], max_size=1333), |
| | T.ToTensor(), |
| | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| | image, _ = transform(image_pil, None) |
| | return image |
| |
|
| |
|
| | def load_model(model_config_path, model_checkpoint_path, device): |
| | args = SLConfig.fromfile(model_config_path) |
| | args.device = device |
| | model = build_model(args) |
| | checkpoint = torch.load(model_checkpoint_path, map_location="cpu") |
| | load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) |
| | print(load_res) |
| | _ = model.eval() |
| | return model |
| |
|
| |
|
| | def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True): |
| | caption = caption.lower() |
| | caption = caption.strip() |
| | if not caption.endswith("."): |
| | caption = caption + "." |
| |
|
| | with torch.no_grad(): |
| | outputs = model(image[None], captions=[caption]) |
| | logits = outputs["pred_logits"].cpu().sigmoid()[0] |
| | boxes = outputs["pred_boxes"].cpu()[0] |
| | logits.shape[0] |
| |
|
| | |
| | logits_filt = logits.clone() |
| | boxes_filt = boxes.clone() |
| | filt_mask = logits_filt.max(dim=1)[0] > box_threshold |
| | logits_filt = logits_filt[filt_mask] |
| | boxes_filt = boxes_filt[filt_mask] |
| | logits_filt.shape[0] |
| |
|
| | |
| | tokenlizer = model.tokenizer |
| | tokenized = tokenlizer(caption) |
| | |
| | pred_phrases = [] |
| | scores = [] |
| | for logit, box in zip(logits_filt, boxes_filt): |
| | pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) |
| | if with_logits: |
| | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") |
| | else: |
| | pred_phrases.append(pred_phrase) |
| | scores.append(logit.max().item()) |
| |
|
| | return boxes_filt, torch.Tensor(scores), pred_phrases |
| |
|
| | def draw_mask(mask, draw, random_color=False): |
| | if random_color: |
| | color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 153) |
| | else: |
| | color = (30, 144, 255, 153) |
| |
|
| | nonzero_coords = np.transpose(np.nonzero(mask)) |
| |
|
| | for coord in nonzero_coords: |
| | draw.point(coord[::-1], fill=color) |
| |
|
| | def draw_box(box, draw, label): |
| | |
| | color = tuple(np.random.randint(0, 255, size=3).tolist()) |
| |
|
| | draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=2) |
| |
|
| | if label: |
| | font = ImageFont.load_default() |
| | if hasattr(font, "getbbox"): |
| | bbox = draw.textbbox((box[0], box[1]), str(label), font) |
| | else: |
| | w, h = draw.textsize(str(label), font) |
| | bbox = (box[0], box[1], w + box[0], box[1] + h) |
| | draw.rectangle(bbox, fill=color) |
| | draw.text((box[0], box[1]), str(label), fill="white") |
| |
|
| | draw.text((box[0], box[1]), label) |
| |
|
| |
|
| |
|
| | config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py' |
| | ckpt_repo_id = "ShilongLiu/GroundingDINO" |
| | ckpt_filenmae = "groundingdino_swint_ogc.pth" |
| | sam_checkpoint='sam_vit_h_4b8939.pth' |
| | output_dir="outputs" |
| | device="cuda" |
| |
|
| |
|
| | blip_processor = None |
| | blip_model = None |
| | groundingdino_model = None |
| | sam_predictor = None |
| | sam_automask_generator = None |
| | inpaint_pipeline = None |
| |
|
| | def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, scribble_mode, openai_api_key): |
| |
|
| | global blip_processor, blip_model, groundingdino_model, sam_predictor, sam_automask_generator, inpaint_pipeline |
| |
|
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | image = input_image["image"] |
| | scribble = input_image["mask"] |
| | size = image.size |
| |
|
| | if sam_predictor is None: |
| | |
| | assert sam_checkpoint, 'sam_checkpoint is not found!' |
| | sam = build_sam(checkpoint=sam_checkpoint) |
| | sam.to(device=device) |
| | sam_predictor = SamPredictor(sam) |
| | sam_automask_generator = SamAutomaticMaskGenerator(sam) |
| |
|
| | if groundingdino_model is None: |
| | groundingdino_model = load_model(config_file, ckpt_filenmae, device=device) |
| |
|
| | image_pil = image.convert("RGB") |
| | image = np.array(image_pil) |
| |
|
| | if task_type == 'scribble': |
| | sam_predictor.set_image(image) |
| | scribble = scribble.convert("RGB") |
| | scribble = np.array(scribble) |
| | scribble = scribble.transpose(2, 1, 0)[0] |
| |
|
| | |
| | labeled_array, num_features = ndimage.label(scribble >= 255) |
| |
|
| | |
| | centers = ndimage.center_of_mass(scribble, labeled_array, range(1, num_features+1)) |
| | centers = np.array(centers) |
| |
|
| | point_coords = torch.from_numpy(centers) |
| | point_coords = sam_predictor.transform.apply_coords_torch(point_coords, image.shape[:2]) |
| | point_coords = point_coords.unsqueeze(0).to(device) |
| | point_labels = torch.from_numpy(np.array([1] * len(centers))).unsqueeze(0).to(device) |
| | if scribble_mode == 'split': |
| | point_coords = point_coords.permute(1, 0, 2) |
| | point_labels = point_labels.permute(1, 0) |
| | masks, _, _ = sam_predictor.predict_torch( |
| | point_coords=point_coords if len(point_coords) > 0 else None, |
| | point_labels=point_labels if len(point_coords) > 0 else None, |
| | mask_input = None, |
| | boxes = None, |
| | multimask_output = False, |
| | ) |
| | elif task_type == 'automask': |
| | masks = sam_automask_generator.generate(image) |
| | else: |
| | transformed_image = transform_image(image_pil) |
| |
|
| | if task_type == 'automatic': |
| | |
| | |
| | |
| | |
| | blip_processor = blip_processor or BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
| | blip_model = blip_model or BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda") |
| | text_prompt = generate_caption(blip_processor, blip_model, image_pil) |
| | if len(openai_api_key) > 0: |
| | text_prompt = generate_tags(text_prompt, split=",", openai_api_key=openai_api_key) |
| | print(f"Caption: {text_prompt}") |
| |
|
| | |
| | boxes_filt, scores, pred_phrases = get_grounding_output( |
| | groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold |
| | ) |
| |
|
| | |
| | H, W = size[1], size[0] |
| | for i in range(boxes_filt.size(0)): |
| | boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) |
| | boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 |
| | boxes_filt[i][2:] += boxes_filt[i][:2] |
| |
|
| | boxes_filt = boxes_filt.cpu() |
| |
|
| |
|
| | if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic': |
| | sam_predictor.set_image(image) |
| |
|
| | if task_type == 'automatic': |
| | |
| | print(f"Before NMS: {boxes_filt.shape[0]} boxes") |
| | nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist() |
| | boxes_filt = boxes_filt[nms_idx] |
| | pred_phrases = [pred_phrases[idx] for idx in nms_idx] |
| | print(f"After NMS: {boxes_filt.shape[0]} boxes") |
| | print(f"Revise caption with number: {text_prompt}") |
| |
|
| | transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device) |
| |
|
| | masks, _, _ = sam_predictor.predict_torch( |
| | point_coords = None, |
| | point_labels = None, |
| | boxes = transformed_boxes, |
| | multimask_output = False, |
| | ) |
| |
|
| | if task_type == 'det': |
| | image_draw = ImageDraw.Draw(image_pil) |
| | for box, label in zip(boxes_filt, pred_phrases): |
| | draw_box(box, image_draw, label) |
| |
|
| | return [image_pil] |
| | elif task_type == 'automask': |
| | full_img, res = show_anns(masks) |
| | return [full_img] |
| | elif task_type == 'scribble': |
| | mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) |
| |
|
| | mask_draw = ImageDraw.Draw(mask_image) |
| |
|
| | for mask in masks: |
| | draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True) |
| |
|
| | image_pil = image_pil.convert('RGBA') |
| | image_pil.alpha_composite(mask_image) |
| | return [image_pil, mask_image] |
| | elif task_type == 'seg' or task_type == 'automatic': |
| | |
| | mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) |
| |
|
| | mask_draw = ImageDraw.Draw(mask_image) |
| | for mask in masks: |
| | draw_mask(mask[0].cpu().numpy(), mask_draw, random_color=True) |
| |
|
| | image_draw = ImageDraw.Draw(image_pil) |
| |
|
| | for box, label in zip(boxes_filt, pred_phrases): |
| | draw_box(box, image_draw, label) |
| |
|
| | if task_type == 'automatic': |
| | image_draw.text((10, 10), text_prompt, fill='black') |
| |
|
| | image_pil = image_pil.convert('RGBA') |
| | image_pil.alpha_composite(mask_image) |
| | return [image_pil, mask_image] |
| | elif task_type == 'inpainting': |
| | assert inpaint_prompt, 'inpaint_prompt is not found!' |
| | |
| | if inpaint_mode == 'merge': |
| | masks = torch.sum(masks, dim=0).unsqueeze(0) |
| | masks = torch.where(masks > 0, True, False) |
| | mask = masks[0][0].cpu().numpy() |
| | mask_pil = Image.fromarray(mask) |
| | |
| | if inpaint_pipeline is None: |
| | inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( |
| | "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 |
| | ) |
| | inpaint_pipeline = inpaint_pipeline.to("cuda") |
| |
|
| | image = inpaint_pipeline(prompt=inpaint_prompt, image=image_pil.resize((512, 512)), mask_image=mask_pil.resize((512, 512))).images[0] |
| | image = image.resize(size) |
| |
|
| | return [image, mask_pil] |
| | else: |
| | print("task_type:{} error!".format(task_type)) |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True) |
| | parser.add_argument("--debug", action="store_true", help="using debug mode") |
| | parser.add_argument("--share", action="store_true", help="share the app") |
| | parser.add_argument('--port', type=int, default=7589, help='port to run the server') |
| | parser.add_argument('--no-gradio-queue', action="store_true", help='path to the SAM checkpoint') |
| | args = parser.parse_args() |
| |
|
| | print(args) |
| |
|
| | block = gr.Blocks() |
| | if not args.no_gradio_queue: |
| | block = block.queue() |
| |
|
| | with block: |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_image = gr.Image(source='upload', type="pil", value="assets/demo1.jpg", tool="sketch") |
| | task_type = gr.Dropdown(["scribble", "automask", "det", "seg", "inpainting", "automatic"], value="automatic", label="task_type") |
| | text_prompt = gr.Textbox(label="Text Prompt") |
| | inpaint_prompt = gr.Textbox(label="Inpaint Prompt") |
| | run_button = gr.Button(label="Run") |
| | with gr.Accordion("Advanced options", open=False): |
| | box_threshold = gr.Slider( |
| | label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.05 |
| | ) |
| | text_threshold = gr.Slider( |
| | label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.05 |
| | ) |
| | iou_threshold = gr.Slider( |
| | label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.05 |
| | ) |
| | inpaint_mode = gr.Dropdown(["merge", "first"], value="merge", label="inpaint_mode") |
| | scribble_mode = gr.Dropdown(["merge", "split"], value="split", label="scribble_mode") |
| | openai_api_key= gr.Textbox(label="(Optional)OpenAI key, enable chatgpt") |
| |
|
| | with gr.Column(): |
| | gallery = gr.Gallery( |
| | label="Generated images", show_label=False, elem_id="gallery" |
| | ).style(preview=True, grid=2, object_fit="scale-down") |
| |
|
| | run_button.click(fn=run_grounded_sam, inputs=[ |
| | input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, scribble_mode, openai_api_key], outputs=gallery) |
| |
|
| | block.queue(concurrency_count=100) |
| | block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share) |