# tax_optimizer.py """ Main Tax Optimization Engine Integrates classifier, aggregator, strategy extractor, and tax engine """ from __future__ import annotations from typing import Dict, List, Any, Optional from datetime import date from dataclasses import dataclass, asdict from transaction_classifier import TransactionClassifier from transaction_aggregator import TransactionAggregator from tax_strategy_extractor import TaxStrategyExtractor, TaxStrategy from rules_engine import TaxEngine, CalculationResult @dataclass class OptimizationScenario: """Represents a tax optimization scenario""" scenario_id: str name: str description: str modified_inputs: Dict[str, float] changes_made: Dict[str, Any] strategy_ids: List[str] @dataclass class OptimizationRecommendation: """A single tax optimization recommendation""" rank: int strategy_name: str strategy_id: str description: str annual_tax_savings: float optimized_tax: float baseline_tax: float implementation_steps: List[str] legal_citations: List[str] risk_level: str complexity: str confidence_score: float changes_required: Dict[str, Any] class TaxOptimizer: """ Main tax optimization engine Analyzes transactions and generates optimization recommendations """ def __init__( self, classifier: TransactionClassifier, aggregator: TransactionAggregator, strategy_extractor: TaxStrategyExtractor, tax_engine: TaxEngine ): """ Initialize optimizer with required components Args: classifier: TransactionClassifier instance aggregator: TransactionAggregator instance strategy_extractor: TaxStrategyExtractor instance tax_engine: TaxEngine instance """ self.classifier = classifier self.aggregator = aggregator self.strategy_extractor = strategy_extractor self.engine = tax_engine def optimize( self, user_id: str, transactions: List[Dict[str, Any]], taxpayer_profile: Optional[Dict[str, Any]] = None, tax_year: int = 2025, tax_type: str = "PIT", jurisdiction: str = "state" ) -> Dict[str, Any]: """ Main optimization workflow Args: user_id: Unique user identifier transactions: List of transactions from Mono API + manual entry taxpayer_profile: Optional profile info (auto-inferred if not provided) tax_year: Tax year to optimize for tax_type: PIT, CIT, or VAT jurisdiction: federal or state Returns: Comprehensive optimization report """ # Step 1: Classify transactions print(f"[Optimizer] Classifying {len(transactions)} transactions...") classified_txs = self.classifier.classify_batch(transactions) # Step 2: Aggregate into tax inputs print(f"[Optimizer] Aggregating transactions for tax year {tax_year}...") tax_inputs = self.aggregator.aggregate_for_tax_year(classified_txs, tax_year) # Step 3: Infer taxpayer profile if not provided if not taxpayer_profile: taxpayer_profile = self._infer_profile(tax_inputs, classified_txs) # Add annual income to profile taxpayer_profile["annual_income"] = tax_inputs.get("gross_income", 0) # Step 4: Calculate baseline tax print(f"[Optimizer] Calculating baseline tax liability...") baseline_result = self._calculate_tax( tax_inputs=tax_inputs, tax_type=tax_type, tax_year=tax_year, jurisdiction=jurisdiction ) baseline_tax = baseline_result.values.get("tax_due", 0) # Step 5: Extract applicable strategies print(f"[Optimizer] Extracting optimization strategies...") strategies = self.strategy_extractor.extract_strategies_for_profile( taxpayer_profile=taxpayer_profile, tax_year=tax_year ) # Step 6: Identify opportunities from transaction analysis print(f"[Optimizer] Identifying optimization opportunities...") opportunities = self.aggregator.identify_optimization_opportunities( aggregated=tax_inputs, tax_year=tax_year ) # Step 7: Generate optimization scenarios print(f"[Optimizer] Generating optimization scenarios...") scenarios = self._generate_scenarios( baseline_inputs=tax_inputs, strategies=strategies, opportunities=opportunities ) # Step 8: Simulate each scenario print(f"[Optimizer] Simulating {len(scenarios)} scenarios...") scenario_results = [] for scenario in scenarios: result = self._calculate_tax( tax_inputs=scenario.modified_inputs, tax_type=tax_type, tax_year=tax_year, jurisdiction=jurisdiction ) scenario_tax = result.values.get("tax_due", 0) savings = baseline_tax - scenario_tax scenario_results.append({ "scenario": scenario, "tax": scenario_tax, "savings": savings, "result": result }) # Step 9: Rank and create recommendations print(f"[Optimizer] Ranking recommendations...") recommendations = self._create_recommendations( scenario_results=scenario_results, baseline_tax=baseline_tax, strategies=strategies ) # Step 10: Generate comprehensive report classification_summary = self.classifier.get_classification_summary(classified_txs) income_breakdown = self.aggregator.get_income_breakdown(classified_txs, tax_year) deduction_breakdown = self.aggregator.get_deduction_breakdown(classified_txs, tax_year) # Calculate total potential savings total_potential_savings = sum(r.annual_tax_savings for r in recommendations) optimized_tax = baseline_tax - total_potential_savings if recommendations else baseline_tax return { "user_id": user_id, "tax_year": tax_year, "tax_type": tax_type, "analysis_date": date.today().isoformat(), # Tax summary "baseline_tax_liability": baseline_tax, "optimized_tax_liability": optimized_tax, "total_potential_savings": total_potential_savings, "savings_percentage": (total_potential_savings / baseline_tax * 100) if baseline_tax > 0 else 0, # Income & deductions "total_annual_income": tax_inputs.get("gross_income", 0), "current_deductions": { "pension": tax_inputs.get("employee_pension_contribution", 0), "nhf": tax_inputs.get("nhf", 0), "life_insurance": tax_inputs.get("life_insurance", 0), "union_dues": tax_inputs.get("union_dues", 0), "total": sum([ tax_inputs.get("employee_pension_contribution", 0), tax_inputs.get("nhf", 0), tax_inputs.get("life_insurance", 0), tax_inputs.get("union_dues", 0) ]) }, # Recommendations "recommendations": [asdict(r) for r in recommendations], "recommendation_count": len(recommendations), # Transaction analysis "transaction_summary": classification_summary, "income_breakdown": income_breakdown, "deduction_breakdown": deduction_breakdown, # Taxpayer profile "taxpayer_profile": taxpayer_profile, # Baseline calculation details "baseline_calculation": { "tax_due": baseline_tax, "taxable_income": baseline_result.values.get("taxable_income", 0), "gross_income": baseline_result.values.get("gross_income", 0), "total_deductions": baseline_result.values.get("cra_amount", 0) + tax_inputs.get("employee_pension_contribution", 0) + tax_inputs.get("nhf", 0) + tax_inputs.get("life_insurance", 0) } } def _calculate_tax( self, tax_inputs: Dict[str, float], tax_type: str, tax_year: int, jurisdiction: str ) -> CalculationResult: """Calculate tax using the rules engine""" return self.engine.run( tax_type=tax_type, as_of=date(tax_year, 12, 31), jurisdiction=jurisdiction, inputs=tax_inputs ) def _infer_profile( self, tax_inputs: Dict[str, float], classified_txs: List[Dict[str, Any]] ) -> Dict[str, Any]: """Infer taxpayer profile from transaction patterns""" gross_income = tax_inputs.get("gross_income", 0) turnover = tax_inputs.get("turnover_annual", 0) # Determine taxpayer type if turnover > 0: taxpayer_type = "company" else: taxpayer_type = "individual" # Determine employment status employment_income_txs = [ tx for tx in classified_txs if tx.get("tax_category") == "employment_income" ] business_income_txs = [ tx for tx in classified_txs if tx.get("tax_category") == "business_income" ] if employment_income_txs and not business_income_txs: employment_status = "employed" elif business_income_txs and not employment_income_txs: employment_status = "self_employed" elif employment_income_txs and business_income_txs: employment_status = "mixed" else: employment_status = "unknown" # Check for rental income has_rental_income = any( tx.get("tax_category") == "rental_income" for tx in classified_txs ) return { "taxpayer_type": taxpayer_type, "employment_status": employment_status, "annual_income": gross_income, "annual_turnover": turnover, "has_rental_income": has_rental_income, "inferred": True } def _generate_scenarios( self, baseline_inputs: Dict[str, float], strategies: List[TaxStrategy], opportunities: List[Dict[str, Any]] ) -> List[OptimizationScenario]: """ Generate optimization scenarios dynamically from RAG-extracted strategies NOT hardcoded - uses strategy information from tax documents """ scenarios = [] gross_income = baseline_inputs.get("gross_income", 0) strategy_map = {s.strategy_id: s for s in strategies} # Generate scenarios based on RAG-extracted strategies (not hardcoded) # Pension optimization (if strategy exists from RAG) pension_strategy = strategy_map.get("pit_pension_maximization") if pension_strategy and gross_income > 0: current_pension = baseline_inputs.get("employee_pension_contribution", 0) # Extract maximum percentage from RAG-extracted strategy metadata (NOT hardcoded) max_pct = pension_strategy.metadata.get("max_percentage", 0.20) if hasattr(pension_strategy, 'metadata') and pension_strategy.metadata else 0.20 max_pension = gross_income * max_pct if max_pension > current_pension: max_pension_inputs = baseline_inputs.copy() max_pension_inputs["employee_pension_contribution"] = max_pension scenarios.append(OptimizationScenario( scenario_id="maximize_pension", name=pension_strategy.name, # From RAG description=pension_strategy.description, # From RAG modified_inputs=max_pension_inputs, changes_made={ "pension_contribution": { "from": current_pension, "to": max_pension, "increase": max_pension - current_pension } }, strategy_ids=[pension_strategy.strategy_id] )) # Life insurance (if strategy exists from RAG) insurance_strategy = strategy_map.get("pit_life_insurance") if insurance_strategy: current_insurance = baseline_inputs.get("life_insurance", 0) # Extract suggested premium from RAG-extracted strategy metadata (NOT hardcoded) suggested_premium = insurance_strategy.metadata.get("suggested_premium", gross_income * 0.01) if hasattr(insurance_strategy, 'metadata') and insurance_strategy.metadata else gross_income * 0.01 if suggested_premium > current_insurance: insurance_inputs = baseline_inputs.copy() insurance_inputs["life_insurance"] = suggested_premium scenarios.append(OptimizationScenario( scenario_id="add_life_insurance", name=insurance_strategy.name, # From RAG description=insurance_strategy.description, # From RAG modified_inputs=insurance_inputs, changes_made={ "life_insurance": { "from": current_insurance, "to": suggested_premium, "increase": suggested_premium - current_insurance } }, strategy_ids=[insurance_strategy.strategy_id] )) # Scenario 3: Combined optimization if len(scenarios) > 1: combined_inputs = baseline_inputs.copy() combined_changes = {} combined_strategy_ids = [] for scenario in scenarios: for key, value in scenario.modified_inputs.items(): if value != baseline_inputs.get(key, 0): combined_inputs[key] = value combined_changes[key] = scenario.changes_made.get(key, {}) combined_strategy_ids.extend(scenario.strategy_ids) scenarios.append(OptimizationScenario( scenario_id="combined_optimization", name="Combined Strategy", description="Apply all recommended optimizations together", modified_inputs=combined_inputs, changes_made=combined_changes, strategy_ids=combined_strategy_ids )) return scenarios def _create_recommendations( self, scenario_results: List[Dict[str, Any]], baseline_tax: float, strategies: List[TaxStrategy] ) -> List[OptimizationRecommendation]: """Create ranked recommendations from scenario results""" recommendations = [] strategy_map = {s.strategy_id: s for s in strategies} # Filter scenarios with positive savings viable_scenarios = [ sr for sr in scenario_results if sr["savings"] > 0 ] # Sort by savings viable_scenarios.sort(key=lambda x: x["savings"], reverse=True) for rank, sr in enumerate(viable_scenarios, 1): scenario = sr["scenario"] # Get implementation steps from strategies implementation_steps = [] legal_citations = [] risk_levels = [] for strategy_id in scenario.strategy_ids: strategy = strategy_map.get(strategy_id) if strategy: implementation_steps.extend(strategy.implementation_steps) legal_citations.extend(strategy.legal_citations) risk_levels.append(strategy.risk_level) # Determine overall risk level if "high" in risk_levels: overall_risk = "high" elif "medium" in risk_levels: overall_risk = "medium" else: overall_risk = "low" # Determine complexity num_changes = len(scenario.changes_made) if num_changes == 1: complexity = "easy" elif num_changes == 2: complexity = "medium" else: complexity = "complex" # Calculate confidence score confidence = 0.95 if overall_risk == "low" else (0.80 if overall_risk == "medium" else 0.65) # Generate narrative description using RAG-extracted strategies narrative_description = self._generate_narrative_description( scenario=scenario, savings=sr["savings"], baseline_tax=baseline_tax, optimized_tax=sr["tax"], strategies=strategies # Pass RAG-extracted strategies ) recommendations.append(OptimizationRecommendation( rank=rank, strategy_name=scenario.name, strategy_id=scenario.scenario_id, description=narrative_description, # Use narrative instead of simple description annual_tax_savings=sr["savings"], optimized_tax=sr["tax"], baseline_tax=baseline_tax, implementation_steps=implementation_steps[:5], # Top 5 steps legal_citations=list(set(legal_citations)), # Unique citations risk_level=overall_risk, complexity=complexity, confidence_score=confidence, changes_required=scenario.changes_made )) return recommendations[:10] # Return top 10 recommendations def _generate_narrative_description( self, scenario: OptimizationScenario, savings: float, baseline_tax: float, optimized_tax: float, strategies: List[TaxStrategy] ) -> str: """ Generate a narrative/prose description using RAG-extracted strategy information This is NOT hardcoded - it uses the strategies extracted from tax documents """ changes = scenario.changes_made strategy_map = {s.strategy_id: s for s in strategies} # Get the relevant strategies for this scenario relevant_strategies = [ strategy_map.get(sid) for sid in scenario.strategy_ids if sid in strategy_map ] if not relevant_strategies: # Fallback if no strategy found return ( f"Based on our analysis of your financial profile and Nigerian tax legislation, " f"implementing this strategy will reduce your tax liability from ₦{baseline_tax:,.0f} " f"to ₦{optimized_tax:,.0f}, resulting in annual savings of ₦{savings:,.0f}." ) # Build narrative from RAG-extracted strategy information narrative_parts = [] # Introduction if len(changes) > 1: narrative_parts.append( f"After a comprehensive analysis of your income and current deductions against " f"Nigerian tax legislation, we've identified {len(changes)} optimization opportunities. " ) else: narrative_parts.append( f"After analyzing your financial profile against Nigerian tax legislation, " f"we've identified a key optimization opportunity. " ) # Use strategy descriptions from RAG (not hardcoded) for strategy in relevant_strategies: # Get the strategy description from RAG extraction strategy_desc = strategy.description # Add context about current vs optimal state from transaction analysis change_details = [] for change_key, change_data in changes.items(): if isinstance(change_data, dict): current = change_data.get("from", 0) optimal = change_data.get("to", 0) increase = change_data.get("increase", 0) if increase > 0: change_details.append( f"Your current {change_key.replace('_', ' ')} is ₦{current:,.0f}. " f"{strategy_desc} " f"This means increasing to ₦{optimal:,.0f} (an additional ₦{increase:,.0f})." ) elif optimal > current: change_details.append( f"{strategy_desc} " f"We recommend adjusting from ₦{current:,.0f} to ₦{optimal:,.0f}." ) if change_details: narrative_parts.extend(change_details) # Add savings impact narrative_parts.append( f"Implementing {'these strategies' if len(changes) > 1 else 'this strategy'} " f"will reduce your annual tax liability from ₦{baseline_tax:,.0f} to ₦{optimized_tax:,.0f}, " f"saving you ₦{savings:,.0f} per year." ) # Add legal backing from RAG all_citations = [] for strategy in relevant_strategies: all_citations.extend(strategy.legal_citations) if all_citations: unique_citations = list(set(all_citations)) narrative_parts.append( f"This recommendation is backed by {', '.join(unique_citations[:3])}." ) return " ".join(narrative_parts)