Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
|
@@ -15,6 +15,8 @@ from tqdm import tqdm, trange
|
|
| 15 |
import skimage.io as io
|
| 16 |
import PIL.Image
|
| 17 |
import gradio as gr
|
|
|
|
|
|
|
| 18 |
N = type(None)
|
| 19 |
V = np.array
|
| 20 |
ARRAY = np.ndarray
|
|
@@ -228,47 +230,47 @@ clip_model, preprocess = clip.load("ViT-B/16", device=device, jit=False)
|
|
| 228 |
from transformers import AutoTokenizer
|
| 229 |
tokenizer = AutoTokenizer.from_pretrained("imthanhlv/gpt2news")
|
| 230 |
|
| 231 |
-
def inference(img, text, is_translate):
|
| 232 |
-
prefix_length = 10
|
| 233 |
-
model = ClipCaptionModel(prefix_length)
|
| 234 |
-
model_path = 'sat_019.pt'
|
| 235 |
-
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
| 236 |
-
model = model.eval()
|
| 237 |
-
device = CUDA(0) if is_gpu else "cpu"
|
| 238 |
-
model = model.to(device)
|
| 239 |
-
use_beam_search = True
|
| 240 |
-
if is_translate:
|
| 241 |
-
# encode text
|
| 242 |
-
if text is None:
|
| 243 |
-
return "No text provided"
|
| 244 |
-
text = clip.tokenize([text]).to(device)
|
| 245 |
-
with torch.no_grad():
|
| 246 |
-
prefix = clip_model.encode_text(text).to(device, dtype=torch.float32)
|
| 247 |
-
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
| 248 |
-
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
|
| 264 |
title = "CLIP Dual encoder"
|
| 265 |
-
description = "You can translate English
|
| 266 |
examples=[["drug.jpg","", False], ["", "What is your name?", True]]
|
| 267 |
|
| 268 |
inputs = [
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
]
|
| 273 |
|
| 274 |
gr.Interface(
|
|
|
|
| 15 |
import skimage.io as io
|
| 16 |
import PIL.Image
|
| 17 |
import gradio as gr
|
| 18 |
+
|
| 19 |
+
|
| 20 |
N = type(None)
|
| 21 |
V = np.array
|
| 22 |
ARRAY = np.ndarray
|
|
|
|
| 230 |
from transformers import AutoTokenizer
|
| 231 |
tokenizer = AutoTokenizer.from_pretrained("imthanhlv/gpt2news")
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
def inference(img, text, is_translation):
|
| 235 |
+
prefix_length = 10
|
| 236 |
+
model = ClipCaptionModel(prefix_length)
|
| 237 |
+
model_path = 'sat_019.pt'
|
| 238 |
+
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
| 239 |
+
model = model.eval()
|
| 240 |
+
device = CUDA(0) if is_gpu else "cpu"
|
| 241 |
+
model = model.to(device)
|
| 242 |
+
if is_translation:
|
| 243 |
+
# encode text
|
| 244 |
+
if text is None:
|
| 245 |
+
return "No text provided"
|
| 246 |
+
text = clip.tokenize([text]).to(device)
|
| 247 |
+
with torch.no_grad():
|
| 248 |
+
prefix = clip_model.encode_text(text).to(device, dtype=torch.float32)
|
| 249 |
+
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
| 250 |
+
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
|
| 251 |
|
| 252 |
+
else:
|
| 253 |
+
if img is None:
|
| 254 |
+
return "No image"
|
| 255 |
+
image = io.imread(img.name)
|
| 256 |
+
pil_image = PIL.Image.fromarray(image)
|
| 257 |
+
image = preprocess(pil_image).unsqueeze(0).to(device)
|
| 258 |
+
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
| 261 |
+
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
| 262 |
+
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed, prompt="Một bức ảnh về")[0]
|
| 263 |
|
| 264 |
+
return generated_text_prefix
|
| 265 |
|
| 266 |
title = "CLIP Dual encoder"
|
| 267 |
+
description = "You can translate English to Vietnamese or generate Vietnamese caption from image"
|
| 268 |
examples=[["drug.jpg","", False], ["", "What is your name?", True]]
|
| 269 |
|
| 270 |
inputs = [
|
| 271 |
+
gr.inputs.Image(type="file", label="Image to generate Vietnamese caption", optional=True),
|
| 272 |
+
gr.inputs.Textbox(lines=2, placeholder="English sentence for translation"),
|
| 273 |
+
gr.inputs.Checkbox()
|
| 274 |
]
|
| 275 |
|
| 276 |
gr.Interface(
|