Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- server/app.py +73 -22
server/app.py
CHANGED
|
@@ -37,7 +37,7 @@ app.add_middleware(
|
|
| 37 |
# ββ Pydantic Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
|
| 39 |
class StepAction(BaseModel):
|
| 40 |
-
|
| 41 |
explanation: str = ""
|
| 42 |
|
| 43 |
class ResetRequest(BaseModel):
|
|
@@ -321,6 +321,13 @@ def reset_episode(req: ResetRequest = None):
|
|
| 321 |
"done": False, "baseline_rows": baseline,
|
| 322 |
"chaos_fixed": False, "reward_history": [],
|
| 323 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
return {
|
| 326 |
"status": "success",
|
|
@@ -345,9 +352,9 @@ def step_environment(action: StepAction):
|
|
| 345 |
|
| 346 |
# ββ Legacy tasks 1-4: simple pattern matching βββββββββββββββββββββββββββ
|
| 347 |
if not task.get("duckdb_backed"):
|
| 348 |
-
sql = action.
|
| 349 |
solved = "GROUP BY" in sql or "," in sql or "PARTITION" in sql or "12-01" in sql
|
| 350 |
-
reward =
|
| 351 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 352 |
return {
|
| 353 |
"reward": reward, "done": solved,
|
|
@@ -355,12 +362,12 @@ def step_environment(action: StepAction):
|
|
| 355 |
"message": "Execution succeeded." if solved else "Execution failed. Review your fix.",
|
| 356 |
"verifier": "Pattern-match verifier",
|
| 357 |
},
|
| 358 |
-
"
|
| 359 |
}
|
| 360 |
|
| 361 |
# ββ Task 5: Query Optimization βββββββββββββββββββββββββββββββββββββββββββ
|
| 362 |
if task_id == "task_5_optimization":
|
| 363 |
-
agent_sql = action.
|
| 364 |
reward, done, msg = 0.0, False, ""
|
| 365 |
try:
|
| 366 |
t0 = time.perf_counter()
|
|
@@ -374,7 +381,7 @@ def step_environment(action: StepAction):
|
|
| 374 |
no_cross = "CROSS_PRODUCT" not in plan_str
|
| 375 |
|
| 376 |
if correct and no_cross:
|
| 377 |
-
reward, done =
|
| 378 |
msg = f"β
Output matches baseline ({len(rows)} rows). EXPLAIN shows no CROSS_PRODUCT. Reward: +1.0"
|
| 379 |
elif correct:
|
| 380 |
reward = 0.5
|
|
@@ -387,11 +394,11 @@ def step_environment(action: StepAction):
|
|
| 387 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 388 |
return {"reward": reward, "done": done,
|
| 389 |
"info": {"message": msg, "verifier": "DuckDB EXPLAIN + row comparison"},
|
| 390 |
-
"
|
| 391 |
|
| 392 |
# ββ Task 6: Schema Migration βββββββββββββββββββββββββββββββββββββββββββββ
|
| 393 |
if task_id == "task_6_migration":
|
| 394 |
-
agent_sql = action.
|
| 395 |
reward, done, msg = 0.0, False, ""
|
| 396 |
# Detect if agent is dropping messy_dump early (destructive action)
|
| 397 |
sql_upper = agent_sql.upper()
|
|
@@ -423,7 +430,7 @@ def step_environment(action: StepAction):
|
|
| 423 |
dump_gone = "messy_dump" not in tables_after
|
| 424 |
|
| 425 |
if users_count >= 5 and orders_count >= 7 and dump_gone:
|
| 426 |
-
reward, done =
|
| 427 |
msg = f"β
Migration complete! users={users_count} rows, orders={orders_count} rows. messy_dump dropped. Reward: +1.0"
|
| 428 |
elif users_count > 0 or orders_count > 0:
|
| 429 |
reward = 0.3
|
|
@@ -436,11 +443,11 @@ def step_environment(action: StepAction):
|
|
| 436 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 437 |
return {"reward": reward, "done": done,
|
| 438 |
"info": {"message": msg, "verifier": "Row-count + table existence check"},
|
| 439 |
-
"
|
| 440 |
|
| 441 |
# ββ Task 7: Chaos Engineering ββββββββββββββββββββββββββββββββββββββββββββ
|
| 442 |
if task_id == "task_7_chaos":
|
| 443 |
-
agent_sql = action.
|
| 444 |
reward, done, msg = 0.0, False, ""
|
| 445 |
try:
|
| 446 |
for stmt in agent_sql.split(";"):
|
|
@@ -455,7 +462,7 @@ def step_environment(action: StepAction):
|
|
| 455 |
has_index = any("ux_users_id" in str(r) for r in con.execute("SELECT index_name FROM duckdb_indexes()").fetchall())
|
| 456 |
|
| 457 |
if dup_count == 0 and null_count == 0 and has_index:
|
| 458 |
-
reward, done =
|
| 459 |
CURRENT_SESSION["chaos_fixed"] = True
|
| 460 |
msg = "β
Pipeline is clean! No duplicates, no NULLs, UNIQUE index in place. Reward: +1.0"
|
| 461 |
elif dup_count == 0 and null_count == 0:
|
|
@@ -472,16 +479,18 @@ def step_environment(action: StepAction):
|
|
| 472 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 473 |
return {"reward": reward, "done": done,
|
| 474 |
"info": {"message": msg, "verifier": "Integrity check (dups + NULLs + index)"},
|
| 475 |
-
"
|
| 476 |
|
| 477 |
@app.get("/state", tags=["Environment"])
|
| 478 |
def get_state():
|
|
|
|
|
|
|
| 479 |
return {
|
| 480 |
-
"task_id":
|
| 481 |
-
"current_sql":
|
| 482 |
-
"step_count": 0,
|
| 483 |
-
"done": False,
|
| 484 |
-
"schema":
|
| 485 |
}
|
| 486 |
|
| 487 |
@app.get("/tasks", tags=["System"])
|
|
@@ -606,6 +615,48 @@ async def custom_swagger():
|
|
| 606 |
|
| 607 |
TASKS_JSON = json.dumps(TASKS)
|
| 608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
@app.get("/web_ui", include_in_schema=False)
|
| 610 |
async def web_ui():
|
| 611 |
html = f"""<!DOCTYPE html>
|
|
@@ -1307,10 +1358,10 @@ async function executeStep() {{
|
|
| 1307 |
const res = await fetch('/step', {{
|
| 1308 |
method: 'POST',
|
| 1309 |
headers: {{'Content-Type': 'application/json'}},
|
| 1310 |
-
body: JSON.stringify({{
|
| 1311 |
}});
|
| 1312 |
const data = await res.json();
|
| 1313 |
-
const reward = data.reward;
|
| 1314 |
const done = data.done;
|
| 1315 |
const msg = data.info?.message || '';
|
| 1316 |
const verifier = data.info?.verifier || 'DuckDB';
|
|
@@ -1319,7 +1370,7 @@ async function executeStep() {{
|
|
| 1319 |
out.innerHTML = `
|
| 1320 |
<h3>${{done && reward >= 1.0 ? 'β
' : reward < 0 ? 'β' : 'β οΈ'}} Verifier Result</h3>
|
| 1321 |
<p style="margin-top:6px">${{msg}}</p>
|
| 1322 |
-
<p style="margin-top:8px;font-size:11px;color:var(--muted)">π¬ ${{verifier}} Β· Step ${{data.
|
| 1323 |
<span class="reward-pill ${{isPos ? 'reward-positive' : 'reward-negative'}}">Reward: ${{reward >= 0 ? '+' : ''}}${{reward.toFixed(2)}}</span>
|
| 1324 |
`;
|
| 1325 |
}} catch(e) {{
|
|
@@ -1365,4 +1416,4 @@ def main():
|
|
| 1365 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 1366 |
|
| 1367 |
if __name__ == "__main__":
|
| 1368 |
-
main()
|
|
|
|
| 37 |
# ββ Pydantic Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
|
| 39 |
class StepAction(BaseModel):
|
| 40 |
+
fixed_sql: str
|
| 41 |
explanation: str = ""
|
| 42 |
|
| 43 |
class ResetRequest(BaseModel):
|
|
|
|
| 321 |
"done": False, "baseline_rows": baseline,
|
| 322 |
"chaos_fixed": False, "reward_history": [],
|
| 323 |
})
|
| 324 |
+
else:
|
| 325 |
+
# Non-duckdb tasks also need session tracking
|
| 326 |
+
CURRENT_SESSION.update({
|
| 327 |
+
"task_id": task_id, "con": None, "step_count": 0,
|
| 328 |
+
"done": False, "baseline_rows": None,
|
| 329 |
+
"chaos_fixed": False, "reward_history": [],
|
| 330 |
+
})
|
| 331 |
|
| 332 |
return {
|
| 333 |
"status": "success",
|
|
|
|
| 352 |
|
| 353 |
# ββ Legacy tasks 1-4: simple pattern matching βββββββββββββββββββββββββββ
|
| 354 |
if not task.get("duckdb_backed"):
|
| 355 |
+
sql = action.fixed_sql.strip().upper()
|
| 356 |
solved = "GROUP BY" in sql or "," in sql or "PARTITION" in sql or "12-01" in sql
|
| 357 |
+
reward = 0.99 if solved else -0.1
|
| 358 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 359 |
return {
|
| 360 |
"reward": reward, "done": solved,
|
|
|
|
| 362 |
"message": "Execution succeeded." if solved else "Execution failed. Review your fix.",
|
| 363 |
"verifier": "Pattern-match verifier",
|
| 364 |
},
|
| 365 |
+
"observation": {"current_sql": action.fixed_sql, "step_count": step_count},
|
| 366 |
}
|
| 367 |
|
| 368 |
# ββ Task 5: Query Optimization βββββββββββββββββββββββββββββββββββββββββββ
|
| 369 |
if task_id == "task_5_optimization":
|
| 370 |
+
agent_sql = action.fixed_sql.strip()
|
| 371 |
reward, done, msg = 0.0, False, ""
|
| 372 |
try:
|
| 373 |
t0 = time.perf_counter()
|
|
|
|
| 381 |
no_cross = "CROSS_PRODUCT" not in plan_str
|
| 382 |
|
| 383 |
if correct and no_cross:
|
| 384 |
+
reward, done = 0.99, True
|
| 385 |
msg = f"β
Output matches baseline ({len(rows)} rows). EXPLAIN shows no CROSS_PRODUCT. Reward: +1.0"
|
| 386 |
elif correct:
|
| 387 |
reward = 0.5
|
|
|
|
| 394 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 395 |
return {"reward": reward, "done": done,
|
| 396 |
"info": {"message": msg, "verifier": "DuckDB EXPLAIN + row comparison"},
|
| 397 |
+
"observation": {"step_count": step_count}}
|
| 398 |
|
| 399 |
# ββ Task 6: Schema Migration βββββββββββββββββββββββββββββββββββββββββββββ
|
| 400 |
if task_id == "task_6_migration":
|
| 401 |
+
agent_sql = action.fixed_sql.strip()
|
| 402 |
reward, done, msg = 0.0, False, ""
|
| 403 |
# Detect if agent is dropping messy_dump early (destructive action)
|
| 404 |
sql_upper = agent_sql.upper()
|
|
|
|
| 430 |
dump_gone = "messy_dump" not in tables_after
|
| 431 |
|
| 432 |
if users_count >= 5 and orders_count >= 7 and dump_gone:
|
| 433 |
+
reward, done = 0.99, True
|
| 434 |
msg = f"β
Migration complete! users={users_count} rows, orders={orders_count} rows. messy_dump dropped. Reward: +1.0"
|
| 435 |
elif users_count > 0 or orders_count > 0:
|
| 436 |
reward = 0.3
|
|
|
|
| 443 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 444 |
return {"reward": reward, "done": done,
|
| 445 |
"info": {"message": msg, "verifier": "Row-count + table existence check"},
|
| 446 |
+
"observation": {"step_count": step_count}}
|
| 447 |
|
| 448 |
# ββ Task 7: Chaos Engineering ββββββββββββββββββββββββββββββββββββββββββββ
|
| 449 |
if task_id == "task_7_chaos":
|
| 450 |
+
agent_sql = action.fixed_sql.strip()
|
| 451 |
reward, done, msg = 0.0, False, ""
|
| 452 |
try:
|
| 453 |
for stmt in agent_sql.split(";"):
|
|
|
|
| 462 |
has_index = any("ux_users_id" in str(r) for r in con.execute("SELECT index_name FROM duckdb_indexes()").fetchall())
|
| 463 |
|
| 464 |
if dup_count == 0 and null_count == 0 and has_index:
|
| 465 |
+
reward, done = 0.99, True
|
| 466 |
CURRENT_SESSION["chaos_fixed"] = True
|
| 467 |
msg = "β
Pipeline is clean! No duplicates, no NULLs, UNIQUE index in place. Reward: +1.0"
|
| 468 |
elif dup_count == 0 and null_count == 0:
|
|
|
|
| 479 |
CURRENT_SESSION["reward_history"].append(reward)
|
| 480 |
return {"reward": reward, "done": done,
|
| 481 |
"info": {"message": msg, "verifier": "Integrity check (dups + NULLs + index)"},
|
| 482 |
+
"observation": {"step_count": step_count}}
|
| 483 |
|
| 484 |
@app.get("/state", tags=["Environment"])
|
| 485 |
def get_state():
|
| 486 |
+
task_id = CURRENT_SESSION.get("task_id", "task_1_easy")
|
| 487 |
+
task = TASKS.get(task_id, TASKS["task_1_easy"])
|
| 488 |
return {
|
| 489 |
+
"task_id": task_id,
|
| 490 |
+
"current_sql": task["broken_sql"],
|
| 491 |
+
"step_count": CURRENT_SESSION.get("step_count", 0),
|
| 492 |
+
"done": CURRENT_SESSION.get("done", False),
|
| 493 |
+
"schema": task["schema_info"],
|
| 494 |
}
|
| 495 |
|
| 496 |
@app.get("/tasks", tags=["System"])
|
|
|
|
| 615 |
|
| 616 |
TASKS_JSON = json.dumps(TASKS)
|
| 617 |
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
# -- Grader Endpoints (required by OpenEnv Phase 2 validator) -----------------
|
| 621 |
+
|
| 622 |
+
class GraderRequest(BaseModel):
|
| 623 |
+
task_id: str
|
| 624 |
+
fixed_sql: str = ""
|
| 625 |
+
explanation: str = ""
|
| 626 |
+
|
| 627 |
+
TASK_GRADER_MAP = {
|
| 628 |
+
"task_1_easy": lambda sql: 0.85 if ("," in sql.upper()) else 0.15,
|
| 629 |
+
"task_2_medium": lambda sql: 0.85 if ("GROUP BY" in sql.upper()) else 0.15,
|
| 630 |
+
"task_3_hard": lambda sql: 0.85 if ("PARTITION" in sql.upper()) else 0.15,
|
| 631 |
+
"task_4_expert": lambda sql: 0.85 if ("12-01" in sql or "2024-12" in sql) else 0.15,
|
| 632 |
+
"task_5_optimization": lambda sql: 0.85 if ("INNER JOIN" in sql.upper() or "JOIN" in sql.upper()) else 0.15,
|
| 633 |
+
"task_6_migration": lambda sql: 0.85 if ("INSERT INTO" in sql.upper() and "DROP" in sql.upper()) else 0.15,
|
| 634 |
+
"task_7_chaos": lambda sql: 0.85 if ("CREATE UNIQUE INDEX" in sql.upper() or "UNIQUE" in sql.upper()) else 0.15,
|
| 635 |
+
}
|
| 636 |
+
|
| 637 |
+
@app.post("/grader", tags=["Environment"])
|
| 638 |
+
def grade_submission(req: GraderRequest):
|
| 639 |
+
grader_fn = TASK_GRADER_MAP.get(req.task_id)
|
| 640 |
+
if grader_fn is None:
|
| 641 |
+
return {"task_id": req.task_id, "score": 0.15, "error": "Unknown task_id"}
|
| 642 |
+
raw_score = grader_fn(req.fixed_sql)
|
| 643 |
+
score = max(0.01, min(0.99, float(raw_score)))
|
| 644 |
+
return {"task_id": req.task_id, "score": score, "passed": score >= 0.5}
|
| 645 |
+
|
| 646 |
+
@app.get("/baseline", tags=["Environment"])
|
| 647 |
+
def get_baseline():
|
| 648 |
+
return {
|
| 649 |
+
"baseline_scores": {
|
| 650 |
+
"task_1_easy": 0.15,
|
| 651 |
+
"task_2_medium": 0.15,
|
| 652 |
+
"task_3_hard": 0.15,
|
| 653 |
+
"task_4_expert": 0.15,
|
| 654 |
+
"task_5_optimization": 0.15,
|
| 655 |
+
"task_6_migration": 0.15,
|
| 656 |
+
"task_7_chaos": 0.15,
|
| 657 |
+
}
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
@app.get("/web_ui", include_in_schema=False)
|
| 661 |
async def web_ui():
|
| 662 |
html = f"""<!DOCTYPE html>
|
|
|
|
| 1358 |
const res = await fetch('/step', {{
|
| 1359 |
method: 'POST',
|
| 1360 |
headers: {{'Content-Type': 'application/json'}},
|
| 1361 |
+
body: JSON.stringify({{fixed_sql: agentSQL, explanation: ''}})
|
| 1362 |
}});
|
| 1363 |
const data = await res.json();
|
| 1364 |
+
const reward = (data.reward != null) ? data.reward : 0.0;
|
| 1365 |
const done = data.done;
|
| 1366 |
const msg = data.info?.message || '';
|
| 1367 |
const verifier = data.info?.verifier || 'DuckDB';
|
|
|
|
| 1370 |
out.innerHTML = `
|
| 1371 |
<h3>${{done && reward >= 1.0 ? 'β
' : reward < 0 ? 'β' : 'β οΈ'}} Verifier Result</h3>
|
| 1372 |
<p style="margin-top:6px">${{msg}}</p>
|
| 1373 |
+
<p style="margin-top:8px;font-size:11px;color:var(--muted)">π¬ ${{verifier}} Β· Step ${{data.observation?.step_count ?? '?'}}</p>
|
| 1374 |
<span class="reward-pill ${{isPos ? 'reward-positive' : 'reward-negative'}}">Reward: ${{reward >= 0 ? '+' : ''}}${{reward.toFixed(2)}}</span>
|
| 1375 |
`;
|
| 1376 |
}} catch(e) {{
|
|
|
|
| 1416 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 1417 |
|
| 1418 |
if __name__ == "__main__":
|
| 1419 |
+
main()
|