# orchestrator.py from __future__ import annotations from dataclasses import dataclass from datetime import date, datetime from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Union import argparse import json import os import sys from dotenv import load_dotenv, find_dotenv from fastapi import FastAPI, HTTPException, Body from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, ConfigDict, Field, field_validator # Load .env so GROQ_API_KEY and other vars are available load_dotenv(find_dotenv(), override=False) # If these files live in the same folder as this file, keep imports as below. # If they live under an app/ package, change to: # from app.calculator.rules_engine import RuleCatalog, TaxEngine # from app.rag.rag_pipeline import RAGPipeline, DocumentStore from rules_engine import RuleCatalog, TaxEngine from rag_pipeline import RAGPipeline, DocumentStore from transaction_classifier import TransactionClassifier from transaction_aggregator import TransactionAggregator from tax_strategy_extractor import TaxStrategyExtractor from tax_optimizer import TaxOptimizer # -------------------- Config -------------------- RULES_PATH = "rules/rules_all.yaml" # adjust if yours is different PDF_SOURCE = "data" # folder or a single PDF EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" GROQ_MODEL = "llama-3.1-8b-instant" # Use /tmp for vector store in Hugging Face Spaces (writable directory) VECTOR_STORE_DIR = os.getenv('VECTOR_STORE_DIR', '/tmp/vector_store') # Allow disabling RAG entirely for resource-constrained environments DISABLE_RAG = os.getenv('DISABLE_RAG', 'false').lower() in ('true', '1', 'yes') import os os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface_cache" os.environ["HF_HOME"] = "/tmp/huggingface_cache" os.makedirs("/tmp/huggingface_cache", exist_ok=True) # Pre-download embedding model to cache def _ensure_embedding_model_cached(model_name: str) -> bool: """Pre-download embedding model to avoid runtime errors""" try: from sentence_transformers import SentenceTransformer print(f"[INFO] Pre-downloading embedding model: {model_name}", file=sys.stderr) model = SentenceTransformer(model_name) print(f"[INFO] Embedding model cached successfully", file=sys.stderr) return True except Exception as e: print(f"[WARN] Failed to cache embedding model: {e}", file=sys.stderr) print(f"[INFO] This is common in Hugging Face Spaces with limited disk space", file=sys.stderr) print(f"[INFO] Set DISABLE_RAG=true to skip RAG initialization", file=sys.stderr) return False CALC_KEYWORDS = { "compute", "calculate", "calc", "how much tax", "tax due", "paye", "cit", "vat to pay", "what will i pay", "liability", "estimate", "breakdown", "net pay", "withholding" } INFO_KEYWORDS = { "what is", "explain", "definition", "section", "rate", "band", "threshold", "who is exempt", "am i exempt", "citation", "law", "clause", "which section" } # -------------------- Pydantic models -------------------- class HandleRequest(BaseModel): """Payload for the orchestrator endpoint.""" question: str = Field(..., min_length=1, description="User question or instruction.") as_of: Optional[date] = Field( default=None, description="Date context for tax rules. Defaults to today when omitted." ) tax_type: str = Field( default="PIT", description="Tax product to evaluate when calculations are requested (PIT, CIT, VAT)." ) jurisdiction: Optional[str] = Field( default="state", description="Jurisdiction key used to filter the rules catalog." ) inputs: Optional[Dict[str, float]] = Field( default=None, description="Numeric inputs required by the calculator, for example {'gross_income': 500000}." ) with_rag_quotes_on_calc: bool = Field( default=True, description="When true and RAG is available, attaches short supporting quotes to calculator lines." ) rule_ids_whitelist: Optional[List[str]] = Field( default=None, description="Optional list of rule IDs to evaluate. When set, other rules are ignored." ) model_config = ConfigDict(extra="forbid") @field_validator("tax_type") @classmethod def _normalize_tax_type(cls, v: str) -> str: allowed = {"PIT", "CIT", "VAT"} value = (v or "").upper() if value not in allowed: raise ValueError(f"tax_type must be one of {sorted(allowed)}") return value @field_validator("inputs") @classmethod def _ensure_numeric_inputs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, float]]: if v is None: return None coerced: Dict[str, float] = {} for key, raw in v.items(): if raw is None: raise ValueError(f"Input '{key}' cannot be null.") try: coerced[key] = float(raw) except (TypeError, ValueError) as exc: raise ValueError(f"Input '{key}' must be numeric.") from exc return coerced class CalculationLine(BaseModel): rule_id: str title: str amount: float output: Optional[str] = None details: Dict[str, Any] = {} authority: List[Dict[str, Any]] = [] quote: Optional[str] = Field( default=None, description="Optional supporting quote from the RAG pipeline." ) model_config = ConfigDict(extra="allow") class RagOnlyResponse(BaseModel): mode: Literal["rag_only"] as_of: str answer: str class CalculationResponse(BaseModel): mode: Literal["calculate"] as_of: str tax_type: str summary: Dict[str, float] lines: List[CalculationLine] model_config = ConfigDict(extra="allow") HandleResponse = Union[RagOnlyResponse, CalculationResponse] # -------------------- Optimization Models -------------------- class MonoTransaction(BaseModel): """Transaction from Mono API or manual entry""" id: Optional[str] = Field(default=None, alias="_id") type: str = Field(..., description="debit or credit (or income/expense)") amount: Optional[float] = None # Amount in Naira amount_kobo: Optional[int] = None # Amount in kobo (from backend) narration: Optional[str] = None # Mono API format description: Optional[str] = None # Backend format date: Optional[str] = None # ISO format date string timestamp: Optional[str] = None # Backend format balance: Optional[float] = None category: Optional[str] = None metadata: Optional[Dict[str, Any]] = None model_config = ConfigDict(extra="allow", populate_by_name=True) @field_validator('type') @classmethod def normalize_type(cls, v: str) -> str: """Normalize transaction type to credit/debit format""" v_lower = v.lower() if v_lower in ['income', 'credit']: return 'credit' elif v_lower in ['expense', 'debit']: return 'debit' return v def model_post_init(self, __context) -> None: """Normalize fields after initialization""" # Normalize amount: prefer amount, fallback to amount_kobo if self.amount is None and self.amount_kobo is not None: self.amount = self.amount_kobo / 100.0 elif self.amount is None: self.amount = 0.0 # Normalize narration: prefer narration, fallback to description if self.narration is None and self.description is not None: self.narration = self.description elif self.narration is None: self.narration = "Unknown transaction" # Normalize date: prefer date, fallback to timestamp if self.date is None and self.timestamp is not None: self.date = self.timestamp elif self.date is None: self.date = datetime.now().isoformat() class TaxpayerProfile(BaseModel): """Optional taxpayer profile information""" taxpayer_type: str = Field(default="individual", description="individual or company") employment_status: Optional[str] = Field(default=None, description="employed, self_employed, business_owner, mixed") annual_income: Optional[float] = None annual_turnover: Optional[float] = None has_rental_income: Optional[bool] = False location: Optional[str] = None model_config = ConfigDict(extra="allow") class OptimizationRequest(BaseModel): """Request payload for tax optimization endpoint""" user_id: str = Field(..., description="Unique user identifier") transactions: List[MonoTransaction] = Field(..., description="List of transactions from Mono API and manual entry") taxpayer_profile: Optional[TaxpayerProfile] = Field(default=None, description="Optional taxpayer profile (auto-inferred if omitted)") tax_year: int = Field(default=2025, description="Tax year to optimize for") tax_type: str = Field(default="PIT", description="PIT, CIT, or VAT") jurisdiction: str = Field(default="state", description="federal or state") model_config = ConfigDict(extra="forbid") class OptimizationResponse(BaseModel): """Response from tax optimization endpoint""" user_id: str tax_year: int tax_type: str analysis_date: str baseline_tax_liability: float optimized_tax_liability: float total_potential_savings: float savings_percentage: float total_annual_income: float current_deductions: Dict[str, float] recommendations: List[Dict[str, Any]] recommendation_count: int transaction_summary: Dict[str, Any] income_breakdown: Dict[str, Any] deduction_breakdown: Dict[str, Any] taxpayer_profile: Dict[str, Any] baseline_calculation: Dict[str, Any] model_config = ConfigDict(extra="allow") # -------------------- Helpers -------------------- def classify_intent(user_text: str) -> str: q = (user_text or "").lower().strip() if any(k in q for k in CALC_KEYWORDS): return "calculate" if any(k in q for k in INFO_KEYWORDS): return "explain" if any(tok in q for tok in ["₦", "ngn", "naira"]) or any(ch.isdigit() for ch in q): if "how much" in q or "pay" in q or "tax" in q: return "calculate" return "explain" # -------------------- Orchestrator core -------------------- @dataclass class Orchestrator: catalog: RuleCatalog engine: TaxEngine rag: Optional[RAGPipeline] = None # RAG optional if PDFs or GROQ are missing optimizer: Optional[TaxOptimizer] = None # Tax optimizer @classmethod def bootstrap(cls) -> "Orchestrator": # calculator if not os.path.exists(RULES_PATH): print(f"ERROR: Rules file not found at {RULES_PATH}", file=sys.stderr) sys.exit(1) catalog = RuleCatalog.from_yaml_files([RULES_PATH]) engine = TaxEngine(catalog, rounding_mode="half_up") # RAG rag = None if DISABLE_RAG: print(f"[INFO] RAG disabled via DISABLE_RAG environment variable", file=sys.stderr) else: try: # Pre-download embedding model if not _ensure_embedding_model_cached(EMBED_MODEL): print(f"[WARN] Embedding model not available. RAG disabled.", file=sys.stderr) raise RuntimeError("Embedding model unavailable") src = Path(PDF_SOURCE) # Use writable directory for Hugging Face Spaces vector_store_path = Path(VECTOR_STORE_DIR) # Create directory if it doesn't exist and is writable try: vector_store_path.mkdir(parents=True, exist_ok=True) except (PermissionError, OSError) as mkdir_err: print(f"[WARN] Cannot create vector_store directory: {mkdir_err}", file=sys.stderr) print(f"[INFO] RAG will be disabled. Tax calculations will still work.", file=sys.stderr) raise ds = DocumentStore(persist_dir=vector_store_path, embedding_model=EMBED_MODEL) pdfs = ds.discover_pdfs(src) if not pdfs: print(f"[WARN] No PDFs found under {src}. RAG disabled.", file=sys.stderr) raise FileNotFoundError(f"No PDFs found under {src}") ds.build_vector_store(pdfs, force_rebuild=False) # RAGPipeline reads GROQ_API_KEY from env via langchain_groq; ensure .env loaded rag = RAGPipeline(doc_store=ds, model=GROQ_MODEL, temperature=0.1) print("[INFO] RAG pipeline initialized successfully", file=sys.stderr) except Exception as e: print(f"[WARN] RAG not initialized: {e}", file=sys.stderr) print(f"[INFO] Service will continue without RAG. Tax calculations available.", file=sys.stderr) # Tax Optimizer optimizer = None if rag: # Optimizer requires RAG for strategy extraction try: classifier = TransactionClassifier(rag_pipeline=rag) aggregator = TransactionAggregator() strategy_extractor = TaxStrategyExtractor(rag_pipeline=rag) optimizer = TaxOptimizer( classifier=classifier, aggregator=aggregator, strategy_extractor=strategy_extractor, tax_engine=engine ) print("[INFO] Tax Optimizer initialized successfully", file=sys.stderr) except Exception as e: print(f"[WARN] Tax Optimizer not initialized: {e}", file=sys.stderr) else: print("[INFO] Tax Optimizer disabled (requires RAG)", file=sys.stderr) return cls(catalog=catalog, engine=engine, rag=rag, optimizer=optimizer) def handle( self, *, user_text: str, as_of: date, tax_type: str = "PIT", jurisdiction: Optional[str] = "state", inputs: Optional[Dict[str, float]] = None, with_rag_quotes_on_calc: bool = True, rule_ids_whitelist: Optional[List[str]] = None ) -> Dict[str, Any]: intent = classify_intent(user_text) use_calc = intent == "calculate" and inputs is not None # RAG-only if not use_calc: if not self.rag: return { "mode": "rag_only", "as_of": as_of.isoformat(), "answer": "RAG unavailable. Add PDFs under 'data' and set GROQ_API_KEY." } answer = self.rag.query(user_text, verbose=False) return {"mode": "rag_only", "as_of": as_of.isoformat(), "answer": str(answer)} # Calculate ctx = self.engine.run( tax_type=tax_type, as_of=as_of, jurisdiction=jurisdiction, inputs=inputs, rule_ids_whitelist=rule_ids_whitelist ) lines: List[Dict[str, Any]] = ctx.lines # Optional: enrich with short quotes if with_rag_quotes_on_calc and self.rag: enriched = [] for ln in lines: auth = ln.get("authority", []) hint = "" if auth: a0 = auth[0] doc = a0.get("doc") or "" sec = a0.get("section") or "" hint = f" from {doc} {sec}".strip() q = f"Quote the operative text{hint}. Keep under 120 words with section and page if visible." try: quote = self.rag.query(q, verbose=False) except Exception: quote = None enriched.append({**ln, "quote": quote}) lines = enriched return { "mode": "calculate", "as_of": as_of.isoformat(), "tax_type": tax_type, "summary": {"tax_due": float(ctx.values.get("tax_due", ctx.values.get("computed_tax", 0.0)))}, "lines": lines } # -------------------- FastAPI app -------------------- app = FastAPI( title="Kaanta Tax Assistant API", version="0.1.0", description="Routes informational Nigeria tax queries to the RAG pipeline and calculations to the deterministic engine.", contact={"name": "Kaanta AI", "url": "https://huggingface.co/spaces"} ) # CORS: open by default. Lock down in production. app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") def _startup_event() -> None: app.state.orchestrator = Orchestrator.bootstrap() def _get_orchestrator() -> Orchestrator: orch = getattr(app.state, "orchestrator", None) if orch is None: raise HTTPException(status_code=503, detail="Service is still warming up.") return orch @app.get("/", tags=["Meta"]) def read_root() -> Dict[str, Any]: orch = getattr(app.state, "orchestrator", None) return { "service": "Kaanta Tax Assistant", "version": "0.2.0", "rag_ready": bool(orch and orch.rag), "calculator_ready": bool(orch), "optimizer_ready": bool(orch and orch.optimizer), "docs_url": "/docs", } @app.get("/health", tags=["Meta"]) def health_check() -> Dict[str, Any]: orch = getattr(app.state, "orchestrator", None) status = "ok" if orch else "initializing" return {"status": status, "rag_ready": bool(orch and orch.rag)} @app.post("/v1/query", tags=["Assistant"], response_model=HandleResponse) def orchestrate_query(payload: HandleRequest = Body(...)) -> HandleResponse: orch = _get_orchestrator() effective_date = payload.as_of or date.today() result = orch.handle( user_text=payload.question, as_of=effective_date, tax_type=payload.tax_type, jurisdiction=payload.jurisdiction, inputs=payload.inputs, with_rag_quotes_on_calc=payload.with_rag_quotes_on_calc, rule_ids_whitelist=payload.rule_ids_whitelist, ) return result # FastAPI will validate against HandleResponse @app.post("/v1/optimize", tags=["Optimization"], response_model=OptimizationResponse) def optimize_tax(payload: OptimizationRequest = Body(...)) -> OptimizationResponse: """ Analyze user transactions and generate tax optimization recommendations This endpoint: 1. Classifies transactions from Mono API and manual entry 2. Aggregates them into tax calculation inputs 3. Calculates baseline tax liability 4. Extracts relevant optimization strategies from tax acts 5. Simulates optimization scenarios 6. Returns ranked recommendations with estimated savings Example request: ```json { "user_id": "user123", "transactions": [ { "type": "credit", "amount": 500000, "narration": "SALARY PAYMENT FROM ABC LTD", "date": "2025-01-31", "balance": 750000 }, { "type": "debit", "amount": 40000, "narration": "PENSION CONTRIBUTION TO XYZ PFA", "date": "2025-01-31", "balance": 710000 } ], "tax_year": 2025 } ``` """ orch = _get_orchestrator() # Check if optimizer is available if not orch.optimizer: error_detail = { "error": "Tax optimizer not available", "reason": "RAG pipeline failed to initialize", "possible_causes": [ "GROQ_API_KEY not set in environment", "PDF files not found in 'data' directory", "Vector store directory not writable (check /tmp permissions)", "Missing dependencies" ], "status": { "calculator_ready": bool(orch.engine), "rag_ready": bool(orch.rag), "optimizer_ready": False } } raise HTTPException( status_code=503, detail=error_detail ) # Convert Pydantic models to dicts for processing transactions = [tx.model_dump(by_alias=True) for tx in payload.transactions] taxpayer_profile = payload.taxpayer_profile.model_dump() if payload.taxpayer_profile else None # Run optimization try: result = orch.optimizer.optimize( user_id=payload.user_id, transactions=transactions, taxpayer_profile=taxpayer_profile, tax_year=payload.tax_year, tax_type=payload.tax_type, jurisdiction=payload.jurisdiction ) return OptimizationResponse(**result) except Exception as e: raise HTTPException( status_code=500, detail=f"Optimization failed: {str(e)}" ) # -------------------- CLI entrypoint -------------------- def _parse_args(): p = argparse.ArgumentParser(description="Kaanta Tax Orchestrator (RAG + Calculator router)") p.add_argument("--question", required=True, help="User question or instruction") p.add_argument("--as-of", default=None, help="YYYY-MM-DD. Defaults to today.") p.add_argument("--tax-type", default="PIT", choices=["PIT", "CIT", "VAT"]) p.add_argument("--jurisdiction", default="state") p.add_argument("--inputs-json", default=None, help="Path to JSON file with calculator inputs") p.add_argument("--no-rag-quotes", action="store_true", help="Skip RAG quotes after calculation") return p.parse_args() def main(): args = _parse_args() as_of = date.today() if not args.as_of else datetime.strptime(args.as_of, "%Y-%m-%d").date() inputs = None if args.inputs_json: with open(args.inputs_json, "r", encoding="utf-8") as f: inputs = json.load(f) orch = Orchestrator.bootstrap() if not os.getenv("GROQ_API_KEY"): print("Note: GROQ_API_KEY not set. RAG queries will fail if executed.", file=sys.stderr) result = orch.handle( user_text=args.question, as_of=as_of, tax_type=args.tax_type, jurisdiction=args.jurisdiction, inputs=inputs, with_rag_quotes_on_calc=not args.no_rag_quotes, ) print(json.dumps(result, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()