|
|
|
|
|
from __future__ import annotations |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, List, Optional, Tuple, Set |
|
|
from datetime import date, datetime |
|
|
import math |
|
|
import yaml |
|
|
import ast |
|
|
|
|
|
|
|
|
class SafeEvalError(Exception): |
|
|
pass |
|
|
|
|
|
class SafeExpr: |
|
|
""" |
|
|
Very small arithmetic evaluator over a dict of variables. |
|
|
Supports + - * / // % **, parentheses, numbers, names, and |
|
|
simple calls to min, max, abs, round with at most 2 args. |
|
|
""" |
|
|
ALLOWED_FUNCS = {"min": min, "max": max, "abs": abs, "round": round} |
|
|
ALLOWED_NODES = ( |
|
|
ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Name, |
|
|
ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, |
|
|
ast.USub, ast.UAdd, ast.Call, ast.Tuple, ast.Constant, ast.Compare, |
|
|
ast.Lt, ast.Gt, ast.LtE, ast.GtE, ast.Eq, ast.NotEq, ast.BoolOp, ast.And, ast.Or, |
|
|
ast.IfExp, ast.Subscript, ast.Index, ast.Dict, ast.List |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def eval(cls, expr: str, variables: Dict[str, Any]) -> Any: |
|
|
try: |
|
|
tree = ast.parse(expr, mode="eval") |
|
|
except Exception as e: |
|
|
raise SafeEvalError(f"Parse error: {e}") from e |
|
|
if not all(isinstance(n, cls.ALLOWED_NODES) for n in ast.walk(tree)): |
|
|
raise SafeEvalError("Disallowed syntax in expression") |
|
|
return cls._eval_node(tree.body, variables) |
|
|
|
|
|
@classmethod |
|
|
def _eval_node(cls, node, vars): |
|
|
if isinstance(node, ast.Constant): |
|
|
return node.value |
|
|
if isinstance(node, ast.Num): |
|
|
return node.n |
|
|
if isinstance(node, ast.Name): |
|
|
try: |
|
|
return vars[node.id] |
|
|
except KeyError: |
|
|
raise SafeEvalError(f"Unknown variable '{node.id}'") |
|
|
if isinstance(node, ast.UnaryOp): |
|
|
val = cls._eval_node(node.operand, vars) |
|
|
if isinstance(node.op, ast.UAdd): |
|
|
return +val |
|
|
if isinstance(node.op, ast.USub): |
|
|
return -val |
|
|
raise SafeEvalError("Unsupported unary op") |
|
|
if isinstance(node, ast.BinOp): |
|
|
l = cls._eval_node(node.left, vars) |
|
|
r = cls._eval_node(node.right, vars) |
|
|
if isinstance(node.op, ast.Add): return l + r |
|
|
if isinstance(node.op, ast.Sub): return l - r |
|
|
if isinstance(node.op, ast.Mult): return l * r |
|
|
if isinstance(node.op, ast.Div): return l / r |
|
|
if isinstance(node.op, ast.FloorDiv): return l // r |
|
|
if isinstance(node.op, ast.Mod): return l % r |
|
|
if isinstance(node.op, ast.Pow): return l ** r |
|
|
raise SafeEvalError("Unsupported binary op") |
|
|
if isinstance(node, ast.Compare): |
|
|
left = cls._eval_node(node.left, vars) |
|
|
result = True |
|
|
cur = left |
|
|
for op, comparator in zip(node.ops, node.comparators): |
|
|
right = cls._eval_node(comparator, vars) |
|
|
if isinstance(op, ast.Lt): ok = cur < right |
|
|
elif isinstance(op, ast.Gt): ok = cur > right |
|
|
elif isinstance(op, ast.LtE): ok = cur <= right |
|
|
elif isinstance(op, ast.GtE): ok = cur >= right |
|
|
elif isinstance(op, ast.Eq): ok = cur == right |
|
|
elif isinstance(op, ast.NotEq): ok = cur != right |
|
|
else: raise SafeEvalError("Unsupported comparator") |
|
|
result = result and ok |
|
|
cur = right |
|
|
return result |
|
|
if isinstance(node, ast.BoolOp): |
|
|
vals = [cls._eval_node(v, vars) for v in node.values] |
|
|
if isinstance(node.op, ast.And): |
|
|
out = True |
|
|
for v in vals: |
|
|
out = out and bool(v) |
|
|
return out |
|
|
if isinstance(node.op, ast.Or): |
|
|
out = False |
|
|
for v in vals: |
|
|
out = out or bool(v) |
|
|
return out |
|
|
raise SafeEvalError("Unsupported bool op") |
|
|
if isinstance(node, ast.IfExp): |
|
|
cond = cls._eval_node(node.test, vars) |
|
|
return cls._eval_node(node.body if cond else node.orelse, vars) |
|
|
if isinstance(node, ast.Call): |
|
|
if not isinstance(node.func, ast.Name): |
|
|
raise SafeEvalError("Only simple function calls allowed") |
|
|
fname = node.func.id |
|
|
if fname not in cls.ALLOWED_FUNCS: |
|
|
raise SafeEvalError(f"Function '{fname}' not allowed") |
|
|
args = [cls._eval_node(a, vars) for a in node.args] |
|
|
if len(args) > 2: |
|
|
raise SafeEvalError("Too many args") |
|
|
return cls.ALLOWED_FUNCS[fname](*args) |
|
|
if isinstance(node, (ast.List, ast.Tuple)): |
|
|
return [cls._eval_node(e, vars) for e in node.elts] |
|
|
if isinstance(node, ast.Dict): |
|
|
return {cls._eval_node(k, vars): cls._eval_node(v, vars) for k, v in zip(node.keys, node.values)} |
|
|
if isinstance(node, ast.Subscript): |
|
|
container = cls._eval_node(node.value, vars) |
|
|
idx = cls._eval_node(node.slice.value if hasattr(node.slice, "value") else node.slice, vars) |
|
|
return container[idx] |
|
|
raise SafeEvalError(f"Unsupported node: {type(node).__name__}") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AuthorityRef: |
|
|
doc: str |
|
|
section: Optional[str] = None |
|
|
subsection: Optional[str] = None |
|
|
page: Optional[str] = None |
|
|
url_anchor: Optional[str] = None |
|
|
|
|
|
@dataclass |
|
|
class RuleAtom: |
|
|
id: str |
|
|
title: str |
|
|
description: str |
|
|
tax_type: str |
|
|
jurisdiction_level: str |
|
|
formula_type: str |
|
|
inputs: List[str] |
|
|
output: str |
|
|
parameters: Dict[str, Any] = field(default_factory=dict) |
|
|
ordering_constraints: Dict[str, List[str]] = field(default_factory=dict) |
|
|
effective_from: str = "1900-01-01" |
|
|
effective_to: Optional[str] = None |
|
|
authority: List[AuthorityRef] = field(default_factory=list) |
|
|
notes: Optional[str] = None |
|
|
status: str = "approved" |
|
|
|
|
|
def is_active_on(self, on_date: date) -> bool: |
|
|
|
|
|
if isinstance(self.effective_from, str): |
|
|
start = datetime.strptime(self.effective_from, "%Y-%m-%d").date() |
|
|
else: |
|
|
start = self.effective_from |
|
|
|
|
|
if self.effective_to is None: |
|
|
end = datetime.max.date() |
|
|
elif isinstance(self.effective_to, str): |
|
|
end = datetime.strptime(self.effective_to, "%Y-%m-%d").date() |
|
|
else: |
|
|
end = self.effective_to |
|
|
|
|
|
return start <= on_date <= end |
|
|
|
|
|
|
|
|
class RuleCatalog: |
|
|
def __init__(self, atoms: List[RuleAtom]): |
|
|
self.atoms = atoms |
|
|
self._by_id = {a.id: a for a in atoms} |
|
|
|
|
|
@classmethod |
|
|
def from_yaml_files(cls, paths: List[str]) -> "RuleCatalog": |
|
|
atoms: List[RuleAtom] = [] |
|
|
for p in paths: |
|
|
with open(p, "r", encoding="utf-8") as f: |
|
|
data = yaml.safe_load(f) |
|
|
if isinstance(data, dict): |
|
|
data = [data] |
|
|
for item in data: |
|
|
auth = [AuthorityRef(**r) for r in item.get("authority", [])] |
|
|
atoms.append(RuleAtom(**{**item, "authority": auth})) |
|
|
return cls(atoms) |
|
|
|
|
|
def select(self, *, tax_type: str, on_date: date, jurisdiction: Optional[str] = None) -> List[RuleAtom]: |
|
|
out = [] |
|
|
for a in self.atoms: |
|
|
if a.tax_type != tax_type: |
|
|
continue |
|
|
if jurisdiction and a.jurisdiction_level != jurisdiction: |
|
|
continue |
|
|
if not a.is_active_on(on_date): |
|
|
continue |
|
|
if a.status == "deprecated": |
|
|
continue |
|
|
out.append(a) |
|
|
return out |
|
|
|
|
|
class CalculationResult: |
|
|
def __init__(self): |
|
|
self.values: Dict[str, float] = {} |
|
|
self.lines: List[Dict[str, Any]] = [] |
|
|
|
|
|
def set_value(self, key: str, val: float): |
|
|
self.values[key] = float(val) |
|
|
|
|
|
def get(self, key: str, default: float = 0.0) -> float: |
|
|
return float(self.values.get(key, default)) |
|
|
|
|
|
class TaxEngine: |
|
|
def __init__(self, catalog: RuleCatalog, rounding_mode: str = "half_up"): |
|
|
self.catalog = catalog |
|
|
self.rounding_mode = rounding_mode |
|
|
|
|
|
|
|
|
def _toposort(self, rules: List[RuleAtom]) -> List[RuleAtom]: |
|
|
after_map: Dict[str, Set[str]] = {} |
|
|
indeg: Dict[str, int] = {} |
|
|
id_map = {r.id: r for r in rules} |
|
|
for r in rules: |
|
|
deps = set(r.ordering_constraints.get("applies_after", [])) |
|
|
after_map[r.id] = {d for d in deps if d in id_map} |
|
|
for r in rules: |
|
|
indeg[r.id] = 0 |
|
|
for r, deps in after_map.items(): |
|
|
for d in deps: |
|
|
indeg[r] += 1 |
|
|
queue = [rid for rid, deg in indeg.items() if deg == 0] |
|
|
ordered: List[RuleAtom] = [] |
|
|
while queue: |
|
|
rid = queue.pop(0) |
|
|
ordered.append(id_map[rid]) |
|
|
for nid, deps in after_map.items(): |
|
|
if rid in deps: |
|
|
indeg[nid] -= 1 |
|
|
if indeg[nid] == 0: |
|
|
queue.append(nid) |
|
|
if len(ordered) != len(rules): |
|
|
|
|
|
raise ValueError("Dependency cycle or missing rule id in applies_after") |
|
|
return ordered |
|
|
|
|
|
def _round(self, x: float) -> float: |
|
|
if self.rounding_mode == "half_up": |
|
|
return float(int(x + 0.5)) if x >= 0 else -float(int(abs(x) + 0.5)) |
|
|
return round(x) |
|
|
|
|
|
def _evaluate_rule(self, r: RuleAtom, ctx: CalculationResult) -> Tuple[str, float, Dict[str, Any]]: |
|
|
v = ctx.values |
|
|
|
|
|
def ex(expr: str) -> float: |
|
|
return float(SafeExpr.eval(expr, v)) |
|
|
|
|
|
details: Dict[str, Any] = {} |
|
|
|
|
|
if r.formula_type == "fixed_amount": |
|
|
amt = ex(r.parameters.get("amount_expr", "0")) |
|
|
elif r.formula_type == "rate_on_base": |
|
|
base = ex(r.parameters.get("base_expr", "0")) |
|
|
rate = float(r.parameters.get("rate", 0)) |
|
|
amt = base * rate |
|
|
details.update({"base": base, "rate": rate}) |
|
|
elif r.formula_type == "capped_percentage": |
|
|
base = ex(r.parameters.get("base_expr", "0")) |
|
|
cap_rate = float(r.parameters.get("cap_rate", 0)) |
|
|
amt = min(base, base * cap_rate) |
|
|
details.update({"base": base, "cap_rate": cap_rate}) |
|
|
elif r.formula_type == "max_of_plus": |
|
|
base_opts = [ex(opt.get("expr", "0")) for opt in r.parameters.get("base_options", [])] |
|
|
plus_expr = r.parameters.get("plus_expr", "0") |
|
|
plus = ex(plus_expr) if plus_expr else 0.0 |
|
|
amt = max(base_opts) + plus if base_opts else plus |
|
|
details.update({"base_options": base_opts, "plus": plus}) |
|
|
elif r.formula_type == "piecewise_bands": |
|
|
taxable = ex(r.parameters.get("base_expr", "0")) |
|
|
bands = r.parameters.get("bands", []) |
|
|
remaining = taxable |
|
|
tax = 0.0 |
|
|
calc_steps = [] |
|
|
prev_upper = 0.0 |
|
|
for b in bands: |
|
|
upper = float("inf") if b.get("up_to") is None else float(b["up_to"]) |
|
|
rate = float(b["rate"]) |
|
|
chunk = max(0.0, min(remaining, upper - prev_upper)) |
|
|
if chunk > 0: |
|
|
part = chunk * rate |
|
|
tax += part |
|
|
calc_steps.append({"range": [prev_upper, upper], "chunk": chunk, "rate": rate, "tax": part}) |
|
|
remaining -= chunk |
|
|
prev_upper = upper |
|
|
if remaining <= 0: |
|
|
break |
|
|
amt = tax |
|
|
details.update({"base": taxable, "bands_applied": calc_steps}) |
|
|
elif r.formula_type == "conditional_min": |
|
|
computed = ex(r.parameters.get("computed_expr", "computed_tax")) |
|
|
min_amount = ex(r.parameters.get("min_amount_expr", "0")) |
|
|
amt = max(computed, min_amount) |
|
|
details.update({"computed": computed, "minimum": min_amount}) |
|
|
else: |
|
|
raise ValueError(f"Unknown formula_type: {r.formula_type}") |
|
|
|
|
|
amt = self._round(amt) if r.parameters.get("round", False) else amt |
|
|
return r.output, amt, details |
|
|
|
|
|
def run( |
|
|
self, |
|
|
*, |
|
|
tax_type: str, |
|
|
as_of: date, |
|
|
jurisdiction: Optional[str], |
|
|
inputs: Dict[str, float], |
|
|
rule_ids_whitelist: Optional[List[str]] = None |
|
|
) -> CalculationResult: |
|
|
active = self.catalog.select(tax_type=tax_type, on_date=as_of, jurisdiction=jurisdiction) |
|
|
if rule_ids_whitelist: |
|
|
idset = set(rule_ids_whitelist) |
|
|
active = [r for r in active if r.id in idset] |
|
|
|
|
|
ordered = self._toposort(active) |
|
|
ctx = CalculationResult() |
|
|
|
|
|
for k, v in inputs.items(): |
|
|
ctx.set_value(k, float(v)) |
|
|
|
|
|
for r in ordered: |
|
|
|
|
|
guard = r.parameters.get("applicability_expr") |
|
|
if guard: |
|
|
try: |
|
|
applies = bool(SafeExpr.eval(guard, ctx.values)) |
|
|
except Exception as e: |
|
|
raise SafeEvalError(f"Guard error in {r.id}: {e}") |
|
|
if not applies: |
|
|
continue |
|
|
|
|
|
out_key, amount, details = self._evaluate_rule(r, ctx) |
|
|
ctx.set_value(out_key, amount) |
|
|
ctx.lines.append({ |
|
|
"rule_id": r.id, |
|
|
"title": r.title, |
|
|
"amount": amount, |
|
|
"output": out_key, |
|
|
"details": details, |
|
|
"authority": [a.__dict__ for a in r.authority], |
|
|
}) |
|
|
return ctx |
|
|
|