# rules_engine.py from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Set from datetime import date, datetime import math import yaml import ast # ------------- Safe expression evaluator ------------- class SafeEvalError(Exception): pass class SafeExpr: """ Very small arithmetic evaluator over a dict of variables. Supports + - * / // % **, parentheses, numbers, names, and simple calls to min, max, abs, round with at most 2 args. """ ALLOWED_FUNCS = {"min": min, "max": max, "abs": abs, "round": round} ALLOWED_NODES = ( ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Name, ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.USub, ast.UAdd, ast.Call, ast.Tuple, ast.Constant, ast.Compare, ast.Lt, ast.Gt, ast.LtE, ast.GtE, ast.Eq, ast.NotEq, ast.BoolOp, ast.And, ast.Or, ast.IfExp, ast.Subscript, ast.Index, ast.Dict, ast.List ) @classmethod def eval(cls, expr: str, variables: Dict[str, Any]) -> Any: try: tree = ast.parse(expr, mode="eval") except Exception as e: raise SafeEvalError(f"Parse error: {e}") from e if not all(isinstance(n, cls.ALLOWED_NODES) for n in ast.walk(tree)): raise SafeEvalError("Disallowed syntax in expression") return cls._eval_node(tree.body, variables) @classmethod def _eval_node(cls, node, vars): if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.Num): # py<3.8 return node.n if isinstance(node, ast.Name): try: return vars[node.id] except KeyError: raise SafeEvalError(f"Unknown variable '{node.id}'") if isinstance(node, ast.UnaryOp): val = cls._eval_node(node.operand, vars) if isinstance(node.op, ast.UAdd): return +val if isinstance(node.op, ast.USub): return -val raise SafeEvalError("Unsupported unary op") if isinstance(node, ast.BinOp): l = cls._eval_node(node.left, vars) r = cls._eval_node(node.right, vars) if isinstance(node.op, ast.Add): return l + r if isinstance(node.op, ast.Sub): return l - r if isinstance(node.op, ast.Mult): return l * r if isinstance(node.op, ast.Div): return l / r if isinstance(node.op, ast.FloorDiv): return l // r if isinstance(node.op, ast.Mod): return l % r if isinstance(node.op, ast.Pow): return l ** r raise SafeEvalError("Unsupported binary op") if isinstance(node, ast.Compare): left = cls._eval_node(node.left, vars) result = True cur = left for op, comparator in zip(node.ops, node.comparators): right = cls._eval_node(comparator, vars) if isinstance(op, ast.Lt): ok = cur < right elif isinstance(op, ast.Gt): ok = cur > right elif isinstance(op, ast.LtE): ok = cur <= right elif isinstance(op, ast.GtE): ok = cur >= right elif isinstance(op, ast.Eq): ok = cur == right elif isinstance(op, ast.NotEq): ok = cur != right else: raise SafeEvalError("Unsupported comparator") result = result and ok cur = right return result if isinstance(node, ast.BoolOp): vals = [cls._eval_node(v, vars) for v in node.values] if isinstance(node.op, ast.And): out = True for v in vals: out = out and bool(v) return out if isinstance(node.op, ast.Or): out = False for v in vals: out = out or bool(v) return out raise SafeEvalError("Unsupported bool op") if isinstance(node, ast.IfExp): cond = cls._eval_node(node.test, vars) return cls._eval_node(node.body if cond else node.orelse, vars) if isinstance(node, ast.Call): if not isinstance(node.func, ast.Name): raise SafeEvalError("Only simple function calls allowed") fname = node.func.id if fname not in cls.ALLOWED_FUNCS: raise SafeEvalError(f"Function '{fname}' not allowed") args = [cls._eval_node(a, vars) for a in node.args] if len(args) > 2: raise SafeEvalError("Too many args") return cls.ALLOWED_FUNCS[fname](*args) if isinstance(node, (ast.List, ast.Tuple)): return [cls._eval_node(e, vars) for e in node.elts] if isinstance(node, ast.Dict): return {cls._eval_node(k, vars): cls._eval_node(v, vars) for k, v in zip(node.keys, node.values)} if isinstance(node, ast.Subscript): container = cls._eval_node(node.value, vars) idx = cls._eval_node(node.slice.value if hasattr(node.slice, "value") else node.slice, vars) return container[idx] raise SafeEvalError(f"Unsupported node: {type(node).__name__}") # ------------- Rule atoms ------------- @dataclass class AuthorityRef: doc: str section: Optional[str] = None subsection: Optional[str] = None page: Optional[str] = None url_anchor: Optional[str] = None @dataclass class RuleAtom: id: str title: str description: str tax_type: str # eg "PIT", "CIT", "VAT" jurisdiction_level: str # eg "federal", "state" formula_type: str # piecewise_bands, capped_percentage, etc inputs: List[str] output: str parameters: Dict[str, Any] = field(default_factory=dict) ordering_constraints: Dict[str, List[str]] = field(default_factory=dict) effective_from: str = "1900-01-01" effective_to: Optional[str] = None authority: List[AuthorityRef] = field(default_factory=list) notes: Optional[str] = None status: str = "approved" # draft, approved, deprecated def is_active_on(self, on_date: date) -> bool: # Handle both string and date objects if isinstance(self.effective_from, str): start = datetime.strptime(self.effective_from, "%Y-%m-%d").date() else: start = self.effective_from if self.effective_to is None: end = datetime.max.date() elif isinstance(self.effective_to, str): end = datetime.strptime(self.effective_to, "%Y-%m-%d").date() else: end = self.effective_to return start <= on_date <= end # ------------- Engine core ------------- class RuleCatalog: def __init__(self, atoms: List[RuleAtom]): self.atoms = atoms self._by_id = {a.id: a for a in atoms} @classmethod def from_yaml_files(cls, paths: List[str]) -> "RuleCatalog": atoms: List[RuleAtom] = [] for p in paths: with open(p, "r", encoding="utf-8") as f: data = yaml.safe_load(f) if isinstance(data, dict): data = [data] for item in data: auth = [AuthorityRef(**r) for r in item.get("authority", [])] atoms.append(RuleAtom(**{**item, "authority": auth})) return cls(atoms) def select(self, *, tax_type: str, on_date: date, jurisdiction: Optional[str] = None) -> List[RuleAtom]: out = [] for a in self.atoms: if a.tax_type != tax_type: continue if jurisdiction and a.jurisdiction_level != jurisdiction: continue if not a.is_active_on(on_date): continue if a.status == "deprecated": continue out.append(a) return out class CalculationResult: def __init__(self): self.values: Dict[str, float] = {} self.lines: List[Dict[str, Any]] = [] # each line: rule_id, title, amount, details, authority def set_value(self, key: str, val: float): self.values[key] = float(val) def get(self, key: str, default: float = 0.0) -> float: return float(self.values.get(key, default)) class TaxEngine: def __init__(self, catalog: RuleCatalog, rounding_mode: str = "half_up"): self.catalog = catalog self.rounding_mode = rounding_mode # dependency ordering def _toposort(self, rules: List[RuleAtom]) -> List[RuleAtom]: after_map: Dict[str, Set[str]] = {} indeg: Dict[str, int] = {} id_map = {r.id: r for r in rules} for r in rules: deps = set(r.ordering_constraints.get("applies_after", [])) after_map[r.id] = {d for d in deps if d in id_map} for r in rules: indeg[r.id] = 0 for r, deps in after_map.items(): for d in deps: indeg[r] += 1 queue = [rid for rid, deg in indeg.items() if deg == 0] ordered: List[RuleAtom] = [] while queue: rid = queue.pop(0) ordered.append(id_map[rid]) for nid, deps in after_map.items(): if rid in deps: indeg[nid] -= 1 if indeg[nid] == 0: queue.append(nid) if len(ordered) != len(rules): # cycle detected or missing ids raise ValueError("Dependency cycle or missing rule id in applies_after") return ordered def _round(self, x: float) -> float: if self.rounding_mode == "half_up": return float(int(x + 0.5)) if x >= 0 else -float(int(abs(x) + 0.5)) return round(x) def _evaluate_rule(self, r: RuleAtom, ctx: CalculationResult) -> Tuple[str, float, Dict[str, Any]]: v = ctx.values # shorthand def ex(expr: str) -> float: return float(SafeExpr.eval(expr, v)) details: Dict[str, Any] = {} if r.formula_type == "fixed_amount": amt = ex(r.parameters.get("amount_expr", "0")) elif r.formula_type == "rate_on_base": base = ex(r.parameters.get("base_expr", "0")) rate = float(r.parameters.get("rate", 0)) amt = base * rate details.update({"base": base, "rate": rate}) elif r.formula_type == "capped_percentage": base = ex(r.parameters.get("base_expr", "0")) cap_rate = float(r.parameters.get("cap_rate", 0)) amt = min(base, base * cap_rate) details.update({"base": base, "cap_rate": cap_rate}) elif r.formula_type == "max_of_plus": base_opts = [ex(opt.get("expr", "0")) for opt in r.parameters.get("base_options", [])] plus_expr = r.parameters.get("plus_expr", "0") plus = ex(plus_expr) if plus_expr else 0.0 amt = max(base_opts) + plus if base_opts else plus details.update({"base_options": base_opts, "plus": plus}) elif r.formula_type == "piecewise_bands": taxable = ex(r.parameters.get("base_expr", "0")) bands = r.parameters.get("bands", []) remaining = taxable tax = 0.0 calc_steps = [] prev_upper = 0.0 for b in bands: upper = float("inf") if b.get("up_to") is None else float(b["up_to"]) rate = float(b["rate"]) chunk = max(0.0, min(remaining, upper - prev_upper)) if chunk > 0: part = chunk * rate tax += part calc_steps.append({"range": [prev_upper, upper], "chunk": chunk, "rate": rate, "tax": part}) remaining -= chunk prev_upper = upper if remaining <= 0: break amt = tax details.update({"base": taxable, "bands_applied": calc_steps}) elif r.formula_type == "conditional_min": computed = ex(r.parameters.get("computed_expr", "computed_tax")) min_amount = ex(r.parameters.get("min_amount_expr", "0")) amt = max(computed, min_amount) details.update({"computed": computed, "minimum": min_amount}) else: raise ValueError(f"Unknown formula_type: {r.formula_type}") amt = self._round(amt) if r.parameters.get("round", False) else amt return r.output, amt, details def run( self, *, tax_type: str, as_of: date, jurisdiction: Optional[str], inputs: Dict[str, float], rule_ids_whitelist: Optional[List[str]] = None ) -> CalculationResult: active = self.catalog.select(tax_type=tax_type, on_date=as_of, jurisdiction=jurisdiction) if rule_ids_whitelist: idset = set(rule_ids_whitelist) active = [r for r in active if r.id in idset] ordered = self._toposort(active) ctx = CalculationResult() # seed inputs for k, v in inputs.items(): ctx.set_value(k, float(v)) for r in ordered: # allow guard expressions like "applicability_expr": "employment_income > 0" guard = r.parameters.get("applicability_expr") if guard: try: applies = bool(SafeExpr.eval(guard, ctx.values)) except Exception as e: raise SafeEvalError(f"Guard error in {r.id}: {e}") if not applies: continue out_key, amount, details = self._evaluate_rule(r, ctx) ctx.set_value(out_key, amount) ctx.lines.append({ "rule_id": r.id, "title": r.title, "amount": amount, "output": out_key, "details": details, "authority": [a.__dict__ for a in r.authority], }) return ctx