Kaanta / tax_optimizer.py
Oluwaferanmi
This is the latest changes
66d6b11
# tax_optimizer.py
"""
Main Tax Optimization Engine
Integrates classifier, aggregator, strategy extractor, and tax engine
"""
from __future__ import annotations
from typing import Dict, List, Any, Optional
from datetime import date
from dataclasses import dataclass, asdict
from transaction_classifier import TransactionClassifier
from transaction_aggregator import TransactionAggregator
from tax_strategy_extractor import TaxStrategyExtractor, TaxStrategy
from rules_engine import TaxEngine, CalculationResult
@dataclass
class OptimizationScenario:
"""Represents a tax optimization scenario"""
scenario_id: str
name: str
description: str
modified_inputs: Dict[str, float]
changes_made: Dict[str, Any]
strategy_ids: List[str]
@dataclass
class OptimizationRecommendation:
"""A single tax optimization recommendation"""
rank: int
strategy_name: str
strategy_id: str
description: str
annual_tax_savings: float
optimized_tax: float
baseline_tax: float
implementation_steps: List[str]
legal_citations: List[str]
risk_level: str
complexity: str
confidence_score: float
changes_required: Dict[str, Any]
class TaxOptimizer:
"""
Main tax optimization engine
Analyzes transactions and generates optimization recommendations
"""
def __init__(
self,
classifier: TransactionClassifier,
aggregator: TransactionAggregator,
strategy_extractor: TaxStrategyExtractor,
tax_engine: TaxEngine
):
"""
Initialize optimizer with required components
Args:
classifier: TransactionClassifier instance
aggregator: TransactionAggregator instance
strategy_extractor: TaxStrategyExtractor instance
tax_engine: TaxEngine instance
"""
self.classifier = classifier
self.aggregator = aggregator
self.strategy_extractor = strategy_extractor
self.engine = tax_engine
def optimize(
self,
user_id: str,
transactions: List[Dict[str, Any]],
taxpayer_profile: Optional[Dict[str, Any]] = None,
tax_year: int = 2025,
tax_type: str = "PIT",
jurisdiction: str = "state"
) -> Dict[str, Any]:
"""
Main optimization workflow
Args:
user_id: Unique user identifier
transactions: List of transactions from Mono API + manual entry
taxpayer_profile: Optional profile info (auto-inferred if not provided)
tax_year: Tax year to optimize for
tax_type: PIT, CIT, or VAT
jurisdiction: federal or state
Returns:
Comprehensive optimization report
"""
# Step 1: Classify transactions
print(f"[Optimizer] Classifying {len(transactions)} transactions...")
classified_txs = self.classifier.classify_batch(transactions)
# Step 2: Aggregate into tax inputs
print(f"[Optimizer] Aggregating transactions for tax year {tax_year}...")
tax_inputs = self.aggregator.aggregate_for_tax_year(classified_txs, tax_year)
# Step 3: Infer taxpayer profile if not provided
if not taxpayer_profile:
taxpayer_profile = self._infer_profile(tax_inputs, classified_txs)
# Add annual income to profile
taxpayer_profile["annual_income"] = tax_inputs.get("gross_income", 0)
# Step 4: Calculate baseline tax
print(f"[Optimizer] Calculating baseline tax liability...")
baseline_result = self._calculate_tax(
tax_inputs=tax_inputs,
tax_type=tax_type,
tax_year=tax_year,
jurisdiction=jurisdiction
)
baseline_tax = baseline_result.values.get("tax_due", 0)
# Step 5: Extract applicable strategies
print(f"[Optimizer] Extracting optimization strategies...")
strategies = self.strategy_extractor.extract_strategies_for_profile(
taxpayer_profile=taxpayer_profile,
tax_year=tax_year
)
# Step 6: Identify opportunities from transaction analysis
print(f"[Optimizer] Identifying optimization opportunities...")
opportunities = self.aggregator.identify_optimization_opportunities(
aggregated=tax_inputs,
tax_year=tax_year
)
# Step 7: Generate optimization scenarios
print(f"[Optimizer] Generating optimization scenarios...")
scenarios = self._generate_scenarios(
baseline_inputs=tax_inputs,
strategies=strategies,
opportunities=opportunities
)
# Step 8: Simulate each scenario
print(f"[Optimizer] Simulating {len(scenarios)} scenarios...")
scenario_results = []
for scenario in scenarios:
result = self._calculate_tax(
tax_inputs=scenario.modified_inputs,
tax_type=tax_type,
tax_year=tax_year,
jurisdiction=jurisdiction
)
scenario_tax = result.values.get("tax_due", 0)
savings = baseline_tax - scenario_tax
scenario_results.append({
"scenario": scenario,
"tax": scenario_tax,
"savings": savings,
"result": result
})
# Step 9: Rank and create recommendations
print(f"[Optimizer] Ranking recommendations...")
recommendations = self._create_recommendations(
scenario_results=scenario_results,
baseline_tax=baseline_tax,
strategies=strategies
)
# Step 10: Generate comprehensive report
classification_summary = self.classifier.get_classification_summary(classified_txs)
income_breakdown = self.aggregator.get_income_breakdown(classified_txs, tax_year)
deduction_breakdown = self.aggregator.get_deduction_breakdown(classified_txs, tax_year)
# Calculate total potential savings
total_potential_savings = sum(r.annual_tax_savings for r in recommendations)
optimized_tax = baseline_tax - total_potential_savings if recommendations else baseline_tax
return {
"user_id": user_id,
"tax_year": tax_year,
"tax_type": tax_type,
"analysis_date": date.today().isoformat(),
# Tax summary
"baseline_tax_liability": baseline_tax,
"optimized_tax_liability": optimized_tax,
"total_potential_savings": total_potential_savings,
"savings_percentage": (total_potential_savings / baseline_tax * 100) if baseline_tax > 0 else 0,
# Income & deductions
"total_annual_income": tax_inputs.get("gross_income", 0),
"current_deductions": {
"pension": tax_inputs.get("employee_pension_contribution", 0),
"nhf": tax_inputs.get("nhf", 0),
"life_insurance": tax_inputs.get("life_insurance", 0),
"union_dues": tax_inputs.get("union_dues", 0),
"total": sum([
tax_inputs.get("employee_pension_contribution", 0),
tax_inputs.get("nhf", 0),
tax_inputs.get("life_insurance", 0),
tax_inputs.get("union_dues", 0)
])
},
# Recommendations
"recommendations": [asdict(r) for r in recommendations],
"recommendation_count": len(recommendations),
# Transaction analysis
"transaction_summary": classification_summary,
"income_breakdown": income_breakdown,
"deduction_breakdown": deduction_breakdown,
# Taxpayer profile
"taxpayer_profile": taxpayer_profile,
# Baseline calculation details
"baseline_calculation": {
"tax_due": baseline_tax,
"taxable_income": baseline_result.values.get("taxable_income", 0),
"gross_income": baseline_result.values.get("gross_income", 0),
"total_deductions": baseline_result.values.get("cra_amount", 0) +
tax_inputs.get("employee_pension_contribution", 0) +
tax_inputs.get("nhf", 0) +
tax_inputs.get("life_insurance", 0)
}
}
def _calculate_tax(
self,
tax_inputs: Dict[str, float],
tax_type: str,
tax_year: int,
jurisdiction: str
) -> CalculationResult:
"""Calculate tax using the rules engine"""
return self.engine.run(
tax_type=tax_type,
as_of=date(tax_year, 12, 31),
jurisdiction=jurisdiction,
inputs=tax_inputs
)
def _infer_profile(
self,
tax_inputs: Dict[str, float],
classified_txs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Infer taxpayer profile from transaction patterns"""
gross_income = tax_inputs.get("gross_income", 0)
turnover = tax_inputs.get("turnover_annual", 0)
# Determine taxpayer type
if turnover > 0:
taxpayer_type = "company"
else:
taxpayer_type = "individual"
# Determine employment status
employment_income_txs = [
tx for tx in classified_txs
if tx.get("tax_category") == "employment_income"
]
business_income_txs = [
tx for tx in classified_txs
if tx.get("tax_category") == "business_income"
]
if employment_income_txs and not business_income_txs:
employment_status = "employed"
elif business_income_txs and not employment_income_txs:
employment_status = "self_employed"
elif employment_income_txs and business_income_txs:
employment_status = "mixed"
else:
employment_status = "unknown"
# Check for rental income
has_rental_income = any(
tx.get("tax_category") == "rental_income"
for tx in classified_txs
)
return {
"taxpayer_type": taxpayer_type,
"employment_status": employment_status,
"annual_income": gross_income,
"annual_turnover": turnover,
"has_rental_income": has_rental_income,
"inferred": True
}
def _generate_scenarios(
self,
baseline_inputs: Dict[str, float],
strategies: List[TaxStrategy],
opportunities: List[Dict[str, Any]]
) -> List[OptimizationScenario]:
"""
Generate optimization scenarios dynamically from RAG-extracted strategies
NOT hardcoded - uses strategy information from tax documents
"""
scenarios = []
gross_income = baseline_inputs.get("gross_income", 0)
strategy_map = {s.strategy_id: s for s in strategies}
# Generate scenarios based on RAG-extracted strategies (not hardcoded)
# Pension optimization (if strategy exists from RAG)
pension_strategy = strategy_map.get("pit_pension_maximization")
if pension_strategy and gross_income > 0:
current_pension = baseline_inputs.get("employee_pension_contribution", 0)
# Extract maximum percentage from RAG-extracted strategy metadata (NOT hardcoded)
max_pct = pension_strategy.metadata.get("max_percentage", 0.20) if hasattr(pension_strategy, 'metadata') and pension_strategy.metadata else 0.20
max_pension = gross_income * max_pct
if max_pension > current_pension:
max_pension_inputs = baseline_inputs.copy()
max_pension_inputs["employee_pension_contribution"] = max_pension
scenarios.append(OptimizationScenario(
scenario_id="maximize_pension",
name=pension_strategy.name, # From RAG
description=pension_strategy.description, # From RAG
modified_inputs=max_pension_inputs,
changes_made={
"pension_contribution": {
"from": current_pension,
"to": max_pension,
"increase": max_pension - current_pension
}
},
strategy_ids=[pension_strategy.strategy_id]
))
# Life insurance (if strategy exists from RAG)
insurance_strategy = strategy_map.get("pit_life_insurance")
if insurance_strategy:
current_insurance = baseline_inputs.get("life_insurance", 0)
# Extract suggested premium from RAG-extracted strategy metadata (NOT hardcoded)
suggested_premium = insurance_strategy.metadata.get("suggested_premium", gross_income * 0.01) if hasattr(insurance_strategy, 'metadata') and insurance_strategy.metadata else gross_income * 0.01
if suggested_premium > current_insurance:
insurance_inputs = baseline_inputs.copy()
insurance_inputs["life_insurance"] = suggested_premium
scenarios.append(OptimizationScenario(
scenario_id="add_life_insurance",
name=insurance_strategy.name, # From RAG
description=insurance_strategy.description, # From RAG
modified_inputs=insurance_inputs,
changes_made={
"life_insurance": {
"from": current_insurance,
"to": suggested_premium,
"increase": suggested_premium - current_insurance
}
},
strategy_ids=[insurance_strategy.strategy_id]
))
# Scenario 3: Combined optimization
if len(scenarios) > 1:
combined_inputs = baseline_inputs.copy()
combined_changes = {}
combined_strategy_ids = []
for scenario in scenarios:
for key, value in scenario.modified_inputs.items():
if value != baseline_inputs.get(key, 0):
combined_inputs[key] = value
combined_changes[key] = scenario.changes_made.get(key, {})
combined_strategy_ids.extend(scenario.strategy_ids)
scenarios.append(OptimizationScenario(
scenario_id="combined_optimization",
name="Combined Strategy",
description="Apply all recommended optimizations together",
modified_inputs=combined_inputs,
changes_made=combined_changes,
strategy_ids=combined_strategy_ids
))
return scenarios
def _create_recommendations(
self,
scenario_results: List[Dict[str, Any]],
baseline_tax: float,
strategies: List[TaxStrategy]
) -> List[OptimizationRecommendation]:
"""Create ranked recommendations from scenario results"""
recommendations = []
strategy_map = {s.strategy_id: s for s in strategies}
# Filter scenarios with positive savings
viable_scenarios = [
sr for sr in scenario_results
if sr["savings"] > 0
]
# Sort by savings
viable_scenarios.sort(key=lambda x: x["savings"], reverse=True)
for rank, sr in enumerate(viable_scenarios, 1):
scenario = sr["scenario"]
# Get implementation steps from strategies
implementation_steps = []
legal_citations = []
risk_levels = []
for strategy_id in scenario.strategy_ids:
strategy = strategy_map.get(strategy_id)
if strategy:
implementation_steps.extend(strategy.implementation_steps)
legal_citations.extend(strategy.legal_citations)
risk_levels.append(strategy.risk_level)
# Determine overall risk level
if "high" in risk_levels:
overall_risk = "high"
elif "medium" in risk_levels:
overall_risk = "medium"
else:
overall_risk = "low"
# Determine complexity
num_changes = len(scenario.changes_made)
if num_changes == 1:
complexity = "easy"
elif num_changes == 2:
complexity = "medium"
else:
complexity = "complex"
# Calculate confidence score
confidence = 0.95 if overall_risk == "low" else (0.80 if overall_risk == "medium" else 0.65)
# Generate narrative description using RAG-extracted strategies
narrative_description = self._generate_narrative_description(
scenario=scenario,
savings=sr["savings"],
baseline_tax=baseline_tax,
optimized_tax=sr["tax"],
strategies=strategies # Pass RAG-extracted strategies
)
recommendations.append(OptimizationRecommendation(
rank=rank,
strategy_name=scenario.name,
strategy_id=scenario.scenario_id,
description=narrative_description, # Use narrative instead of simple description
annual_tax_savings=sr["savings"],
optimized_tax=sr["tax"],
baseline_tax=baseline_tax,
implementation_steps=implementation_steps[:5], # Top 5 steps
legal_citations=list(set(legal_citations)), # Unique citations
risk_level=overall_risk,
complexity=complexity,
confidence_score=confidence,
changes_required=scenario.changes_made
))
return recommendations[:10] # Return top 10 recommendations
def _generate_narrative_description(
self,
scenario: OptimizationScenario,
savings: float,
baseline_tax: float,
optimized_tax: float,
strategies: List[TaxStrategy]
) -> str:
"""
Generate a narrative/prose description using RAG-extracted strategy information
This is NOT hardcoded - it uses the strategies extracted from tax documents
"""
changes = scenario.changes_made
strategy_map = {s.strategy_id: s for s in strategies}
# Get the relevant strategies for this scenario
relevant_strategies = [
strategy_map.get(sid) for sid in scenario.strategy_ids
if sid in strategy_map
]
if not relevant_strategies:
# Fallback if no strategy found
return (
f"Based on our analysis of your financial profile and Nigerian tax legislation, "
f"implementing this strategy will reduce your tax liability from ₦{baseline_tax:,.0f} "
f"to ₦{optimized_tax:,.0f}, resulting in annual savings of ₦{savings:,.0f}."
)
# Build narrative from RAG-extracted strategy information
narrative_parts = []
# Introduction
if len(changes) > 1:
narrative_parts.append(
f"After a comprehensive analysis of your income and current deductions against "
f"Nigerian tax legislation, we've identified {len(changes)} optimization opportunities. "
)
else:
narrative_parts.append(
f"After analyzing your financial profile against Nigerian tax legislation, "
f"we've identified a key optimization opportunity. "
)
# Use strategy descriptions from RAG (not hardcoded)
for strategy in relevant_strategies:
# Get the strategy description from RAG extraction
strategy_desc = strategy.description
# Add context about current vs optimal state from transaction analysis
change_details = []
for change_key, change_data in changes.items():
if isinstance(change_data, dict):
current = change_data.get("from", 0)
optimal = change_data.get("to", 0)
increase = change_data.get("increase", 0)
if increase > 0:
change_details.append(
f"Your current {change_key.replace('_', ' ')} is ₦{current:,.0f}. "
f"{strategy_desc} "
f"This means increasing to ₦{optimal:,.0f} (an additional ₦{increase:,.0f})."
)
elif optimal > current:
change_details.append(
f"{strategy_desc} "
f"We recommend adjusting from ₦{current:,.0f} to ₦{optimal:,.0f}."
)
if change_details:
narrative_parts.extend(change_details)
# Add savings impact
narrative_parts.append(
f"Implementing {'these strategies' if len(changes) > 1 else 'this strategy'} "
f"will reduce your annual tax liability from ₦{baseline_tax:,.0f} to ₦{optimized_tax:,.0f}, "
f"saving you ₦{savings:,.0f} per year."
)
# Add legal backing from RAG
all_citations = []
for strategy in relevant_strategies:
all_citations.extend(strategy.legal_citations)
if all_citations:
unique_citations = list(set(all_citations))
narrative_parts.append(
f"This recommendation is backed by {', '.join(unique_citations[:3])}."
)
return " ".join(narrative_parts)