absiitr commited on
Commit
d2e1b02
Β·
verified Β·
1 Parent(s): bf8b348

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +116 -54
backend.py CHANGED
@@ -3,141 +3,203 @@ import tempfile
3
  import gc
4
  import logging
5
  from fastapi import FastAPI, UploadFile, File, HTTPException
6
- from fastapi.middleware.cors import CORSMiddleware # <-- ADDED
7
  from pydantic import BaseModel
8
  import torch
9
- from dotenv import load_dotenv
 
 
10
  from groq import Groq, APIError
 
 
11
  from langchain_community.document_loaders import PyPDFLoader
12
  from langchain_text_splitters import RecursiveCharacterTextSplitter
13
  from langchain_community.embeddings import HuggingFaceEmbeddings
14
  from langchain_community.vectorstores import Chroma
15
 
16
- # ---------------- Setup ----------------
17
  logging.basicConfig(level=logging.INFO)
 
 
18
  load_dotenv()
19
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
 
20
  GROQ_MODEL = "llama-3.1-8b-instant"
 
 
21
  client = None
22
- if GROQ_API_KEY:
 
 
 
23
  try:
24
  client = Groq(api_key=GROQ_API_KEY)
25
- logging.info("βœ… Groq client initialized")
26
  except Exception as e:
27
- logging.error(f"Groq init failed: {e}")
 
28
 
29
  app = FastAPI()
30
 
31
- # ==================================================#
32
- # CORS Middleware (NEW SECTION)
33
- # ==================================================
34
- origins = [
35
- "*", # Allow all origins for deployment on HF Spaces
36
- ]
37
-
38
- app.add_middleware(
39
- CORSMiddleware,
40
- allow_origins=origins,
41
- allow_credentials=True,
42
- allow_methods=["*"],
43
- allow_headers=["*"],
44
- )
45
- # ==================================================#
46
-
47
  retriever = None
48
  vectorstore = None
49
 
 
 
50
  class Query(BaseModel):
51
  question: str
52
 
53
- # ==================================================#
54
- # PDF Upload
55
  # ==================================================
56
- @app.post("/api/upload")
 
 
57
  async def upload_pdf(file: UploadFile = File(...)):
 
58
  global retriever, vectorstore
 
59
  if not file.filename.endswith(".pdf"):
60
  raise HTTPException(400, "Only PDF files allowed")
 
61
  if not client:
62
- raise HTTPException(500, "Groq API key missing")
 
63
  path = None
64
  try:
 
65
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
66
  tmp.write(await file.read())
67
  path = tmp.name
 
 
 
 
68
  loader = PyPDFLoader(path)
69
  docs = loader.load()
 
 
70
  splitter = RecursiveCharacterTextSplitter(
71
  chunk_size=800,
72
  chunk_overlap=50
73
  )
74
  chunks = splitter.split_documents(docs)
 
 
75
  embeddings = HuggingFaceEmbeddings(
76
  model_name="sentence-transformers/all-MiniLM-L6-v2",
77
  model_kwargs={"device": "cpu"},
78
  encode_kwargs={"normalize_embeddings": True}
79
  )
 
 
80
  if vectorstore:
81
  del vectorstore
82
  gc.collect()
 
 
83
  vectorstore = Chroma.from_documents(chunks, embeddings)
 
84
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
 
 
 
85
  return {"message": "PDF processed", "chunks": len(chunks)}
 
86
  except Exception as e:
87
- raise HTTPException(500, str(e))
 
88
  finally:
 
89
  if path and os.path.exists(path):
90
  os.unlink(path)
91
  gc.collect()
92
 
93
- # ==================================================#
94
- # Ask Question
 
95
  # ==================================================
96
- @app.post("/api/ask")
97
  async def ask(req: Query):
98
- if not retriever:
99
- raise HTTPException(400, "Upload PDF first")
 
 
 
 
 
 
100
  try:
 
101
  docs = retriever.invoke(req.question)
 
102
  context = "\n\n".join(d.page_content for d in docs)
103
- prompt = f"""Use ONLY the context below.If answer not found, say: "I cannot find this in the PDF."CONTEXT:{context}QUESTION: {req.question}ANSWER:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  response = client.chat.completions.create(
105
  model=GROQ_MODEL,
106
  messages=[
107
- {"role": "system", "content": "Answer strictly from PDF context"},
 
108
  {"role": "user", "content": prompt}
109
  ],
110
  temperature=0.0
111
  )
112
- return {
113
- "answer": response.choices[0].message.content.strip(),
114
- "sources": len(docs)
115
- }
116
  except APIError as e:
117
- raise HTTPException(500, str(e))
 
 
 
 
 
 
118
 
119
- # ==================================================#
120
- # Clear Memory
121
  # ==================================================
122
- @app.post("/api/clear")
 
 
 
 
 
 
 
 
 
 
 
 
123
  async def clear():
 
124
  global retriever, vectorstore
 
 
125
  if vectorstore:
126
  del vectorstore
127
  retriever = None
128
  vectorstore = None
 
129
  gc.collect()
 
130
  if torch.cuda.is_available():
131
  torch.cuda.empty_cache()
132
- return {"message": "Memory cleared"}
133
 
134
- # ==================================================#
135
- # Health
136
- # ==================================================
137
- @app.get("/api/health")
138
- async def health():
139
- return {
140
- "status": "running",
141
- "pdf_loaded": retriever is not None,
142
- "groq_client_ok": client is not None
143
- }
 
3
  import gc
4
  import logging
5
  from fastapi import FastAPI, UploadFile, File, HTTPException
 
6
  from pydantic import BaseModel
7
  import torch
8
+ from dotenv import load_dotenv # Used to load API key from .env file
9
+
10
+ # ---------------- Groq API ----------------
11
  from groq import Groq, APIError
12
+
13
+ # ---------------- LangChain ----------------
14
  from langchain_community.document_loaders import PyPDFLoader
15
  from langchain_text_splitters import RecursiveCharacterTextSplitter
16
  from langchain_community.embeddings import HuggingFaceEmbeddings
17
  from langchain_community.vectorstores import Chroma
18
 
19
+ # --- Configuration & Setup ---
20
  logging.basicConfig(level=logging.INFO)
21
+
22
+ # 1. Load environment variables from .env file
23
  load_dotenv()
24
+
25
+ # 2. Load API Key from Environment Variable
26
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
27
  GROQ_MODEL = "llama-3.1-8b-instant"
28
+
29
+ # 3. Initialize Groq Client
30
  client = None
31
+ if not GROQ_API_KEY:
32
+ logging.error(
33
+ "❌ GROQ_API_KEY is not set in the environment or the .env file. The service will run but cannot answer questions.")
34
+ else:
35
  try:
36
  client = Groq(api_key=GROQ_API_KEY)
37
+ logging.info("βœ… Groq client initialized successfully.")
38
  except Exception as e:
39
+ logging.error(f"❌ Failed to initialize Groq client: {e}")
40
+ client = None
41
 
42
  app = FastAPI()
43
 
44
+ # Global state for RAG components
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  retriever = None
46
  vectorstore = None
47
 
48
+
49
+ # ---------------- Input Schema ----------------
50
  class Query(BaseModel):
51
  question: str
52
 
53
+
 
54
  # ==================================================
55
+ # PDF Upload β†’ Chunk β†’ Embed β†’ Vectorstore
56
+ # ==================================================
57
+ @app.post("/upload")
58
  async def upload_pdf(file: UploadFile = File(...)):
59
+ """Handles PDF upload, processing, chunking, embedding, and vectorstore creation."""
60
  global retriever, vectorstore
61
+
62
  if not file.filename.endswith(".pdf"):
63
  raise HTTPException(400, "Only PDF files allowed")
64
+
65
  if not client:
66
+ raise HTTPException(500, "Service not fully initialized. Groq API key is missing or invalid.")
67
+
68
  path = None
69
  try:
70
+ # 1. Save file temporarily
71
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
72
  tmp.write(await file.read())
73
  path = tmp.name
74
+
75
+ logging.info(f"Processing PDF: {path}")
76
+
77
+ # 2. Load
78
  loader = PyPDFLoader(path)
79
  docs = loader.load()
80
+
81
+ # 3. Split
82
  splitter = RecursiveCharacterTextSplitter(
83
  chunk_size=800,
84
  chunk_overlap=50
85
  )
86
  chunks = splitter.split_documents(docs)
87
+
88
+ # 4. Embeddings (Using CPU-friendly model)
89
  embeddings = HuggingFaceEmbeddings(
90
  model_name="sentence-transformers/all-MiniLM-L6-v2",
91
  model_kwargs={"device": "cpu"},
92
  encode_kwargs={"normalize_embeddings": True}
93
  )
94
+
95
+ # 5. Clear previous vectorstore to free memory
96
  if vectorstore:
97
  del vectorstore
98
  gc.collect()
99
+
100
+ # 6. Create Vectorstore and Retriever
101
  vectorstore = Chroma.from_documents(chunks, embeddings)
102
+ # Search for 3 most relevant chunks
103
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
104
+
105
+ logging.info(f"PDF processed. Chunks created: {len(chunks)}")
106
+
107
  return {"message": "PDF processed", "chunks": len(chunks)}
108
+
109
  except Exception as e:
110
+ logging.error(f"Error during PDF processing: {e}")
111
+ raise HTTPException(500, f"Error: {str(e)}")
112
  finally:
113
+ # 7. Cleanup temp file and memory
114
  if path and os.path.exists(path):
115
  os.unlink(path)
116
  gc.collect()
117
 
118
+
119
+ # ==================================================
120
+ # ASK β†’ RETRIEVE β†’ GROQ β†’ ANSWER
121
  # ==================================================
122
+ @app.post("/ask")
123
  async def ask(req: Query):
124
+ global retriever
125
+
126
+ if client is None:
127
+ raise HTTPException(500, "Groq client is not initialized. Check API key setup.")
128
+
129
+ if retriever is None:
130
+ raise HTTPException(400, "Upload PDF first to initialize the knowledge base.")
131
+
132
  try:
133
+ # 1. Retrieve relevant chunks (NEW LangChain API)
134
  docs = retriever.invoke(req.question)
135
+
136
  context = "\n\n".join(d.page_content for d in docs)
137
+
138
+ # 2. Build prompt
139
+ prompt = f"""
140
+ You are a strict RAG Q&A assistant.
141
+ Use ONLY the context provided. If the answer is not found, reply:
142
+ "I cannot find this in the PDF."
143
+
144
+ ---------------- CONTEXT ----------------
145
+ {context}
146
+ -----------------------------------------
147
+
148
+ QUESTION: {req.question}
149
+
150
+ FINAL ANSWER:
151
+ """
152
+
153
+ # 3. Call Groq
154
  response = client.chat.completions.create(
155
  model=GROQ_MODEL,
156
  messages=[
157
+ {"role": "system",
158
+ "content": "Use only the PDF content. If answer not found, say: 'I cannot find this in the PDF.'"},
159
  {"role": "user", "content": prompt}
160
  ],
161
  temperature=0.0
162
  )
163
+
164
+ answer = response.choices[0].message.content.strip()
165
+ return {"answer": answer, "sources": len(docs)}
166
+
167
  except APIError as e:
168
+ logging.error(f"Groq API Error: {e}")
169
+ raise HTTPException(500, f"Groq API Error: {str(e)}")
170
+
171
+ except Exception as e:
172
+ logging.error(f"General error in /ask: {e}")
173
+ raise HTTPException(500, f"General error: {str(e)}")
174
+
175
 
 
 
176
  # ==================================================
177
+ # HEALTH & CLEAR
178
+ # ==================================================
179
+ @app.get("/health")
180
+ async def health():
181
+ """Endpoint for checking service status."""
182
+ return {
183
+ "status": "running",
184
+ "pdf_loaded": retriever is not None,
185
+ "groq_client_ok": client is not None
186
+ }
187
+
188
+
189
+ @app.post("/clear")
190
  async def clear():
191
+ """Clears the current RAG components from memory."""
192
  global retriever, vectorstore
193
+
194
+ # Explicitly clear objects
195
  if vectorstore:
196
  del vectorstore
197
  retriever = None
198
  vectorstore = None
199
+
200
  gc.collect()
201
+ # Clear CUDA cache if running on a machine with a GPU (good practice)
202
  if torch.cuda.is_available():
203
  torch.cuda.empty_cache()
 
204
 
205
+ return {"message": "Memory cleared. Upload a new PDF."}