jk12p commited on
Commit
1ac7b06
Β·
verified Β·
1 Parent(s): 5dec005

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -1,20 +1,24 @@
1
  import streamlit as st
2
  import torch
3
  import fitz # PyMuPDF
 
 
4
  from sentence_transformers import SentenceTransformer
5
  import faiss
6
  import numpy as np
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
  # --- CONFIG ---
10
- HF_TOKEN = "your_huggingface_token_here" # Add your Hugging Face token
 
11
 
12
  # Load tokenizer and model with optimizations
13
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
14
  model = AutoModelForCausalLM.from_pretrained(
15
  "google/gemma-2b-it",
16
- torch_dtype=torch.float16, # Use half-precision for less memory
17
- device_map="auto" # This will place the model on the best device (CPU/GPU)
 
18
  )
19
 
20
  # Load sentence transformer model for embedding generation
@@ -25,7 +29,6 @@ st.title("πŸ” RAG App using πŸ€– Gemma 2B")
25
 
26
  uploaded_file = st.file_uploader("πŸ“„ Upload a PDF or TXT file", type=["pdf", "txt"])
27
 
28
- # Extract text from file (PDF/TXT)
29
  def extract_text(file):
30
  text = ""
31
  if file.type == "application/pdf":
@@ -36,11 +39,9 @@ def extract_text(file):
36
  text = file.read().decode("utf-8")
37
  return text
38
 
39
- # Split text into chunks for indexing
40
  def split_into_chunks(text, chunk_size=500):
41
  return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
42
 
43
- # Create FAISS index for fast retrieval
44
  def create_faiss_index(chunks):
45
  embeddings = embedder.encode(chunks)
46
  dim = embeddings.shape[1]
@@ -48,13 +49,11 @@ def create_faiss_index(chunks):
48
  index.add(np.array(embeddings))
49
  return index, embeddings
50
 
51
- # Retrieve top-k relevant chunks for the query
52
  def retrieve_chunks(query, chunks, index, embeddings, k=3):
53
  query_embedding = embedder.encode([query])
54
  D, I = index.search(np.array(query_embedding), k)
55
  return [chunks[i] for i in I[0]]
56
 
57
- # --- MAIN LOGIC ---
58
  if uploaded_file:
59
  st.success("βœ… File uploaded successfully!")
60
  raw_text = extract_text(uploaded_file)
@@ -70,10 +69,12 @@ if uploaded_file:
70
  with st.spinner("Thinking..."):
71
  context = "\n".join(retrieve_chunks(user_question, chunks, index, embeddings))
72
 
73
- # Generate response from Gemma 2B
74
- input_ids = tokenizer.encode(f"Answer the question based on the context below:\n\nContext:\n{context}\n\nQuestion: {user_question}\nAnswer:", return_tensors="pt").to(model.device)
 
 
75
 
76
- with torch.no_grad(): # Disable gradient computation for inference
77
  outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, temperature=0.7)
78
 
79
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
1
  import streamlit as st
2
  import torch
3
  import fitz # PyMuPDF
4
+ import os
5
+ from dotenv import load_dotenv
6
  from sentence_transformers import SentenceTransformer
7
  import faiss
8
  import numpy as np
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
  # --- CONFIG ---
12
+ load_dotenv()
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
 
15
  # Load tokenizer and model with optimizations
16
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", token=HF_TOKEN)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  "google/gemma-2b-it",
19
+ torch_dtype=torch.float16,
20
+ device_map="auto",
21
+ token=HF_TOKEN
22
  )
23
 
24
  # Load sentence transformer model for embedding generation
 
29
 
30
  uploaded_file = st.file_uploader("πŸ“„ Upload a PDF or TXT file", type=["pdf", "txt"])
31
 
 
32
  def extract_text(file):
33
  text = ""
34
  if file.type == "application/pdf":
 
39
  text = file.read().decode("utf-8")
40
  return text
41
 
 
42
  def split_into_chunks(text, chunk_size=500):
43
  return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
44
 
 
45
  def create_faiss_index(chunks):
46
  embeddings = embedder.encode(chunks)
47
  dim = embeddings.shape[1]
 
49
  index.add(np.array(embeddings))
50
  return index, embeddings
51
 
 
52
  def retrieve_chunks(query, chunks, index, embeddings, k=3):
53
  query_embedding = embedder.encode([query])
54
  D, I = index.search(np.array(query_embedding), k)
55
  return [chunks[i] for i in I[0]]
56
 
 
57
  if uploaded_file:
58
  st.success("βœ… File uploaded successfully!")
59
  raw_text = extract_text(uploaded_file)
 
69
  with st.spinner("Thinking..."):
70
  context = "\n".join(retrieve_chunks(user_question, chunks, index, embeddings))
71
 
72
+ input_ids = tokenizer.encode(
73
+ f"Answer the question based on the context below:\n\nContext:\n{context}\n\nQuestion: {user_question}\nAnswer:",
74
+ return_tensors="pt"
75
+ ).to(model.device)
76
 
77
+ with torch.no_grad():
78
  outputs = model.generate(input_ids, max_length=512, num_return_sequences=1, temperature=0.7)
79
 
80
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)