nikitaaaswsdwdw commited on
Commit
5bb063f
·
verified ·
1 Parent(s): ba23b99
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ # 1. Load a model that fits in the free tier
6
+ # 'flan-t5-large' is powerful but small enough for the free CPU tier
7
+ model_name = 'google/flan-t5-large'
8
+
9
+ print("Loading model... this may take a minute.")
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
+
13
+ # Use GPU if available, otherwise CPU
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ model = model.to(device)
16
+
17
+ # 2. Define the generation function
18
+ def generate_text(task_prefix, input_text):
19
+ # Dynamic prompt construction
20
+ prompt = f"{task_prefix}: {input_text}"
21
+
22
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
23
+
24
+ # Generate output
25
+ output_ids = model.generate(
26
+ **inputs,
27
+ max_length=64,
28
+ num_beams=4,
29
+ early_stopping=True,
30
+ no_repeat_ngram_size=2
31
+ )
32
+
33
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
34
+
35
+ # 3. Create the Web Interface
36
+ # We map the choices to the actual prefixes the model understands
37
+ task_choices = [
38
+ "summarize",
39
+ "translate French to English",
40
+ "paraphrase",
41
+ "generate question",
42
+ "sst2 sentence"
43
+ ]
44
+
45
+ demo = gr.Interface(
46
+ fn=generate_text,
47
+ inputs=[
48
+ gr.Dropdown(choices=task_choices, label="Select Task", value="summarize"),
49
+ gr.Textbox(label="Input Text", placeholder="Enter your text here...")
50
+ ],
51
+ outputs=gr.Textbox(label="AI Output"),
52
+ title="Multi-Task AI Generator",
53
+ description="Select a task and enter text. Powered by Google Flan-T5."
54
+ )
55
+
56
+ # 4. Launch
57
+ demo.launch()