Kaanta / rules_engine.py
Eniiyanu's picture
Upload 22 files
2d58264 verified
# rules_engine.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Set
from datetime import date, datetime
import math
import yaml
import ast
# ------------- Safe expression evaluator -------------
class SafeEvalError(Exception):
pass
class SafeExpr:
"""
Very small arithmetic evaluator over a dict of variables.
Supports + - * / // % **, parentheses, numbers, names, and
simple calls to min, max, abs, round with at most 2 args.
"""
ALLOWED_FUNCS = {"min": min, "max": max, "abs": abs, "round": round}
ALLOWED_NODES = (
ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Name,
ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow,
ast.USub, ast.UAdd, ast.Call, ast.Tuple, ast.Constant, ast.Compare,
ast.Lt, ast.Gt, ast.LtE, ast.GtE, ast.Eq, ast.NotEq, ast.BoolOp, ast.And, ast.Or,
ast.IfExp, ast.Subscript, ast.Index, ast.Dict, ast.List
)
@classmethod
def eval(cls, expr: str, variables: Dict[str, Any]) -> Any:
try:
tree = ast.parse(expr, mode="eval")
except Exception as e:
raise SafeEvalError(f"Parse error: {e}") from e
if not all(isinstance(n, cls.ALLOWED_NODES) for n in ast.walk(tree)):
raise SafeEvalError("Disallowed syntax in expression")
return cls._eval_node(tree.body, variables)
@classmethod
def _eval_node(cls, node, vars):
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.Num): # py<3.8
return node.n
if isinstance(node, ast.Name):
try:
return vars[node.id]
except KeyError:
raise SafeEvalError(f"Unknown variable '{node.id}'")
if isinstance(node, ast.UnaryOp):
val = cls._eval_node(node.operand, vars)
if isinstance(node.op, ast.UAdd):
return +val
if isinstance(node.op, ast.USub):
return -val
raise SafeEvalError("Unsupported unary op")
if isinstance(node, ast.BinOp):
l = cls._eval_node(node.left, vars)
r = cls._eval_node(node.right, vars)
if isinstance(node.op, ast.Add): return l + r
if isinstance(node.op, ast.Sub): return l - r
if isinstance(node.op, ast.Mult): return l * r
if isinstance(node.op, ast.Div): return l / r
if isinstance(node.op, ast.FloorDiv): return l // r
if isinstance(node.op, ast.Mod): return l % r
if isinstance(node.op, ast.Pow): return l ** r
raise SafeEvalError("Unsupported binary op")
if isinstance(node, ast.Compare):
left = cls._eval_node(node.left, vars)
result = True
cur = left
for op, comparator in zip(node.ops, node.comparators):
right = cls._eval_node(comparator, vars)
if isinstance(op, ast.Lt): ok = cur < right
elif isinstance(op, ast.Gt): ok = cur > right
elif isinstance(op, ast.LtE): ok = cur <= right
elif isinstance(op, ast.GtE): ok = cur >= right
elif isinstance(op, ast.Eq): ok = cur == right
elif isinstance(op, ast.NotEq): ok = cur != right
else: raise SafeEvalError("Unsupported comparator")
result = result and ok
cur = right
return result
if isinstance(node, ast.BoolOp):
vals = [cls._eval_node(v, vars) for v in node.values]
if isinstance(node.op, ast.And):
out = True
for v in vals:
out = out and bool(v)
return out
if isinstance(node.op, ast.Or):
out = False
for v in vals:
out = out or bool(v)
return out
raise SafeEvalError("Unsupported bool op")
if isinstance(node, ast.IfExp):
cond = cls._eval_node(node.test, vars)
return cls._eval_node(node.body if cond else node.orelse, vars)
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise SafeEvalError("Only simple function calls allowed")
fname = node.func.id
if fname not in cls.ALLOWED_FUNCS:
raise SafeEvalError(f"Function '{fname}' not allowed")
args = [cls._eval_node(a, vars) for a in node.args]
if len(args) > 2:
raise SafeEvalError("Too many args")
return cls.ALLOWED_FUNCS[fname](*args)
if isinstance(node, (ast.List, ast.Tuple)):
return [cls._eval_node(e, vars) for e in node.elts]
if isinstance(node, ast.Dict):
return {cls._eval_node(k, vars): cls._eval_node(v, vars) for k, v in zip(node.keys, node.values)}
if isinstance(node, ast.Subscript):
container = cls._eval_node(node.value, vars)
idx = cls._eval_node(node.slice.value if hasattr(node.slice, "value") else node.slice, vars)
return container[idx]
raise SafeEvalError(f"Unsupported node: {type(node).__name__}")
# ------------- Rule atoms -------------
@dataclass
class AuthorityRef:
doc: str
section: Optional[str] = None
subsection: Optional[str] = None
page: Optional[str] = None
url_anchor: Optional[str] = None
@dataclass
class RuleAtom:
id: str
title: str
description: str
tax_type: str # eg "PIT", "CIT", "VAT"
jurisdiction_level: str # eg "federal", "state"
formula_type: str # piecewise_bands, capped_percentage, etc
inputs: List[str]
output: str
parameters: Dict[str, Any] = field(default_factory=dict)
ordering_constraints: Dict[str, List[str]] = field(default_factory=dict)
effective_from: str = "1900-01-01"
effective_to: Optional[str] = None
authority: List[AuthorityRef] = field(default_factory=list)
notes: Optional[str] = None
status: str = "approved" # draft, approved, deprecated
def is_active_on(self, on_date: date) -> bool:
# Handle both string and date objects
if isinstance(self.effective_from, str):
start = datetime.strptime(self.effective_from, "%Y-%m-%d").date()
else:
start = self.effective_from
if self.effective_to is None:
end = datetime.max.date()
elif isinstance(self.effective_to, str):
end = datetime.strptime(self.effective_to, "%Y-%m-%d").date()
else:
end = self.effective_to
return start <= on_date <= end
# ------------- Engine core -------------
class RuleCatalog:
def __init__(self, atoms: List[RuleAtom]):
self.atoms = atoms
self._by_id = {a.id: a for a in atoms}
@classmethod
def from_yaml_files(cls, paths: List[str]) -> "RuleCatalog":
atoms: List[RuleAtom] = []
for p in paths:
with open(p, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if isinstance(data, dict):
data = [data]
for item in data:
auth = [AuthorityRef(**r) for r in item.get("authority", [])]
atoms.append(RuleAtom(**{**item, "authority": auth}))
return cls(atoms)
def select(self, *, tax_type: str, on_date: date, jurisdiction: Optional[str] = None) -> List[RuleAtom]:
out = []
for a in self.atoms:
if a.tax_type != tax_type:
continue
if jurisdiction and a.jurisdiction_level != jurisdiction:
continue
if not a.is_active_on(on_date):
continue
if a.status == "deprecated":
continue
out.append(a)
return out
class CalculationResult:
def __init__(self):
self.values: Dict[str, float] = {}
self.lines: List[Dict[str, Any]] = [] # each line: rule_id, title, amount, details, authority
def set_value(self, key: str, val: float):
self.values[key] = float(val)
def get(self, key: str, default: float = 0.0) -> float:
return float(self.values.get(key, default))
class TaxEngine:
def __init__(self, catalog: RuleCatalog, rounding_mode: str = "half_up"):
self.catalog = catalog
self.rounding_mode = rounding_mode
# dependency ordering
def _toposort(self, rules: List[RuleAtom]) -> List[RuleAtom]:
after_map: Dict[str, Set[str]] = {}
indeg: Dict[str, int] = {}
id_map = {r.id: r for r in rules}
for r in rules:
deps = set(r.ordering_constraints.get("applies_after", []))
after_map[r.id] = {d for d in deps if d in id_map}
for r in rules:
indeg[r.id] = 0
for r, deps in after_map.items():
for d in deps:
indeg[r] += 1
queue = [rid for rid, deg in indeg.items() if deg == 0]
ordered: List[RuleAtom] = []
while queue:
rid = queue.pop(0)
ordered.append(id_map[rid])
for nid, deps in after_map.items():
if rid in deps:
indeg[nid] -= 1
if indeg[nid] == 0:
queue.append(nid)
if len(ordered) != len(rules):
# cycle detected or missing ids
raise ValueError("Dependency cycle or missing rule id in applies_after")
return ordered
def _round(self, x: float) -> float:
if self.rounding_mode == "half_up":
return float(int(x + 0.5)) if x >= 0 else -float(int(abs(x) + 0.5))
return round(x)
def _evaluate_rule(self, r: RuleAtom, ctx: CalculationResult) -> Tuple[str, float, Dict[str, Any]]:
v = ctx.values # shorthand
def ex(expr: str) -> float:
return float(SafeExpr.eval(expr, v))
details: Dict[str, Any] = {}
if r.formula_type == "fixed_amount":
amt = ex(r.parameters.get("amount_expr", "0"))
elif r.formula_type == "rate_on_base":
base = ex(r.parameters.get("base_expr", "0"))
rate = float(r.parameters.get("rate", 0))
amt = base * rate
details.update({"base": base, "rate": rate})
elif r.formula_type == "capped_percentage":
base = ex(r.parameters.get("base_expr", "0"))
cap_rate = float(r.parameters.get("cap_rate", 0))
amt = min(base, base * cap_rate)
details.update({"base": base, "cap_rate": cap_rate})
elif r.formula_type == "max_of_plus":
base_opts = [ex(opt.get("expr", "0")) for opt in r.parameters.get("base_options", [])]
plus_expr = r.parameters.get("plus_expr", "0")
plus = ex(plus_expr) if plus_expr else 0.0
amt = max(base_opts) + plus if base_opts else plus
details.update({"base_options": base_opts, "plus": plus})
elif r.formula_type == "piecewise_bands":
taxable = ex(r.parameters.get("base_expr", "0"))
bands = r.parameters.get("bands", [])
remaining = taxable
tax = 0.0
calc_steps = []
prev_upper = 0.0
for b in bands:
upper = float("inf") if b.get("up_to") is None else float(b["up_to"])
rate = float(b["rate"])
chunk = max(0.0, min(remaining, upper - prev_upper))
if chunk > 0:
part = chunk * rate
tax += part
calc_steps.append({"range": [prev_upper, upper], "chunk": chunk, "rate": rate, "tax": part})
remaining -= chunk
prev_upper = upper
if remaining <= 0:
break
amt = tax
details.update({"base": taxable, "bands_applied": calc_steps})
elif r.formula_type == "conditional_min":
computed = ex(r.parameters.get("computed_expr", "computed_tax"))
min_amount = ex(r.parameters.get("min_amount_expr", "0"))
amt = max(computed, min_amount)
details.update({"computed": computed, "minimum": min_amount})
else:
raise ValueError(f"Unknown formula_type: {r.formula_type}")
amt = self._round(amt) if r.parameters.get("round", False) else amt
return r.output, amt, details
def run(
self,
*,
tax_type: str,
as_of: date,
jurisdiction: Optional[str],
inputs: Dict[str, float],
rule_ids_whitelist: Optional[List[str]] = None
) -> CalculationResult:
active = self.catalog.select(tax_type=tax_type, on_date=as_of, jurisdiction=jurisdiction)
if rule_ids_whitelist:
idset = set(rule_ids_whitelist)
active = [r for r in active if r.id in idset]
ordered = self._toposort(active)
ctx = CalculationResult()
# seed inputs
for k, v in inputs.items():
ctx.set_value(k, float(v))
for r in ordered:
# allow guard expressions like "applicability_expr": "employment_income > 0"
guard = r.parameters.get("applicability_expr")
if guard:
try:
applies = bool(SafeExpr.eval(guard, ctx.values))
except Exception as e:
raise SafeEvalError(f"Guard error in {r.id}: {e}")
if not applies:
continue
out_key, amount, details = self._evaluate_rule(r, ctx)
ctx.set_value(out_key, amount)
ctx.lines.append({
"rule_id": r.id,
"title": r.title,
"amount": amount,
"output": out_key,
"details": details,
"authority": [a.__dict__ for a in r.authority],
})
return ctx