rxpbtn21 commited on
Commit
dcb3064
·
verified ·
1 Parent(s): 10ba581

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -4,20 +4,30 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from peft import PeftModel
5
  import torch
6
 
 
 
 
7
  # Load the tokenizer
8
  tokenizer = AutoTokenizer.from_pretrained("rxpbtn21/t5-small-lora-summarizer")
9
 
10
  # Load the base model and then the LoRA adapter
11
- base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="auto")
 
12
  model = PeftModel.from_pretrained(base_model, "rxpbtn21/t5-small-lora-summarizer")
13
  model.eval()
14
 
15
  def summarize(text):
16
- inputs = tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
17
- with torch.no_grad():
18
- outputs = model.generate(inputs["input_ids"].to(model.device), num_beams=4, max_new_tokens=128, early_stopping=True)
19
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
20
- return summary
 
 
 
 
 
 
21
 
22
  # Create Gradio interface
23
  iface = gr.Interface(fn=summarize, inputs="text", outputs="text", title="LoRA Fine-tuned T5-small Summarizer")
 
4
  from peft import PeftModel
5
  import torch
6
 
7
+ # Determine device
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
  # Load the tokenizer
11
  tokenizer = AutoTokenizer.from_pretrained("rxpbtn21/t5-small-lora-summarizer")
12
 
13
  # Load the base model and then the LoRA adapter
14
+ # Ensure the base model is also moved to the correct device
15
+ base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(device)
16
  model = PeftModel.from_pretrained(base_model, "rxpbtn21/t5-small-lora-summarizer")
17
  model.eval()
18
 
19
  def summarize(text):
20
+ try:
21
+ inputs = tokenizer(text, max_length=512, truncation=True, return_tensors="pt")
22
+ with torch.no_grad():
23
+ # Ensure inputs are on the same device as the model
24
+ outputs = model.generate(inputs["input_ids"].to(device), num_beams=4, max_new_tokens=128, early_stopping=True)
25
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+ return summary
27
+ except Exception as e:
28
+ # Log the error and return an informative message
29
+ print(f"Error during summarization: {e}")
30
+ return f"An error occurred during summarization. Please check the Space logs for details. Error: {e}"
31
 
32
  # Create Gradio interface
33
  iface = gr.Interface(fn=summarize, inputs="text", outputs="text", title="LoRA Fine-tuned T5-small Summarizer")