prithivMLmods commited on
Commit
8316ee2
·
verified ·
1 Parent(s): c3fa496

update [kernels:flash-attn3] (cleaned) ✅

Browse files
Files changed (1) hide show
  1. app.py +203 -42
app.py CHANGED
@@ -29,9 +29,6 @@ from transformers.image_utils import load_image
29
  from gradio.themes import Soft
30
  from gradio.themes.utils import colors, fonts, sizes
31
 
32
- # --- Theme and CSS Definition ---
33
-
34
- # Define the new OrangeRed color palette
35
  colors.orange_red = colors.Color(
36
  name="orange_red",
37
  c50="#FFF0E5",
@@ -39,7 +36,7 @@ colors.orange_red = colors.Color(
39
  c200="#FFC299",
40
  c300="#FFA366",
41
  c400="#FF8533",
42
- c500="#FF4500", # OrangeRed base color
43
  c600="#E63E00",
44
  c700="#CC3700",
45
  c800="#B33000",
@@ -52,7 +49,7 @@ class OrangeRedTheme(Soft):
52
  self,
53
  *,
54
  primary_hue: colors.Color | str = colors.gray,
55
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
56
  neutral_hue: colors.Color | str = colors.slate,
57
  text_size: sizes.Size | str = sizes.text_lg,
58
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -98,7 +95,6 @@ class OrangeRedTheme(Soft):
98
  block_label_background_fill="*primary_200",
99
  )
100
 
101
- # Instantiate the new theme
102
  orange_red_theme = OrangeRedTheme()
103
 
104
  css = """
@@ -106,7 +102,41 @@ css = """
106
  font-size: 2.3em !important;
107
  }
108
  #output-title h2 {
109
- font-size: 2.1em !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  }
111
  """
112
 
@@ -116,61 +146,126 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
 
117
  print("Using device:", device)
118
 
119
- # --- Model Loading ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Load Qwen3-VL-4B-Instruct
122
  MODEL_ID_Q4B = "Qwen/Qwen3-VL-4B-Instruct"
123
  processor_q4b = AutoProcessor.from_pretrained(MODEL_ID_Q4B, trust_remote_code=True)
124
  model_q4b = Qwen3VLForConditionalGeneration.from_pretrained(
125
  MODEL_ID_Q4B,
126
- attn_implementation="flash_attention_2",
127
  trust_remote_code=True,
128
  torch_dtype=torch.bfloat16
129
  ).to(device).eval()
130
 
131
- # Load Qwen3-VL-8B-Instruct
132
  MODEL_ID_Q8B = "Qwen/Qwen3-VL-8B-Instruct"
133
  processor_q8b = AutoProcessor.from_pretrained(MODEL_ID_Q8B, trust_remote_code=True)
134
  model_q8b = Qwen3VLForConditionalGeneration.from_pretrained(
135
  MODEL_ID_Q8B,
136
- attn_implementation="flash_attention_2",
137
  trust_remote_code=True,
138
  torch_dtype=torch.bfloat16
139
  ).to(device).eval()
140
 
141
- # Load Qwen3-VL-2B-Instruct
142
  MODEL_ID_Q2B = "Qwen/Qwen3-VL-2B-Instruct"
143
  processor_q2b = AutoProcessor.from_pretrained(MODEL_ID_Q2B, trust_remote_code=True)
144
  model_q2b = Qwen3VLForConditionalGeneration.from_pretrained(
145
  MODEL_ID_Q2B,
146
- #attn_implementation="flash_attention_2",
147
  trust_remote_code=True,
148
  torch_dtype=torch.bfloat16
149
  ).to(device).eval()
150
 
151
- # Load Qwen2.5-VL-7B-Instruct
152
  MODEL_ID_M7B = "Qwen/Qwen2.5-VL-7B-Instruct"
153
  processor_m7b = AutoProcessor.from_pretrained(MODEL_ID_M7B, trust_remote_code=True)
154
  model_m7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
155
  MODEL_ID_M7B,
156
- attn_implementation="flash_attention_2",
157
  trust_remote_code=True,
158
  torch_dtype=torch.float16
159
  ).to(device).eval()
160
 
161
- # Load Qwen2.5-VL-3B-Instruct
162
  MODEL_ID_X3B = "Qwen/Qwen2.5-VL-3B-Instruct"
163
  processor_x3b = AutoProcessor.from_pretrained(MODEL_ID_X3B, trust_remote_code=True)
164
  model_x3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
165
  MODEL_ID_X3B,
166
- #attn_implementation="flash_attention_2",
167
  trust_remote_code=True,
168
  torch_dtype=torch.float16
169
  ).to(device).eval()
170
 
171
 
172
- # --- Helper Functions ---
173
-
174
  def select_model(model_name: str):
175
  if model_name == "Qwen3-VL-4B-Instruct":
176
  return processor_q4b, model_q4b
@@ -261,10 +356,51 @@ def navigate_pdf_page(direction: str, state: Dict[str, Any]):
261
  page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
262
  return image_preview, state, page_info_html
263
 
264
- # --- Generation Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- @spaces.GPU
267
- def generate_image(model_name: str, text: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
 
 
 
 
 
 
 
 
268
  if image is None:
269
  yield "Please upload an image.", "Please upload an image."
270
  return
@@ -287,8 +423,11 @@ def generate_image(model_name: str, text: str, image: Image.Image, max_new_token
287
  time.sleep(0.01)
288
  yield buffer, buffer
289
 
290
- @spaces.GPU
291
- def generate_video(model_name: str, text: str, video_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
292
  if video_path is None:
293
  yield "Please upload a video.", "Please upload a video."
294
  return
@@ -315,12 +454,14 @@ def generate_video(model_name: str, text: str, video_path: str, max_new_tokens:
315
  buffer = ""
316
  for new_text in streamer:
317
  buffer += new_text
318
- # buffer = buffer.replace("<|im_end|>", "")
319
  time.sleep(0.01)
320
  yield buffer, buffer
321
 
322
- @spaces.GPU
323
- def generate_pdf(model_name: str, text: str, state: Dict[str, Any], max_new_tokens: int = 2048, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
324
  if not state or not state["pages"]:
325
  yield "Please upload a PDF file first.", "Please upload a PDF file first."
326
  return
@@ -349,8 +490,11 @@ def generate_pdf(model_name: str, text: str, state: Dict[str, Any], max_new_toke
349
  time.sleep(0.01)
350
  full_response += page_header + page_buffer + "\n\n"
351
 
352
- @spaces.GPU
353
- def generate_caption(model_name: str, image: Image.Image, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
354
  if image is None:
355
  yield "Please upload an image to caption.", "Please upload an image to caption."
356
  return
@@ -377,8 +521,11 @@ def generate_caption(model_name: str, image: Image.Image, max_new_tokens: int =
377
  time.sleep(0.01)
378
  yield buffer, buffer
379
 
380
- @spaces.GPU
381
- def generate_gif(model_name: str, text: str, gif_path: str, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2):
 
 
 
382
  if gif_path is None:
383
  yield "Please upload a GIF.", "Please upload a GIF."
384
  return
@@ -404,12 +551,9 @@ def generate_gif(model_name: str, text: str, gif_path: str, max_new_tokens: int
404
  buffer = ""
405
  for new_text in streamer:
406
  buffer += new_text
407
- # buffer = buffer.replace("<|im_end|>", "")
408
  time.sleep(0.01)
409
  yield buffer, buffer
410
 
411
- # --- Examples and Gradio UI ---
412
-
413
  image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
414
  ["Caption the image. Describe the safety measures shown in the image. Conclude whether the situation is (safe or unsafe)...", "examples/images/2.jpg"],
415
  ["Solve the problem...", "examples/images/3.png"]]
@@ -489,27 +633,44 @@ with gr.Blocks() as demo:
489
  label="Select Model",
490
  value="Qwen3-VL-4B-Instruct"
491
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
- # --- Event Handlers ---
494
-
495
  image_submit.click(fn=generate_image,
496
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
497
  outputs=[output, markdown_output])
498
 
499
  video_submit.click(fn=generate_video,
500
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
501
  outputs=[output, markdown_output])
502
 
503
  pdf_submit.click(fn=generate_pdf,
504
- inputs=[model_choice, pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
505
  outputs=[output, markdown_output])
506
 
507
  gif_submit.click(fn=generate_gif,
508
- inputs=[model_choice, gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
509
  outputs=[output, markdown_output])
510
 
511
  caption_submit.click(fn=generate_caption,
512
- inputs=[model_choice, caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
513
  outputs=[output, markdown_output])
514
 
515
  pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])
 
29
  from gradio.themes import Soft
30
  from gradio.themes.utils import colors, fonts, sizes
31
 
 
 
 
32
  colors.orange_red = colors.Color(
33
  name="orange_red",
34
  c50="#FFF0E5",
 
36
  c200="#FFC299",
37
  c300="#FFA366",
38
  c400="#FF8533",
39
+ c500="#FF4500",
40
  c600="#E63E00",
41
  c700="#CC3700",
42
  c800="#B33000",
 
49
  self,
50
  *,
51
  primary_hue: colors.Color | str = colors.gray,
52
+ secondary_hue: colors.Color | str = colors.orange_red,
53
  neutral_hue: colors.Color | str = colors.slate,
54
  text_size: sizes.Size | str = sizes.text_lg,
55
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
95
  block_label_background_fill="*primary_200",
96
  )
97
 
 
98
  orange_red_theme = OrangeRedTheme()
99
 
100
  css = """
 
102
  font-size: 2.3em !important;
103
  }
104
  #output-title h2 {
105
+ font-size: 2.2em !important;
106
+ }
107
+
108
+ /* RadioAnimated Styles */
109
+ .ra-wrap{ width: fit-content; }
110
+ .ra-inner{
111
+ position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
112
+ background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
113
+ }
114
+ .ra-input{ display: none; }
115
+ .ra-label{
116
+ position: relative; z-index: 2; padding: 8px 16px;
117
+ font-family: inherit; font-size: 14px; font-weight: 600;
118
+ color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
119
+ }
120
+ .ra-highlight{
121
+ position: absolute; z-index: 1; top: 6px; left: 6px;
122
+ height: calc(100% - 12px); border-radius: 9999px;
123
+ background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
124
+ transition: transform 0.2s, width 0.2s;
125
+ }
126
+ .ra-input:checked + .ra-label{ color: black; }
127
+
128
+ /* Dark mode adjustments for Radio */
129
+ .dark .ra-inner { background: var(--neutral-800); }
130
+ .dark .ra-label { color: var(--neutral-400); }
131
+ .dark .ra-highlight { background: var(--neutral-600); }
132
+ .dark .ra-input:checked + .ra-label { color: white; }
133
+
134
+ #gpu-duration-container {
135
+ padding: 10px;
136
+ border-radius: 8px;
137
+ background: var(--background-fill-secondary);
138
+ border: 1px solid var(--border-color-primary);
139
+ margin-top: 10px;
140
  }
141
  """
142
 
 
146
 
147
  print("Using device:", device)
148
 
149
+ class RadioAnimated(gr.HTML):
150
+ def __init__(self, choices, value=None, **kwargs):
151
+ if not choices or len(choices) < 2:
152
+ raise ValueError("RadioAnimated requires at least 2 choices.")
153
+ if value is None:
154
+ value = choices[0]
155
+
156
+ uid = uuid.uuid4().hex[:8]
157
+ group_name = f"ra-{uid}"
158
+
159
+ inputs_html = "\n".join(
160
+ f"""
161
+ <input class="ra-input" type="radio" name="{group_name}" id="{group_name}-{i}" value="{c}">
162
+ <label class="ra-label" for="{group_name}-{i}">{c}</label>
163
+ """
164
+ for i, c in enumerate(choices)
165
+ )
166
+
167
+ html_template = f"""
168
+ <div class="ra-wrap" data-ra="{uid}">
169
+ <div class="ra-inner">
170
+ <div class="ra-highlight"></div>
171
+ {inputs_html}
172
+ </div>
173
+ </div>
174
+ """
175
+
176
+ js_on_load = r"""
177
+ (() => {
178
+ const wrap = element.querySelector('.ra-wrap');
179
+ const inner = element.querySelector('.ra-inner');
180
+ const highlight = element.querySelector('.ra-highlight');
181
+ const inputs = Array.from(element.querySelectorAll('.ra-input'));
182
+
183
+ if (!inputs.length) return;
184
+
185
+ const choices = inputs.map(i => i.value);
186
+
187
+ function setHighlightByIndex(idx) {
188
+ const n = choices.length;
189
+ const pct = 100 / n;
190
+ highlight.style.width = `calc(${pct}% - 6px)`;
191
+ highlight.style.transform = `translateX(${idx * 100}%)`;
192
+ }
193
+
194
+ function setCheckedByValue(val, shouldTrigger=false) {
195
+ const idx = Math.max(0, choices.indexOf(val));
196
+ inputs.forEach((inp, i) => { inp.checked = (i === idx); });
197
+ setHighlightByIndex(idx);
198
+
199
+ props.value = choices[idx];
200
+ if (shouldTrigger) trigger('change', props.value);
201
+ }
202
+
203
+ setCheckedByValue(props.value ?? choices[0], false);
204
+
205
+ inputs.forEach((inp) => {
206
+ inp.addEventListener('change', () => {
207
+ setCheckedByValue(inp.value, true);
208
+ });
209
+ });
210
+ })();
211
+ """
212
+
213
+ super().__init__(
214
+ value=value,
215
+ html_template=html_template,
216
+ js_on_load=js_on_load,
217
+ **kwargs
218
+ )
219
+
220
+ def apply_gpu_duration(val: str):
221
+ return int(val)
222
 
 
223
  MODEL_ID_Q4B = "Qwen/Qwen3-VL-4B-Instruct"
224
  processor_q4b = AutoProcessor.from_pretrained(MODEL_ID_Q4B, trust_remote_code=True)
225
  model_q4b = Qwen3VLForConditionalGeneration.from_pretrained(
226
  MODEL_ID_Q4B,
227
+ attn_implementation="kernels-community/flash-attn3",
228
  trust_remote_code=True,
229
  torch_dtype=torch.bfloat16
230
  ).to(device).eval()
231
 
 
232
  MODEL_ID_Q8B = "Qwen/Qwen3-VL-8B-Instruct"
233
  processor_q8b = AutoProcessor.from_pretrained(MODEL_ID_Q8B, trust_remote_code=True)
234
  model_q8b = Qwen3VLForConditionalGeneration.from_pretrained(
235
  MODEL_ID_Q8B,
236
+ attn_implementation="kernels-community/flash-attn3",
237
  trust_remote_code=True,
238
  torch_dtype=torch.bfloat16
239
  ).to(device).eval()
240
 
 
241
  MODEL_ID_Q2B = "Qwen/Qwen3-VL-2B-Instruct"
242
  processor_q2b = AutoProcessor.from_pretrained(MODEL_ID_Q2B, trust_remote_code=True)
243
  model_q2b = Qwen3VLForConditionalGeneration.from_pretrained(
244
  MODEL_ID_Q2B,
245
+ attn_implementation="kernels-community/flash-attn3",
246
  trust_remote_code=True,
247
  torch_dtype=torch.bfloat16
248
  ).to(device).eval()
249
 
 
250
  MODEL_ID_M7B = "Qwen/Qwen2.5-VL-7B-Instruct"
251
  processor_m7b = AutoProcessor.from_pretrained(MODEL_ID_M7B, trust_remote_code=True)
252
  model_m7b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
253
  MODEL_ID_M7B,
254
+ attn_implementation="kernels-community/flash-attn3",
255
  trust_remote_code=True,
256
  torch_dtype=torch.float16
257
  ).to(device).eval()
258
 
 
259
  MODEL_ID_X3B = "Qwen/Qwen2.5-VL-3B-Instruct"
260
  processor_x3b = AutoProcessor.from_pretrained(MODEL_ID_X3B, trust_remote_code=True)
261
  model_x3b = Qwen2_5_VLForConditionalGeneration.from_pretrained(
262
  MODEL_ID_X3B,
263
+ attn_implementation="kernels-community/flash-attn3",
264
  trust_remote_code=True,
265
  torch_dtype=torch.float16
266
  ).to(device).eval()
267
 
268
 
 
 
269
  def select_model(model_name: str):
270
  if model_name == "Qwen3-VL-4B-Instruct":
271
  return processor_q4b, model_q4b
 
356
  page_info_html = f'<div style="text-align:center;">Page {new_index + 1} / {total_pages}</div>'
357
  return image_preview, state, page_info_html
358
 
359
+ def calc_timeout_image(model_name: str, text: str, image: Image.Image,
360
+ max_new_tokens: int, temperature: float, top_p: float,
361
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
362
+ try:
363
+ return int(gpu_timeout)
364
+ except:
365
+ return 60
366
+
367
+ def calc_timeout_video(model_name: str, text: str, video_path: str,
368
+ max_new_tokens: int, temperature: float, top_p: float,
369
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
370
+ try:
371
+ return int(gpu_timeout)
372
+ except:
373
+ return 60
374
+
375
+ def calc_timeout_pdf(model_name: str, text: str, state: Dict[str, Any],
376
+ max_new_tokens: int, temperature: float, top_p: float,
377
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
378
+ try:
379
+ return int(gpu_timeout)
380
+ except:
381
+ return 60
382
+
383
+ def calc_timeout_caption(model_name: str, image: Image.Image,
384
+ max_new_tokens: int, temperature: float, top_p: float,
385
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
386
+ try:
387
+ return int(gpu_timeout)
388
+ except:
389
+ return 60
390
 
391
+ def calc_timeout_gif(model_name: str, text: str, gif_path: str,
392
+ max_new_tokens: int, temperature: float, top_p: float,
393
+ top_k: int, repetition_penalty: float, gpu_timeout: int):
394
+ try:
395
+ return int(gpu_timeout)
396
+ except:
397
+ return 60
398
+
399
+ @spaces.GPU(duration=calc_timeout_image)
400
+ def generate_image(model_name: str, text: str, image: Image.Image,
401
+ max_new_tokens: int = 1024, temperature: float = 0.6,
402
+ top_p: float = 0.9, top_k: int = 50,
403
+ repetition_penalty: float = 1.2, gpu_timeout: int = 60):
404
  if image is None:
405
  yield "Please upload an image.", "Please upload an image."
406
  return
 
423
  time.sleep(0.01)
424
  yield buffer, buffer
425
 
426
+ @spaces.GPU(duration=calc_timeout_video)
427
+ def generate_video(model_name: str, text: str, video_path: str,
428
+ max_new_tokens: int = 1024, temperature: float = 0.6,
429
+ top_p: float = 0.9, top_k: int = 50,
430
+ repetition_penalty: float = 1.2, gpu_timeout: int = 90):
431
  if video_path is None:
432
  yield "Please upload a video.", "Please upload a video."
433
  return
 
454
  buffer = ""
455
  for new_text in streamer:
456
  buffer += new_text
 
457
  time.sleep(0.01)
458
  yield buffer, buffer
459
 
460
+ @spaces.GPU(duration=calc_timeout_pdf)
461
+ def generate_pdf(model_name: str, text: str, state: Dict[str, Any],
462
+ max_new_tokens: int = 2048, temperature: float = 0.6,
463
+ top_p: float = 0.9, top_k: int = 50,
464
+ repetition_penalty: float = 1.2, gpu_timeout: int = 120):
465
  if not state or not state["pages"]:
466
  yield "Please upload a PDF file first.", "Please upload a PDF file first."
467
  return
 
490
  time.sleep(0.01)
491
  full_response += page_header + page_buffer + "\n\n"
492
 
493
+ @spaces.GPU(duration=calc_timeout_caption)
494
+ def generate_caption(model_name: str, image: Image.Image,
495
+ max_new_tokens: int = 1024, temperature: float = 0.6,
496
+ top_p: float = 0.9, top_k: int = 50,
497
+ repetition_penalty: float = 1.2, gpu_timeout: int = 60):
498
  if image is None:
499
  yield "Please upload an image to caption.", "Please upload an image to caption."
500
  return
 
521
  time.sleep(0.01)
522
  yield buffer, buffer
523
 
524
+ @spaces.GPU(duration=calc_timeout_gif)
525
+ def generate_gif(model_name: str, text: str, gif_path: str,
526
+ max_new_tokens: int = 1024, temperature: float = 0.6,
527
+ top_p: float = 0.9, top_k: int = 50,
528
+ repetition_penalty: float = 1.2, gpu_timeout: int = 90):
529
  if gif_path is None:
530
  yield "Please upload a GIF.", "Please upload a GIF."
531
  return
 
551
  buffer = ""
552
  for new_text in streamer:
553
  buffer += new_text
 
554
  time.sleep(0.01)
555
  yield buffer, buffer
556
 
 
 
557
  image_examples = [["Perform OCR on the image...", "examples/images/1.jpg"],
558
  ["Caption the image. Describe the safety measures shown in the image. Conclude whether the situation is (safe or unsafe)...", "examples/images/2.jpg"],
559
  ["Solve the problem...", "examples/images/3.png"]]
 
633
  label="Select Model",
634
  value="Qwen3-VL-4B-Instruct"
635
  )
636
+
637
+ with gr.Row(elem_id="gpu-duration-container"):
638
+ with gr.Column():
639
+ gr.Markdown("**GPU Duration (seconds)**")
640
+ radioanimated_gpu_duration = RadioAnimated(
641
+ choices=["60", "90", "120", "180", "240", "300"],
642
+ value="60",
643
+ elem_id="radioanimated_gpu_duration"
644
+ )
645
+ gpu_duration_state = gr.Number(value=60, visible=False)
646
+
647
+ gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
648
+
649
+ radioanimated_gpu_duration.change(
650
+ fn=apply_gpu_duration,
651
+ inputs=radioanimated_gpu_duration,
652
+ outputs=[gpu_duration_state],
653
+ api_visibility="private"
654
+ )
655
 
 
 
656
  image_submit.click(fn=generate_image,
657
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
658
  outputs=[output, markdown_output])
659
 
660
  video_submit.click(fn=generate_video,
661
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
662
  outputs=[output, markdown_output])
663
 
664
  pdf_submit.click(fn=generate_pdf,
665
+ inputs=[model_choice, pdf_query, pdf_state, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
666
  outputs=[output, markdown_output])
667
 
668
  gif_submit.click(fn=generate_gif,
669
+ inputs=[model_choice, gif_query, gif_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
670
  outputs=[output, markdown_output])
671
 
672
  caption_submit.click(fn=generate_caption,
673
+ inputs=[model_choice, caption_image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_duration_state],
674
  outputs=[output, markdown_output])
675
 
676
  pdf_upload.change(fn=load_and_preview_pdf, inputs=[pdf_upload], outputs=[pdf_preview_img, pdf_state, page_info])