from __future__ import annotations import argparse import os import sys import warnings import pickle from pathlib import Path from typing import List, Dict, Any, Tuple, Optional import hashlib import re from dataclasses import dataclass os.environ.setdefault("TRANSFORMERS_NO_TF", "1") os.environ.setdefault("USE_TF", "0") os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") # Silence warnings warnings.filterwarnings("ignore") try: from langchain_core._api import LangChainDeprecationWarning warnings.filterwarnings("ignore", category=LangChainDeprecationWarning) except Exception: pass from dotenv import load_dotenv from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.document_loaders import PyPDFLoader from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from langchain_groq import ChatGroq # Optional hybrid and rerankers from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever # Cross encoder is optional try: from sentence_transformers import CrossEncoder _HAS_CE = True except Exception: _HAS_CE = False load_dotenv() @dataclass class RetrievalConfig: use_hybrid: bool = True use_mmr: bool = True use_reranker: bool = True mmr_fetch_k: int = 50 mmr_lambda: float = 0.5 top_k: int = 8 neighbor_window: int = 1 # include adjacent pages for continuity class DocumentStore: """Manages document loading, chunking, and vector storage.""" def __init__( self, persist_dir: Path, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", chunk_size: int = 800, chunk_overlap: int = 200, ): self.persist_dir = persist_dir self.persist_dir.mkdir(parents=True, exist_ok=True) self.embedding_model_name = embedding_model self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.vector_store_path = self.persist_dir / "faiss_index" self.metadata_path = self.persist_dir / "metadata.pkl" self.chunks_path = self.persist_dir / "chunks.pkl" print(f"Initializing embedding model: {embedding_model}") self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model, model_kwargs={"device": "cpu"}, encode_kwargs={ "normalize_embeddings": True, "batch_size": 8, # Reduced from 32 to prevent hanging }, ) print("Embedding model loaded") self.vector_store: Optional[FAISS] = None self.metadata: Dict[str, Any] = {} self.chunks: List[Document] = [] self.page_counts: Dict[str, int] = {} def _fast_file_hash(self, path: Path, sample_bytes: int = 1_000_000) -> bytes: h = hashlib.sha256() try: with open(path, "rb") as f: h.update(f.read(sample_bytes)) except Exception: h.update(b"") return h.digest() def _compute_source_hash(self, pdf_paths: List[Path]) -> str: """Compute hash of PDF files to detect changes. Uses path, mtime, and a sample of content.""" hasher = hashlib.sha256() for pdf_path in sorted(pdf_paths): hasher.update(str(pdf_path).encode()) if pdf_path.exists(): hasher.update(str(pdf_path.stat().st_mtime).encode()) hasher.update(self._fast_file_hash(pdf_path)) return hasher.hexdigest() def discover_pdfs(self, source: Path) -> List[Path]: """Find all PDF files in source path.""" print(f"\nSearching for PDFs in: {source.absolute()}") if source.is_file() and source.suffix.lower() == ".pdf": print(f"Found single PDF: {source.name}") return [source] if source.is_dir(): pdfs = sorted(path for path in source.glob("*.pdf") if path.is_file()) if not pdfs: pdfs = sorted(path for path in source.glob("**/*.pdf") if path.is_file()) if pdfs: print(f"Found {len(pdfs)} PDF(s):") for pdf in pdfs: size_mb = pdf.stat().st_size / (1024 * 1024) print(f" - {pdf.name} ({size_mb:.2f} MB)") return pdfs else: raise FileNotFoundError(f"No PDF files found in {source}") raise FileNotFoundError(f"Path does not exist: {source}") def _load_pages(self, pdf_path: Path) -> List[Document]: loader = PyPDFLoader(str(pdf_path)) docs = loader.load() for doc in docs: doc.metadata["source"] = pdf_path.name doc.metadata["source_path"] = str(pdf_path) return docs def load_and_split_documents(self, pdf_paths: List[Path]) -> List[Document]: """Load PDFs and split into chunks.""" print(f"\nLoading and processing documents...") all_page_docs: List[Document] = [] total_pages = 0 self.page_counts = {} for pdf_path in pdf_paths: try: print(f" Loading: {pdf_path.name}...", end=" ", flush=True) page_docs = self._load_pages(pdf_path) all_page_docs.extend(page_docs) total_pages += len(page_docs) self.page_counts[pdf_path.name] = len(page_docs) print(f"{len(page_docs)} pages") except Exception as e: print(f"Error: {e}") continue if not all_page_docs: raise ValueError("Failed to load any documents") print(f"Loaded {total_pages} pages from {len(pdf_paths)} document(s)") # Split into chunks print(f"\nSplitting into chunks (size={self.chunk_size}, overlap={self.chunk_overlap})...") text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separators=["\n\n", "\n", ". ", "? ", "! ", "; ", ", ", " ", ""], length_function=len, ) chunks = text_splitter.split_documents(all_page_docs) print(f"Created {len(chunks)} chunks") # Show sample if chunks: sample = chunks[0] preview = sample.page_content[:200].replace("\n", " ") print(f"\nSample chunk:") print(f" Source: {sample.metadata.get('source', 'unknown')}") print(f" Page: {sample.metadata.get('page', 'unknown')}") print(f" Preview: {preview}...") return chunks def build_vector_store(self, pdf_paths: List[Path], force_rebuild: bool = False): """Build or load vector store and persist chunks for hybrid retrieval.""" source_hash = self._compute_source_hash(pdf_paths) if ( not force_rebuild and self.vector_store_path.exists() and self.metadata_path.exists() and self.chunks_path.exists() ): try: with open(self.metadata_path, "rb") as f: saved_metadata = pickle.load(f) if saved_metadata.get("source_hash") == source_hash: print("\nLoading existing vector store...") self.vector_store = FAISS.load_local( str(self.vector_store_path), self.embeddings, allow_dangerous_deserialization=True, ) with open(self.chunks_path, "rb") as f: self.chunks = pickle.load(f) self.metadata = saved_metadata self.page_counts = saved_metadata.get("page_counts", {}) print(f"Loaded vector store with {saved_metadata.get('chunk_count', 0)} chunks") return else: print("\nSource files changed, rebuilding vector store...") except Exception as e: print(f"\nCould not load existing store: {e}") print("Building new vector store...") print("\nBuilding new vector store...") chunks = self.load_and_split_documents(pdf_paths) if not chunks: raise ValueError("No chunks created from documents") print(f"Creating embeddings for {len(chunks)} chunks...") self.vector_store = FAISS.from_documents(chunks, self.embeddings) print("Saving vector store to disk...") self.vector_store.save_local(str(self.vector_store_path)) with open(self.chunks_path, "wb") as f: pickle.dump(chunks, f) self.chunks = chunks self.metadata = { "source_hash": source_hash, "chunk_count": len(chunks), "pdf_files": [str(p) for p in pdf_paths], "embedding_model": self.embedding_model_name, "page_counts": self.page_counts, } with open(self.metadata_path, "wb") as f: pickle.dump(self.metadata, f) print(f"Vector store built and saved with {len(chunks)} chunks") def _build_bm25(self) -> BM25Retriever: if not self.chunks: if self.chunks_path.exists(): with open(self.chunks_path, "rb") as f: self.chunks = pickle.load(f) else: raise ValueError("Chunks not available to build BM25") bm25 = BM25Retriever.from_documents(self.chunks) bm25.k = 20 return bm25 def get_retriever(self, cfg: RetrievalConfig): """Get a retriever. Hybrid BM25 plus FAISS with MMR if requested.""" if self.vector_store is None: raise ValueError("Vector store not initialized. Call build_vector_store first.") if cfg.use_mmr: faiss_ret = self.vector_store.as_retriever( search_type="mmr", search_kwargs={"k": max(cfg.top_k, 20), "fetch_k": cfg.mmr_fetch_k, "lambda_mult": cfg.mmr_lambda}, ) else: faiss_ret = self.vector_store.as_retriever( search_type="similarity", search_kwargs={"k": max(cfg.top_k, 20)}, ) if cfg.use_hybrid: bm25 = self._build_bm25() hybrid = EnsembleRetriever(retrievers=[bm25, faiss_ret], weights=[0.55, 0.45]) return hybrid return faiss_ret def get_page_count(self, source_name: str) -> Optional[int]: return self.page_counts.get(source_name) class RAGPipeline: """RAG pipeline with hybrid retrieval, multi-query, reranking, neighbor expansion, and task routing.""" def __init__( self, doc_store: DocumentStore, model: str = "llama-3.1-8b-instant", temperature: float = 0.1, max_tokens: int = 4096, top_k: int = 8, use_hybrid: bool = True, use_mmr: bool = True, use_reranker: bool = True, neighbor_window: int = 1, ): self.doc_store = doc_store self.model = model self.temperature = temperature self.max_tokens = max_tokens self.cfg = RetrievalConfig( use_hybrid=use_hybrid, use_mmr=use_mmr, use_reranker=use_reranker and _HAS_CE, top_k=top_k, neighbor_window=neighbor_window, ) print(f"\nInitializing RAG pipeline") print(f" Model: {model}") print(f" Temperature: {temperature}") print(f" Retrieval Top-K: {top_k}") print(f" Hybrid: {self.cfg.use_hybrid} MMR: {self.cfg.use_mmr} Rerank: {self.cfg.use_reranker}") self.retriever = doc_store.get_retriever(self.cfg) self.llm = ChatGroq(model=model, temperature=temperature, max_tokens=max_tokens) self.reranker = None if self.cfg.use_reranker: try: self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device="cpu") print("Cross-encoder reranker loaded") except Exception as e: print(f"Could not load cross-encoder reranker: {e}") self.reranker = None self.chain = self._build_chain() print("RAG pipeline ready") # -------- Retrieval helpers -------- def _multi_query_variants(self, question: str, n: int = 3) -> List[str]: prompt = PromptTemplate.from_template( "Produce {n} different short search queries that target the same information need.\n" "Input: {q}\n" "Output one per line, no numbering." ) text = (prompt | self.llm | StrOutputParser()).invoke({"q": question, "n": n}) variants = [ln.strip("- ").strip() for ln in text.splitlines() if ln.strip()] # Always include the original question first uniq = [] for s in [question] + variants: if s not in uniq: uniq.append(s) return uniq @staticmethod def _dedupe_by_source_page(docs: List[Document]) -> List[Document]: seen = set() out = [] for d in docs: key = (d.metadata.get("source"), d.metadata.get("page")) if key not in seen: seen.add(key) out.append(d) return out def _neighbor_expand(self, docs: List[Document], window: int) -> List[Document]: if window <= 0: return docs # Build a lookup of page docs by source and page from the persisted chunks if not self.doc_store.chunks: return docs page_map: Dict[Tuple[str, int], List[Document]] = {} for ch in self.doc_store.chunks: src = ch.metadata.get("source") page = ch.metadata.get("page") if isinstance(src, str) and isinstance(page, int): page_map.setdefault((src, page), []).append(ch) expanded = list(docs) for d in docs: src = d.metadata.get("source") page = d.metadata.get("page") if not isinstance(src, str) or not isinstance(page, int): continue for p in range(page - window, page + window + 1): if (src, p) in page_map: expanded.extend(page_map[(src, p)]) return self._dedupe_by_source_page(expanded) def _rerank(self, question: str, docs: List[Document], top_n: int) -> List[Document]: if not self.reranker or not docs: return docs[:top_n] pairs = [[question, d.page_content] for d in docs] scores = self.reranker.predict(pairs) ranked = [d for _, d in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)] return ranked[:top_n] def _retrieve(self, question: str) -> List[Document]: variants = self._multi_query_variants(question, n=3) candidates: List[Document] = [] for q in variants: # retriever is Runnable, so use invoke try: res = self.retriever.invoke(q) except AttributeError: # fallback if retriever does not implement invoke res = self.retriever.get_relevant_documents(q) candidates.extend(res) docs = self._dedupe_by_source_page(candidates) docs = self._neighbor_expand(docs, self.cfg.neighbor_window) docs = self._rerank(question, docs, self.cfg.top_k) return docs # -------- Chains -------- def _format_docs(self, docs: List[Document]) -> str: if not docs: return "No relevant information found in the provided documents." parts = [] for i, doc in enumerate(docs, 1): source = doc.metadata.get("source", "Unknown") page = doc.metadata.get("page", "Unknown") content = doc.page_content.strip() parts.append( f"[Excerpt {i}]\n" f"Source: {source}, Page: {page}\n" f"Content: {content}" ) return "\n\n" + ("\n" + ("=" * 80) + "\n\n").join(parts) def _build_chain(self): """Build a strict-citation QA chain.""" prompt = ChatPromptTemplate.from_messages([ ("system", "You are a precise assistant that answers using only the given context.\n" "Rules:\n" "1) Use only the context to answer.\n" "2) Cite sources as: (Document Name, page X).\n" "3) If information is missing, reply exactly: \"This information is not available in the provided documents\".\n" "4) No external knowledge. No assumptions.\n" "5) Prefer concise bullets.\n" "6) End with Key Takeaways - 2 to 3 bullets.\n\n" "Context:\n{context}"), ("human", "Question: {question}\n\nAnswer using only the context above.") ]) def retrieve_and_pack(question: str) -> Dict[str, Any]: docs = self._retrieve(question) return {"context": self._format_docs(docs), "question": question} chain = retrieve_and_pack | prompt | self.llm | StrOutputParser() return chain # -------- Chapter summarization -------- def _find_chapter_span( self, question: str, pdf_paths: List[str] ) -> Optional[Tuple[str, int, int, List[str]]]: """ Find chapter span by scanning page texts for a heading like ^CHAPTER EIGHT or ^CHAPTER 8. Returns tuple: (pdf_name, start_page, end_page, page_texts[start:end+1]) Pages are 1-based for readability, but we keep 0-based indexing for internal operations. """ # Extract chapter token from question if possible # Accept words or numbers after 'chapter' m = re.search(r"chapter\s+([ivxlcdm]+|\d+)", question, re.IGNORECASE) chapter_token = m.group(1) if m else None start_pat = None if chapter_token: # Build a tolerant regex like ^CHAPTER\s+(EIGHT|8) roman = chapter_token.upper() num = chapter_token try: # If user gave digits, keep digits. If romans, keep romans too. start_pat = re.compile(rf"^CHAPTER\s+{re.escape(chapter_token)}\b", re.IGNORECASE | re.MULTILINE) except Exception: start_pat = re.compile(r"^CHAPTER\s+\w+", re.IGNORECASE | re.MULTILINE) else: start_pat = re.compile(r"^CHAPTER\s+\w+", re.IGNORECASE | re.MULTILINE) next_pat = re.compile(r"^CHAPTER\s+\w+", re.IGNORECASE | re.MULTILINE) # Try each PDF until we find a matching chapter start for pdf in pdf_paths: pages = self._load_entire_pdf_text_by_page(pdf) if not pages: continue start_idx = None for i, text in enumerate(pages): if start_pat.search(text): start_idx = i break if start_idx is None: continue # find end at the next chapter heading end_idx = len(pages) - 1 for j in range(start_idx + 1, len(pages)): if next_pat.search(pages[j]): end_idx = j - 1 break # Return texts and 1-based page numbers return (Path(pdf).name, start_idx + 1, end_idx + 1, pages[start_idx:end_idx + 1]) return None def _load_entire_pdf_text_by_page(self, pdf_path_str: str) -> List[str]: pdf_path = Path(pdf_path_str) try: page_docs = self.doc_store._load_pages(pdf_path) return [d.page_content or "" for d in page_docs] except Exception: return [] def _summarize_chapter(self, question: str) -> str: # Collect candidate PDFs from metadata pdfs = self.doc_store.metadata.get("pdf_files", []) span = self._find_chapter_span(question, pdfs) if not span: # Fall back to regular QA chain return self.chain.invoke(question) pdf_name, start_page, end_page, page_texts = span chapter_text = "\n\n".join(page_texts) # Map-reduce summarization # Map: summarize per slice map_prompt = ChatPromptTemplate.from_template( "You are summarizing a legal chapter from a statute. Summarize the following text into 6-10 bullet points. " "Keep every bullet tied to specific page numbers shown inline as (p. X). " "Do not use external knowledge.\n\n" "{text}" ) # Chunk chapter_text into moderately large pieces by naive split # Keep boundaries aligned with pages for reliable citations pieces = [] piece_buf = [] char_budget = 3500 # target per LLM call - adjust if needed running = 0 for idx, page in enumerate(page_texts): if running + len(page) > char_budget and piece_buf: pieces.append("\n\n".join(piece_buf)) piece_buf = [] running = 0 # Prepend page tag to help the model cite correctly page_num = start_page + idx piece_buf.append(f"[Page {page_num}]\n{page}") running += len(page) if piece_buf: pieces.append("\n\n".join(piece_buf)) map_summaries = [] for pc in pieces: ms = (map_prompt | self.llm | StrOutputParser()).invoke({"text": pc}) map_summaries.append(ms) reduce_prompt = ChatPromptTemplate.from_template( "Combine the partial summaries into a cohesive chapter summary with the following sections:\n" "1) Executive summary - 8 to 12 bullets with page citations.\n" "2) Section map - list section numbers and titles with page ranges.\n" "3) Detailed summary by section - concise rules, conditions, and any calculations with page citations.\n" "4) Table-friendly lines - incentives or exemptions with eligibility, conditions, limits, compliance steps, page.\n" "5) Open issues - ambiguities or cross-references.\n\n" "Document: {pdf_name}, Pages: {start_page}-{end_page}\n\n" "Partials:\n{partials}\n\n" "All claims must include page citations like (p. X). No external knowledge." ) final = (reduce_prompt | self.llm | StrOutputParser()).invoke({ "pdf_name": pdf_name, "start_page": start_page, "end_page": end_page, "partials": "\n\n---\n\n".join(map_summaries) }) return final # -------- Task routing -------- @staticmethod def _route(question: str) -> str: q = question.lower() if re.search(r"\bchapter\b|\bsection\b|\bpart\s+[ivxlcdm]+\b|^summari[sz]e\b", q): return "summarize" if re.search(r"\bextract\b|\blist\b|\btable\b|\brate\b|\bband\b|\bthreshold\b|\ballowance\b|\brelief\b", q): return "extract" return "qa" # Stub for a future extractor chain - currently route extractor requests to QA chain with strict rules def _extract_structured(self, question: str) -> str: return self.chain.invoke(question) def query(self, question: str, verbose: bool = False) -> str: """Route and answer the question.""" if verbose: print(f"\nRetrieving relevant documents...") docs = self._retrieve(question) print(f"Found {len(docs)} relevant chunks:") for i, doc in enumerate(docs[:20], 1): source = doc.metadata.get("source", "Unknown") page = doc.metadata.get("page", "Unknown") preview = doc.page_content[:150].replace("\n", " ") print(f" [{i}] {source} (page {page}): {preview}...") print() task = self._route(question) if task == "summarize": return self._summarize_chapter(question) elif task == "extract": return self._extract_structured(question) else: return self.chain.invoke(question) def main(): parser = argparse.ArgumentParser( description="Enhanced RAG pipeline with hybrid retrieval, reranking, and chapter summarization", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--source", type=Path, default=Path("."), help="Path to a PDF file or directory" ) parser.add_argument( "--persist-dir", type=Path, default=Path("vector_store"), help="Directory for vector store and caches" ) parser.add_argument( "--rebuild", action="store_true", help="Force rebuild of vector store" ) parser.add_argument( "--model", type=str, default="llama-3.1-8b-instant", help="Groq model name" ) parser.add_argument( "--embedding-model", type=str, default="sentence-transformers/all-mpnet-base-v2", help="HuggingFace embedding model" ) parser.add_argument( "--temperature", type=float, default=0.1, help="LLM temperature" ) parser.add_argument( "--top-k", type=int, default=8, help="Number of chunks to return after rerank" ) parser.add_argument( "--max-tokens", type=int, default=4096, help="Max tokens for response" ) parser.add_argument( "--question", type=str, help="Single question for non-interactive mode" ) parser.add_argument( "--no-hybrid", action="store_true", help="Disable BM25 plus FAISS hybrid retrieval" ) parser.add_argument( "--no-mmr", action="store_true", help="Disable MMR search on FAISS retriever" ) parser.add_argument( "--no-rerank", action="store_true", help="Disable cross-encoder reranking" ) parser.add_argument( "--neighbor-window", type=int, default=1, help="Include N neighbor pages around hits" ) parser.add_argument( "--verbose", action="store_true", help="Verbose retrieval logging" ) args = parser.parse_args() print("=" * 80) print("Kaanta AI - Nigeria Tax Acts RAG") print("=" * 80) if not os.getenv("GROQ_API_KEY"): print("\nERROR: GROQ_API_KEY not set") print("Set it with: export GROQ_API_KEY='your-key'") sys.exit(1) try: # Initialize document store doc_store = DocumentStore( persist_dir=args.persist_dir, embedding_model=args.embedding_model, ) # Discover PDFs pdf_paths = doc_store.discover_pdfs(args.source) # Build or load vector store doc_store.build_vector_store(pdf_paths, force_rebuild=args.rebuild) # Initialize pipeline rag = RAGPipeline( doc_store=doc_store, model=args.model, temperature=args.temperature, max_tokens=args.max_tokens, top_k=args.top_k, use_hybrid=not args.no_hybrid, use_mmr=not args.no_mmr, use_reranker=not args.no_rerank, neighbor_window=args.neighbor_window, ) print("\n" + "=" * 80) # Single question mode if args.question: print(f"\nQuestion: {args.question}\n") print("Kaanta AI is thinking...\n") answer = rag.query(args.question, verbose=args.verbose) print("Answer:") print("-" * 80) print(answer) print("-" * 80) return # Interactive mode print("\nReady. Ask questions about the Nigeria Tax Acts.") print("Type 'exit' or 'quit' to stop\n") print("=" * 80) while True: try: question = input("\nYour question: ").strip() except (EOFError, KeyboardInterrupt): print("\n\nGoodbye") break if not question: continue if question.lower() in ["exit", "quit", "q"]: print("\nGoodbye") break try: print("\nThinking...\n") answer = rag.query(question, verbose=args.verbose) print("Answer:") print("-" * 80) print(answer) print("-" * 80) except Exception as e: print(f"\nError: {e}") except Exception as e: print(f"\nFatal error: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()