|
|
|
|
|
from __future__ import annotations
|
|
|
from dataclasses import dataclass, field
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Set
|
|
|
from datetime import date, datetime
|
|
|
import math
|
|
|
import yaml
|
|
|
import ast
|
|
|
|
|
|
|
|
|
class SafeEvalError(Exception):
|
|
|
pass
|
|
|
|
|
|
class SafeExpr:
|
|
|
"""
|
|
|
Very small arithmetic evaluator over a dict of variables.
|
|
|
Supports + - * / // % **, parentheses, numbers, names, and
|
|
|
simple calls to min, max, abs, round with at most 2 args.
|
|
|
"""
|
|
|
ALLOWED_FUNCS = {"min": min, "max": max, "abs": abs, "round": round}
|
|
|
ALLOWED_NODES = (
|
|
|
ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Name,
|
|
|
ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow,
|
|
|
ast.USub, ast.UAdd, ast.Call, ast.Tuple, ast.Constant, ast.Compare,
|
|
|
ast.Lt, ast.Gt, ast.LtE, ast.GtE, ast.Eq, ast.NotEq, ast.BoolOp, ast.And, ast.Or,
|
|
|
ast.IfExp, ast.Subscript, ast.Index, ast.Dict, ast.List
|
|
|
)
|
|
|
|
|
|
@classmethod
|
|
|
def eval(cls, expr: str, variables: Dict[str, Any]) -> Any:
|
|
|
try:
|
|
|
tree = ast.parse(expr, mode="eval")
|
|
|
except Exception as e:
|
|
|
raise SafeEvalError(f"Parse error: {e}") from e
|
|
|
if not all(isinstance(n, cls.ALLOWED_NODES) for n in ast.walk(tree)):
|
|
|
raise SafeEvalError("Disallowed syntax in expression")
|
|
|
return cls._eval_node(tree.body, variables)
|
|
|
|
|
|
@classmethod
|
|
|
def _eval_node(cls, node, vars):
|
|
|
if isinstance(node, ast.Constant):
|
|
|
return node.value
|
|
|
if isinstance(node, ast.Num):
|
|
|
return node.n
|
|
|
if isinstance(node, ast.Name):
|
|
|
try:
|
|
|
return vars[node.id]
|
|
|
except KeyError:
|
|
|
raise SafeEvalError(f"Unknown variable '{node.id}'")
|
|
|
if isinstance(node, ast.UnaryOp):
|
|
|
val = cls._eval_node(node.operand, vars)
|
|
|
if isinstance(node.op, ast.UAdd):
|
|
|
return +val
|
|
|
if isinstance(node.op, ast.USub):
|
|
|
return -val
|
|
|
raise SafeEvalError("Unsupported unary op")
|
|
|
if isinstance(node, ast.BinOp):
|
|
|
l = cls._eval_node(node.left, vars)
|
|
|
r = cls._eval_node(node.right, vars)
|
|
|
if isinstance(node.op, ast.Add): return l + r
|
|
|
if isinstance(node.op, ast.Sub): return l - r
|
|
|
if isinstance(node.op, ast.Mult): return l * r
|
|
|
if isinstance(node.op, ast.Div): return l / r
|
|
|
if isinstance(node.op, ast.FloorDiv): return l // r
|
|
|
if isinstance(node.op, ast.Mod): return l % r
|
|
|
if isinstance(node.op, ast.Pow): return l ** r
|
|
|
raise SafeEvalError("Unsupported binary op")
|
|
|
if isinstance(node, ast.Compare):
|
|
|
left = cls._eval_node(node.left, vars)
|
|
|
result = True
|
|
|
cur = left
|
|
|
for op, comparator in zip(node.ops, node.comparators):
|
|
|
right = cls._eval_node(comparator, vars)
|
|
|
if isinstance(op, ast.Lt): ok = cur < right
|
|
|
elif isinstance(op, ast.Gt): ok = cur > right
|
|
|
elif isinstance(op, ast.LtE): ok = cur <= right
|
|
|
elif isinstance(op, ast.GtE): ok = cur >= right
|
|
|
elif isinstance(op, ast.Eq): ok = cur == right
|
|
|
elif isinstance(op, ast.NotEq): ok = cur != right
|
|
|
else: raise SafeEvalError("Unsupported comparator")
|
|
|
result = result and ok
|
|
|
cur = right
|
|
|
return result
|
|
|
if isinstance(node, ast.BoolOp):
|
|
|
vals = [cls._eval_node(v, vars) for v in node.values]
|
|
|
if isinstance(node.op, ast.And):
|
|
|
out = True
|
|
|
for v in vals:
|
|
|
out = out and bool(v)
|
|
|
return out
|
|
|
if isinstance(node.op, ast.Or):
|
|
|
out = False
|
|
|
for v in vals:
|
|
|
out = out or bool(v)
|
|
|
return out
|
|
|
raise SafeEvalError("Unsupported bool op")
|
|
|
if isinstance(node, ast.IfExp):
|
|
|
cond = cls._eval_node(node.test, vars)
|
|
|
return cls._eval_node(node.body if cond else node.orelse, vars)
|
|
|
if isinstance(node, ast.Call):
|
|
|
if not isinstance(node.func, ast.Name):
|
|
|
raise SafeEvalError("Only simple function calls allowed")
|
|
|
fname = node.func.id
|
|
|
if fname not in cls.ALLOWED_FUNCS:
|
|
|
raise SafeEvalError(f"Function '{fname}' not allowed")
|
|
|
args = [cls._eval_node(a, vars) for a in node.args]
|
|
|
if len(args) > 2:
|
|
|
raise SafeEvalError("Too many args")
|
|
|
return cls.ALLOWED_FUNCS[fname](*args)
|
|
|
if isinstance(node, (ast.List, ast.Tuple)):
|
|
|
return [cls._eval_node(e, vars) for e in node.elts]
|
|
|
if isinstance(node, ast.Dict):
|
|
|
return {cls._eval_node(k, vars): cls._eval_node(v, vars) for k, v in zip(node.keys, node.values)}
|
|
|
if isinstance(node, ast.Subscript):
|
|
|
container = cls._eval_node(node.value, vars)
|
|
|
idx = cls._eval_node(node.slice.value if hasattr(node.slice, "value") else node.slice, vars)
|
|
|
return container[idx]
|
|
|
raise SafeEvalError(f"Unsupported node: {type(node).__name__}")
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class AuthorityRef:
|
|
|
doc: str
|
|
|
section: Optional[str] = None
|
|
|
subsection: Optional[str] = None
|
|
|
page: Optional[str] = None
|
|
|
url_anchor: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
|
class RuleAtom:
|
|
|
id: str
|
|
|
title: str
|
|
|
description: str
|
|
|
tax_type: str
|
|
|
jurisdiction_level: str
|
|
|
formula_type: str
|
|
|
inputs: List[str]
|
|
|
output: str
|
|
|
parameters: Dict[str, Any] = field(default_factory=dict)
|
|
|
ordering_constraints: Dict[str, List[str]] = field(default_factory=dict)
|
|
|
effective_from: str = "1900-01-01"
|
|
|
effective_to: Optional[str] = None
|
|
|
authority: List[AuthorityRef] = field(default_factory=list)
|
|
|
notes: Optional[str] = None
|
|
|
status: str = "approved"
|
|
|
|
|
|
def is_active_on(self, on_date: date) -> bool:
|
|
|
|
|
|
if isinstance(self.effective_from, str):
|
|
|
start = datetime.strptime(self.effective_from, "%Y-%m-%d").date()
|
|
|
else:
|
|
|
start = self.effective_from
|
|
|
|
|
|
if self.effective_to is None:
|
|
|
end = datetime.max.date()
|
|
|
elif isinstance(self.effective_to, str):
|
|
|
end = datetime.strptime(self.effective_to, "%Y-%m-%d").date()
|
|
|
else:
|
|
|
end = self.effective_to
|
|
|
|
|
|
return start <= on_date <= end
|
|
|
|
|
|
|
|
|
class RuleCatalog:
|
|
|
def __init__(self, atoms: List[RuleAtom]):
|
|
|
self.atoms = atoms
|
|
|
self._by_id = {a.id: a for a in atoms}
|
|
|
|
|
|
@classmethod
|
|
|
def from_yaml_files(cls, paths: List[str]) -> "RuleCatalog":
|
|
|
atoms: List[RuleAtom] = []
|
|
|
for p in paths:
|
|
|
with open(p, "r", encoding="utf-8") as f:
|
|
|
data = yaml.safe_load(f)
|
|
|
if isinstance(data, dict):
|
|
|
data = [data]
|
|
|
for item in data:
|
|
|
auth = [AuthorityRef(**r) for r in item.get("authority", [])]
|
|
|
atoms.append(RuleAtom(**{**item, "authority": auth}))
|
|
|
return cls(atoms)
|
|
|
|
|
|
def select(self, *, tax_type: str, on_date: date, jurisdiction: Optional[str] = None) -> List[RuleAtom]:
|
|
|
out = []
|
|
|
for a in self.atoms:
|
|
|
if a.tax_type != tax_type:
|
|
|
continue
|
|
|
if jurisdiction and a.jurisdiction_level != jurisdiction:
|
|
|
continue
|
|
|
if not a.is_active_on(on_date):
|
|
|
continue
|
|
|
if a.status == "deprecated":
|
|
|
continue
|
|
|
out.append(a)
|
|
|
return out
|
|
|
|
|
|
class CalculationResult:
|
|
|
def __init__(self):
|
|
|
self.values: Dict[str, float] = {}
|
|
|
self.lines: List[Dict[str, Any]] = []
|
|
|
|
|
|
def set_value(self, key: str, val: float):
|
|
|
self.values[key] = float(val)
|
|
|
|
|
|
def get(self, key: str, default: float = 0.0) -> float:
|
|
|
return float(self.values.get(key, default))
|
|
|
|
|
|
class TaxEngine:
|
|
|
def __init__(self, catalog: RuleCatalog, rounding_mode: str = "half_up"):
|
|
|
self.catalog = catalog
|
|
|
self.rounding_mode = rounding_mode
|
|
|
|
|
|
|
|
|
def _toposort(self, rules: List[RuleAtom]) -> List[RuleAtom]:
|
|
|
after_map: Dict[str, Set[str]] = {}
|
|
|
indeg: Dict[str, int] = {}
|
|
|
id_map = {r.id: r for r in rules}
|
|
|
for r in rules:
|
|
|
deps = set(r.ordering_constraints.get("applies_after", []))
|
|
|
after_map[r.id] = {d for d in deps if d in id_map}
|
|
|
for r in rules:
|
|
|
indeg[r.id] = 0
|
|
|
for r, deps in after_map.items():
|
|
|
for d in deps:
|
|
|
indeg[r] += 1
|
|
|
queue = [rid for rid, deg in indeg.items() if deg == 0]
|
|
|
ordered: List[RuleAtom] = []
|
|
|
while queue:
|
|
|
rid = queue.pop(0)
|
|
|
ordered.append(id_map[rid])
|
|
|
for nid, deps in after_map.items():
|
|
|
if rid in deps:
|
|
|
indeg[nid] -= 1
|
|
|
if indeg[nid] == 0:
|
|
|
queue.append(nid)
|
|
|
if len(ordered) != len(rules):
|
|
|
|
|
|
raise ValueError("Dependency cycle or missing rule id in applies_after")
|
|
|
return ordered
|
|
|
|
|
|
def _round(self, x: float) -> float:
|
|
|
if self.rounding_mode == "half_up":
|
|
|
return float(int(x + 0.5)) if x >= 0 else -float(int(abs(x) + 0.5))
|
|
|
return round(x)
|
|
|
|
|
|
def _evaluate_rule(self, r: RuleAtom, ctx: CalculationResult) -> Tuple[str, float, Dict[str, Any]]:
|
|
|
v = ctx.values
|
|
|
|
|
|
def ex(expr: str) -> float:
|
|
|
return float(SafeExpr.eval(expr, v))
|
|
|
|
|
|
details: Dict[str, Any] = {}
|
|
|
|
|
|
if r.formula_type == "fixed_amount":
|
|
|
amt = ex(r.parameters.get("amount_expr", "0"))
|
|
|
elif r.formula_type == "rate_on_base":
|
|
|
base = ex(r.parameters.get("base_expr", "0"))
|
|
|
rate = float(r.parameters.get("rate", 0))
|
|
|
amt = base * rate
|
|
|
details.update({"base": base, "rate": rate})
|
|
|
elif r.formula_type == "capped_percentage":
|
|
|
base = ex(r.parameters.get("base_expr", "0"))
|
|
|
cap_rate = float(r.parameters.get("cap_rate", 0))
|
|
|
amt = min(base, base * cap_rate)
|
|
|
details.update({"base": base, "cap_rate": cap_rate})
|
|
|
elif r.formula_type == "max_of_plus":
|
|
|
base_opts = [ex(opt.get("expr", "0")) for opt in r.parameters.get("base_options", [])]
|
|
|
plus_expr = r.parameters.get("plus_expr", "0")
|
|
|
plus = ex(plus_expr) if plus_expr else 0.0
|
|
|
amt = max(base_opts) + plus if base_opts else plus
|
|
|
details.update({"base_options": base_opts, "plus": plus})
|
|
|
elif r.formula_type == "piecewise_bands":
|
|
|
taxable = ex(r.parameters.get("base_expr", "0"))
|
|
|
bands = r.parameters.get("bands", [])
|
|
|
remaining = taxable
|
|
|
tax = 0.0
|
|
|
calc_steps = []
|
|
|
prev_upper = 0.0
|
|
|
for b in bands:
|
|
|
upper = float("inf") if b.get("up_to") is None else float(b["up_to"])
|
|
|
rate = float(b["rate"])
|
|
|
chunk = max(0.0, min(remaining, upper - prev_upper))
|
|
|
if chunk > 0:
|
|
|
part = chunk * rate
|
|
|
tax += part
|
|
|
calc_steps.append({"range": [prev_upper, upper], "chunk": chunk, "rate": rate, "tax": part})
|
|
|
remaining -= chunk
|
|
|
prev_upper = upper
|
|
|
if remaining <= 0:
|
|
|
break
|
|
|
amt = tax
|
|
|
details.update({"base": taxable, "bands_applied": calc_steps})
|
|
|
elif r.formula_type == "conditional_min":
|
|
|
computed = ex(r.parameters.get("computed_expr", "computed_tax"))
|
|
|
min_amount = ex(r.parameters.get("min_amount_expr", "0"))
|
|
|
amt = max(computed, min_amount)
|
|
|
details.update({"computed": computed, "minimum": min_amount})
|
|
|
else:
|
|
|
raise ValueError(f"Unknown formula_type: {r.formula_type}")
|
|
|
|
|
|
amt = self._round(amt) if r.parameters.get("round", False) else amt
|
|
|
return r.output, amt, details
|
|
|
|
|
|
def run(
|
|
|
self,
|
|
|
*,
|
|
|
tax_type: str,
|
|
|
as_of: date,
|
|
|
jurisdiction: Optional[str],
|
|
|
inputs: Dict[str, float],
|
|
|
rule_ids_whitelist: Optional[List[str]] = None
|
|
|
) -> CalculationResult:
|
|
|
active = self.catalog.select(tax_type=tax_type, on_date=as_of, jurisdiction=jurisdiction)
|
|
|
if rule_ids_whitelist:
|
|
|
idset = set(rule_ids_whitelist)
|
|
|
active = [r for r in active if r.id in idset]
|
|
|
|
|
|
ordered = self._toposort(active)
|
|
|
ctx = CalculationResult()
|
|
|
|
|
|
for k, v in inputs.items():
|
|
|
ctx.set_value(k, float(v))
|
|
|
|
|
|
for r in ordered:
|
|
|
|
|
|
guard = r.parameters.get("applicability_expr")
|
|
|
if guard:
|
|
|
try:
|
|
|
applies = bool(SafeExpr.eval(guard, ctx.values))
|
|
|
except Exception as e:
|
|
|
raise SafeEvalError(f"Guard error in {r.id}: {e}")
|
|
|
if not applies:
|
|
|
continue
|
|
|
|
|
|
out_key, amount, details = self._evaluate_rule(r, ctx)
|
|
|
ctx.set_value(out_key, amount)
|
|
|
ctx.lines.append({
|
|
|
"rule_id": r.id,
|
|
|
"title": r.title,
|
|
|
"amount": amount,
|
|
|
"output": out_key,
|
|
|
"details": details,
|
|
|
"authority": [a.__dict__ for a in r.authority],
|
|
|
})
|
|
|
return ctx
|
|
|
|