#!/usr/bin/env python3 """ Test script to validate FREE optimization improvements. Measures before/after quality on sample tax queries. """ import sys import json from pathlib import Path from rag_pipeline import RAGPipeline, DocumentStore # Test questions covering different tax scenarios TEST_QUESTIONS = [ { "question": "What are the personal income tax rates in Nigeria?", "expected_keywords": ["₦800,000", "15%", "18%", "21%", "23%", "25%"], "category": "rates" }, { "question": "What is CRA and how is it calculated?", "expected_keywords": ["Consolidated Relief Allowance", "₦200,000", "20%", "1%"], "category": "relief" }, { "question": "What are the company income tax rates?", "expected_keywords": ["30%", "20%", "CIT", "company"], "category": "corporate" }, { "question": "Tell me about PAYE deductions", "expected_keywords": ["Pay As You Earn", "employer", "monthly", "withholding"], "category": "paye" }, { "question": "What tax reliefs are available for individuals?", "expected_keywords": ["relief", "allowance", "deduction", "pension"], "category": "reliefs" }, ] def test_retrieval_quality(rag: RAGPipeline): """Test if retrieval finds expected keywords.""" print("\n" + "=" * 80) print("RETRIEVAL QUALITY TEST") print("=" * 80) results = [] for item in TEST_QUESTIONS: question = item["question"] expected = item["expected_keywords"] # Retrieve docs docs = rag._retrieve(question) retrieved_text = " ".join([d.page_content for d in docs[:10]]).lower() # Check if expected keywords found found = [kw for kw in expected if kw.lower() in retrieved_text] precision = len(found) / len(expected) if expected else 0 results.append({ "question": question, "precision": precision, "found": len(found), "total": len(expected), "found_keywords": found }) print(f"\n{item['category'].upper()}: {question}") print(f" Found: {len(found)}/{len(expected)} keywords ({precision*100:.0f}%)") if len(found) < len(expected): missing = set(expected) - set([k for k in expected if k.lower() in retrieved_text]) print(f" Missing: {', '.join(missing)}") avg_precision = sum(r["precision"] for r in results) / len(results) print(f"\n{'='*80}") print(f"AVERAGE RETRIEVAL PRECISION: {avg_precision*100:.1f}%") print(f"{'='*80}\n") return avg_precision def test_answer_quality(rag: RAGPipeline): """Test if answers have good formatting and content.""" print("\n" + "=" * 80) print("ANSWER QUALITY TEST") print("=" * 80) for idx, item in enumerate(TEST_QUESTIONS[:3], 1): # Test first 3 for speed question = item["question"] print(f"\n[{idx}] QUESTION: {question}") print("-" * 80) try: answer = rag.query(question, verbose=False) # Quality checks has_bottom_line = "**Bottom line**" in answer has_numbers = any(char.isdigit() for char in answer) has_bold_numbers = "**₦" in answer or "**%" in answer no_fact_ids = "[F1]" not in answer and "[F2]" not in answer has_structure = "**Here's what you need to know**" in answer print(f"ANSWER:\n{answer}\n") print("QUALITY CHECKS:") print(f" ✓ Has bottom line: {has_bottom_line}") print(f" ✓ Contains numbers: {has_numbers}") print(f" ✓ Numbers emphasized (bold): {has_bold_numbers}") print(f" ✓ No fact IDs ([F1], etc.): {no_fact_ids}") print(f" ✓ Structured format: {has_structure}") if not all([has_bottom_line, has_numbers, no_fact_ids, has_structure]): print(" ⚠️ WARNING: Some quality checks failed!") except Exception as e: print(f" ❌ ERROR: {e}") print(f"\n{'='*80}\n") def test_hallucination_prevention(rag: RAGPipeline): """Test if system avoids hallucinating specific examples.""" print("\n" + "=" * 80) print("HALLUCINATION PREVENTION TEST") print("=" * 80) # Questions designed to tempt hallucination trick_questions = [ { "question": "How much tax will I pay if I earn ₦500,000 per month?", "should_calculate": True, # Should use tax calculator "forbidden_phrases": [] # Calculator is allowed to show examples }, { "question": "What happens if I don't pay my taxes?", "should_calculate": False, "forbidden_phrases": ["for example, you could be fined ₦", "typically around ₦"] }, ] hallucinations = 0 total = 0 for item in trick_questions: question = item["question"] print(f"\nQUESTION: {question}") try: answer = rag.query(question, verbose=False) # Check for forbidden phrases found_forbidden = [] for phrase in item["forbidden_phrases"]: if phrase.lower() in answer.lower(): found_forbidden.append(phrase) hallucinations += 1 if found_forbidden: print(f" ❌ HALLUCINATION DETECTED: {found_forbidden}") print(f" Answer excerpt: {answer[:200]}...") else: print(f" ✓ No hallucinations detected") total += 1 except Exception as e: print(f" ⚠️ ERROR: {e}") if total > 0: hallucination_rate = (hallucinations / total) * 100 print(f"\n{'='*80}") print(f"HALLUCINATION RATE: {hallucination_rate:.1f}%") if hallucination_rate == 0: print("✓ EXCELLENT: No hallucinations detected!") elif hallucination_rate < 10: print("✓ GOOD: Low hallucination rate") else: print("⚠️ WARNING: High hallucination rate, review prompts") print(f"{'='*80}\n") def main(): print("=" * 80) print("FREE OPTIMIZATION VALIDATION TEST") print("Testing: Improved embeddings, prompts, formatting, and retrieval") print("=" * 80) # Initialize RAG pipeline print("\nInitializing RAG pipeline...") vector_store_path = Path("vector_store") doc_store = DocumentStore( persist_dir=vector_store_path, embedding_model="BAAI/bge-large-en-v1.5" # New embedding model ) src = Path("data") pdfs = doc_store.discover_pdfs(src) doc_store.build_vector_store(pdfs, force_rebuild=False) rag = RAGPipeline( doc_store=doc_store, model="llama-3.3-70b-versatile", temperature=0.1, top_k=15, # Increased from 8 use_hybrid=True, use_mmr=True, use_reranker=True ) print("✓ RAG pipeline initialized\n") # Run tests try: retrieval_precision = test_retrieval_quality(rag) test_answer_quality(rag) test_hallucination_prevention(rag) # Summary print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) print(f"Retrieval Precision: {retrieval_precision*100:.1f}%") print(f" Target: >55% (baseline was ~42%)") if retrieval_precision > 0.55: print(f" ✓ EXCELLENT: Retrieval improved!") elif retrieval_precision > 0.45: print(f" ✓ GOOD: Retrieval improved") else: print(f" ⚠️ Need improvement") print("\nOPTIMIZATIONS APPLIED:") print(" ✓ Upgraded embedding: all-MiniLM-L6-v2 → bge-large-en-v1.5") print(" ✓ Upgraded reranker: MiniLM-L-6 → MiniLM-L-12") print(" ✓ Anti-hallucination system prompts") print(" ✓ Enhanced fact schema with number extraction") print(" ✓ Removed fact IDs from output") print(" ✓ Bold emphasis on numbers and percentages") print(" ✓ Tax-aware query expansion") print(" ✓ Increased retrieval: 8 → 15 docs") print(" ✓ Context added to thresholds (₦800K → ₦800K (₦66,667/month))") print("\n" + "=" * 80) print("TEST COMPLETE") print("=" * 80) except Exception as e: print(f"\n❌ TEST FAILED: {e}") import traceback traceback.print_exc() sys.exit(1) if __name__ == "__main__": main()