|
|
|
|
|
""" |
|
|
Test script to validate FREE optimization improvements. |
|
|
Measures before/after quality on sample tax queries. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import json |
|
|
from pathlib import Path |
|
|
from rag_pipeline import RAGPipeline, DocumentStore |
|
|
|
|
|
|
|
|
TEST_QUESTIONS = [ |
|
|
{ |
|
|
"question": "What are the personal income tax rates in Nigeria?", |
|
|
"expected_keywords": ["₦800,000", "15%", "18%", "21%", "23%", "25%"], |
|
|
"category": "rates" |
|
|
}, |
|
|
{ |
|
|
"question": "What is CRA and how is it calculated?", |
|
|
"expected_keywords": ["Consolidated Relief Allowance", "₦200,000", "20%", "1%"], |
|
|
"category": "relief" |
|
|
}, |
|
|
{ |
|
|
"question": "What are the company income tax rates?", |
|
|
"expected_keywords": ["30%", "20%", "CIT", "company"], |
|
|
"category": "corporate" |
|
|
}, |
|
|
{ |
|
|
"question": "Tell me about PAYE deductions", |
|
|
"expected_keywords": ["Pay As You Earn", "employer", "monthly", "withholding"], |
|
|
"category": "paye" |
|
|
}, |
|
|
{ |
|
|
"question": "What tax reliefs are available for individuals?", |
|
|
"expected_keywords": ["relief", "allowance", "deduction", "pension"], |
|
|
"category": "reliefs" |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
def test_retrieval_quality(rag: RAGPipeline): |
|
|
"""Test if retrieval finds expected keywords.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("RETRIEVAL QUALITY TEST") |
|
|
print("=" * 80) |
|
|
|
|
|
results = [] |
|
|
for item in TEST_QUESTIONS: |
|
|
question = item["question"] |
|
|
expected = item["expected_keywords"] |
|
|
|
|
|
|
|
|
docs = rag._retrieve(question) |
|
|
retrieved_text = " ".join([d.page_content for d in docs[:10]]).lower() |
|
|
|
|
|
|
|
|
found = [kw for kw in expected if kw.lower() in retrieved_text] |
|
|
precision = len(found) / len(expected) if expected else 0 |
|
|
|
|
|
results.append({ |
|
|
"question": question, |
|
|
"precision": precision, |
|
|
"found": len(found), |
|
|
"total": len(expected), |
|
|
"found_keywords": found |
|
|
}) |
|
|
|
|
|
print(f"\n{item['category'].upper()}: {question}") |
|
|
print(f" Found: {len(found)}/{len(expected)} keywords ({precision*100:.0f}%)") |
|
|
if len(found) < len(expected): |
|
|
missing = set(expected) - set([k for k in expected if k.lower() in retrieved_text]) |
|
|
print(f" Missing: {', '.join(missing)}") |
|
|
|
|
|
avg_precision = sum(r["precision"] for r in results) / len(results) |
|
|
print(f"\n{'='*80}") |
|
|
print(f"AVERAGE RETRIEVAL PRECISION: {avg_precision*100:.1f}%") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
return avg_precision |
|
|
|
|
|
|
|
|
def test_answer_quality(rag: RAGPipeline): |
|
|
"""Test if answers have good formatting and content.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("ANSWER QUALITY TEST") |
|
|
print("=" * 80) |
|
|
|
|
|
for idx, item in enumerate(TEST_QUESTIONS[:3], 1): |
|
|
question = item["question"] |
|
|
print(f"\n[{idx}] QUESTION: {question}") |
|
|
print("-" * 80) |
|
|
|
|
|
try: |
|
|
answer = rag.query(question, verbose=False) |
|
|
|
|
|
|
|
|
has_bottom_line = "**Bottom line**" in answer |
|
|
has_numbers = any(char.isdigit() for char in answer) |
|
|
has_bold_numbers = "**₦" in answer or "**%" in answer |
|
|
no_fact_ids = "[F1]" not in answer and "[F2]" not in answer |
|
|
has_structure = "**Here's what you need to know**" in answer |
|
|
|
|
|
print(f"ANSWER:\n{answer}\n") |
|
|
print("QUALITY CHECKS:") |
|
|
print(f" ✓ Has bottom line: {has_bottom_line}") |
|
|
print(f" ✓ Contains numbers: {has_numbers}") |
|
|
print(f" ✓ Numbers emphasized (bold): {has_bold_numbers}") |
|
|
print(f" ✓ No fact IDs ([F1], etc.): {no_fact_ids}") |
|
|
print(f" ✓ Structured format: {has_structure}") |
|
|
|
|
|
if not all([has_bottom_line, has_numbers, no_fact_ids, has_structure]): |
|
|
print(" ⚠️ WARNING: Some quality checks failed!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ❌ ERROR: {e}") |
|
|
|
|
|
print(f"\n{'='*80}\n") |
|
|
|
|
|
|
|
|
def test_hallucination_prevention(rag: RAGPipeline): |
|
|
"""Test if system avoids hallucinating specific examples.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("HALLUCINATION PREVENTION TEST") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
trick_questions = [ |
|
|
{ |
|
|
"question": "How much tax will I pay if I earn ₦500,000 per month?", |
|
|
"should_calculate": True, |
|
|
"forbidden_phrases": [] |
|
|
}, |
|
|
{ |
|
|
"question": "What happens if I don't pay my taxes?", |
|
|
"should_calculate": False, |
|
|
"forbidden_phrases": ["for example, you could be fined ₦", "typically around ₦"] |
|
|
}, |
|
|
] |
|
|
|
|
|
hallucinations = 0 |
|
|
total = 0 |
|
|
|
|
|
for item in trick_questions: |
|
|
question = item["question"] |
|
|
print(f"\nQUESTION: {question}") |
|
|
|
|
|
try: |
|
|
answer = rag.query(question, verbose=False) |
|
|
|
|
|
|
|
|
found_forbidden = [] |
|
|
for phrase in item["forbidden_phrases"]: |
|
|
if phrase.lower() in answer.lower(): |
|
|
found_forbidden.append(phrase) |
|
|
hallucinations += 1 |
|
|
|
|
|
if found_forbidden: |
|
|
print(f" ❌ HALLUCINATION DETECTED: {found_forbidden}") |
|
|
print(f" Answer excerpt: {answer[:200]}...") |
|
|
else: |
|
|
print(f" ✓ No hallucinations detected") |
|
|
|
|
|
total += 1 |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ⚠️ ERROR: {e}") |
|
|
|
|
|
if total > 0: |
|
|
hallucination_rate = (hallucinations / total) * 100 |
|
|
print(f"\n{'='*80}") |
|
|
print(f"HALLUCINATION RATE: {hallucination_rate:.1f}%") |
|
|
if hallucination_rate == 0: |
|
|
print("✓ EXCELLENT: No hallucinations detected!") |
|
|
elif hallucination_rate < 10: |
|
|
print("✓ GOOD: Low hallucination rate") |
|
|
else: |
|
|
print("⚠️ WARNING: High hallucination rate, review prompts") |
|
|
print(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("=" * 80) |
|
|
print("FREE OPTIMIZATION VALIDATION TEST") |
|
|
print("Testing: Improved embeddings, prompts, formatting, and retrieval") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
print("\nInitializing RAG pipeline...") |
|
|
vector_store_path = Path("vector_store") |
|
|
doc_store = DocumentStore( |
|
|
persist_dir=vector_store_path, |
|
|
embedding_model="BAAI/bge-large-en-v1.5" |
|
|
) |
|
|
|
|
|
src = Path("data") |
|
|
pdfs = doc_store.discover_pdfs(src) |
|
|
doc_store.build_vector_store(pdfs, force_rebuild=False) |
|
|
|
|
|
rag = RAGPipeline( |
|
|
doc_store=doc_store, |
|
|
model="llama-3.3-70b-versatile", |
|
|
temperature=0.1, |
|
|
top_k=15, |
|
|
use_hybrid=True, |
|
|
use_mmr=True, |
|
|
use_reranker=True |
|
|
) |
|
|
|
|
|
print("✓ RAG pipeline initialized\n") |
|
|
|
|
|
|
|
|
try: |
|
|
retrieval_precision = test_retrieval_quality(rag) |
|
|
test_answer_quality(rag) |
|
|
test_hallucination_prevention(rag) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("SUMMARY") |
|
|
print("=" * 80) |
|
|
print(f"Retrieval Precision: {retrieval_precision*100:.1f}%") |
|
|
print(f" Target: >55% (baseline was ~42%)") |
|
|
if retrieval_precision > 0.55: |
|
|
print(f" ✓ EXCELLENT: Retrieval improved!") |
|
|
elif retrieval_precision > 0.45: |
|
|
print(f" ✓ GOOD: Retrieval improved") |
|
|
else: |
|
|
print(f" ⚠️ Need improvement") |
|
|
|
|
|
print("\nOPTIMIZATIONS APPLIED:") |
|
|
print(" ✓ Upgraded embedding: all-MiniLM-L6-v2 → bge-large-en-v1.5") |
|
|
print(" ✓ Upgraded reranker: MiniLM-L-6 → MiniLM-L-12") |
|
|
print(" ✓ Anti-hallucination system prompts") |
|
|
print(" ✓ Enhanced fact schema with number extraction") |
|
|
print(" ✓ Removed fact IDs from output") |
|
|
print(" ✓ Bold emphasis on numbers and percentages") |
|
|
print(" ✓ Tax-aware query expansion") |
|
|
print(" ✓ Increased retrieval: 8 → 15 docs") |
|
|
print(" ✓ Context added to thresholds (₦800K → ₦800K (₦66,667/month))") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("TEST COMPLETE") |
|
|
print("=" * 80) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n❌ TEST FAILED: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|