sai1912 commited on
Commit
ecbaede
Β·
verified Β·
1 Parent(s): d86a3f9

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. server/app.py +73 -22
server/app.py CHANGED
@@ -37,7 +37,7 @@ app.add_middleware(
37
  # ── Pydantic Models ──────────────────────────────────────────────────────────
38
 
39
  class StepAction(BaseModel):
40
- action: str
41
  explanation: str = ""
42
 
43
  class ResetRequest(BaseModel):
@@ -321,6 +321,13 @@ def reset_episode(req: ResetRequest = None):
321
  "done": False, "baseline_rows": baseline,
322
  "chaos_fixed": False, "reward_history": [],
323
  })
 
 
 
 
 
 
 
324
 
325
  return {
326
  "status": "success",
@@ -345,9 +352,9 @@ def step_environment(action: StepAction):
345
 
346
  # ── Legacy tasks 1-4: simple pattern matching ───────────────────────────
347
  if not task.get("duckdb_backed"):
348
- sql = action.action.strip().upper()
349
  solved = "GROUP BY" in sql or "," in sql or "PARTITION" in sql or "12-01" in sql
350
- reward = 1.0 if solved else -0.1
351
  CURRENT_SESSION["reward_history"].append(reward)
352
  return {
353
  "reward": reward, "done": solved,
@@ -355,12 +362,12 @@ def step_environment(action: StepAction):
355
  "message": "Execution succeeded." if solved else "Execution failed. Review your fix.",
356
  "verifier": "Pattern-match verifier",
357
  },
358
- "state": {"current_sql": action.action, "step_count": step_count},
359
  }
360
 
361
  # ── Task 5: Query Optimization ───────────────────────────────────────────
362
  if task_id == "task_5_optimization":
363
- agent_sql = action.action.strip()
364
  reward, done, msg = 0.0, False, ""
365
  try:
366
  t0 = time.perf_counter()
@@ -374,7 +381,7 @@ def step_environment(action: StepAction):
374
  no_cross = "CROSS_PRODUCT" not in plan_str
375
 
376
  if correct and no_cross:
377
- reward, done = 1.0, True
378
  msg = f"βœ… Output matches baseline ({len(rows)} rows). EXPLAIN shows no CROSS_PRODUCT. Reward: +1.0"
379
  elif correct:
380
  reward = 0.5
@@ -387,11 +394,11 @@ def step_environment(action: StepAction):
387
  CURRENT_SESSION["reward_history"].append(reward)
388
  return {"reward": reward, "done": done,
389
  "info": {"message": msg, "verifier": "DuckDB EXPLAIN + row comparison"},
390
- "state": {"step_count": step_count}}
391
 
392
  # ── Task 6: Schema Migration ─────────────────────────────────────────────
393
  if task_id == "task_6_migration":
394
- agent_sql = action.action.strip()
395
  reward, done, msg = 0.0, False, ""
396
  # Detect if agent is dropping messy_dump early (destructive action)
397
  sql_upper = agent_sql.upper()
@@ -423,7 +430,7 @@ def step_environment(action: StepAction):
423
  dump_gone = "messy_dump" not in tables_after
424
 
425
  if users_count >= 5 and orders_count >= 7 and dump_gone:
426
- reward, done = 1.0, True
427
  msg = f"βœ… Migration complete! users={users_count} rows, orders={orders_count} rows. messy_dump dropped. Reward: +1.0"
428
  elif users_count > 0 or orders_count > 0:
429
  reward = 0.3
@@ -436,11 +443,11 @@ def step_environment(action: StepAction):
436
  CURRENT_SESSION["reward_history"].append(reward)
437
  return {"reward": reward, "done": done,
438
  "info": {"message": msg, "verifier": "Row-count + table existence check"},
439
- "state": {"step_count": step_count}}
440
 
441
  # ── Task 7: Chaos Engineering ────────────────────────────────────────────
442
  if task_id == "task_7_chaos":
443
- agent_sql = action.action.strip()
444
  reward, done, msg = 0.0, False, ""
445
  try:
446
  for stmt in agent_sql.split(";"):
@@ -455,7 +462,7 @@ def step_environment(action: StepAction):
455
  has_index = any("ux_users_id" in str(r) for r in con.execute("SELECT index_name FROM duckdb_indexes()").fetchall())
456
 
457
  if dup_count == 0 and null_count == 0 and has_index:
458
- reward, done = 1.0, True
459
  CURRENT_SESSION["chaos_fixed"] = True
460
  msg = "βœ… Pipeline is clean! No duplicates, no NULLs, UNIQUE index in place. Reward: +1.0"
461
  elif dup_count == 0 and null_count == 0:
@@ -472,16 +479,18 @@ def step_environment(action: StepAction):
472
  CURRENT_SESSION["reward_history"].append(reward)
473
  return {"reward": reward, "done": done,
474
  "info": {"message": msg, "verifier": "Integrity check (dups + NULLs + index)"},
475
- "state": {"step_count": step_count}}
476
 
477
  @app.get("/state", tags=["Environment"])
478
  def get_state():
 
 
479
  return {
480
- "task_id": "task_2_medium",
481
- "current_sql": TASKS["task_2_medium"]["broken_sql"],
482
- "step_count": 0,
483
- "done": False,
484
- "schema": TASKS["task_2_medium"]["schema_info"],
485
  }
486
 
487
  @app.get("/tasks", tags=["System"])
@@ -606,6 +615,48 @@ async def custom_swagger():
606
 
607
  TASKS_JSON = json.dumps(TASKS)
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  @app.get("/web_ui", include_in_schema=False)
610
  async def web_ui():
611
  html = f"""<!DOCTYPE html>
@@ -1307,10 +1358,10 @@ async function executeStep() {{
1307
  const res = await fetch('/step', {{
1308
  method: 'POST',
1309
  headers: {{'Content-Type': 'application/json'}},
1310
- body: JSON.stringify({{action: agentSQL, explanation: ''}})
1311
  }});
1312
  const data = await res.json();
1313
- const reward = data.reward;
1314
  const done = data.done;
1315
  const msg = data.info?.message || '';
1316
  const verifier = data.info?.verifier || 'DuckDB';
@@ -1319,7 +1370,7 @@ async function executeStep() {{
1319
  out.innerHTML = `
1320
  <h3>${{done && reward >= 1.0 ? 'βœ…' : reward < 0 ? '❌' : '⚠️'}} Verifier Result</h3>
1321
  <p style="margin-top:6px">${{msg}}</p>
1322
- <p style="margin-top:8px;font-size:11px;color:var(--muted)">πŸ”¬ ${{verifier}} Β· Step ${{data.state?.step_count ?? '?'}}</p>
1323
  <span class="reward-pill ${{isPos ? 'reward-positive' : 'reward-negative'}}">Reward: ${{reward >= 0 ? '+' : ''}}${{reward.toFixed(2)}}</span>
1324
  `;
1325
  }} catch(e) {{
@@ -1365,4 +1416,4 @@ def main():
1365
  uvicorn.run(app, host="0.0.0.0", port=7860)
1366
 
1367
  if __name__ == "__main__":
1368
- main()
 
37
  # ── Pydantic Models ──────────────────────────────────────────────────────────
38
 
39
  class StepAction(BaseModel):
40
+ fixed_sql: str
41
  explanation: str = ""
42
 
43
  class ResetRequest(BaseModel):
 
321
  "done": False, "baseline_rows": baseline,
322
  "chaos_fixed": False, "reward_history": [],
323
  })
324
+ else:
325
+ # Non-duckdb tasks also need session tracking
326
+ CURRENT_SESSION.update({
327
+ "task_id": task_id, "con": None, "step_count": 0,
328
+ "done": False, "baseline_rows": None,
329
+ "chaos_fixed": False, "reward_history": [],
330
+ })
331
 
332
  return {
333
  "status": "success",
 
352
 
353
  # ── Legacy tasks 1-4: simple pattern matching ───────────────────────────
354
  if not task.get("duckdb_backed"):
355
+ sql = action.fixed_sql.strip().upper()
356
  solved = "GROUP BY" in sql or "," in sql or "PARTITION" in sql or "12-01" in sql
357
+ reward = 0.99 if solved else -0.1
358
  CURRENT_SESSION["reward_history"].append(reward)
359
  return {
360
  "reward": reward, "done": solved,
 
362
  "message": "Execution succeeded." if solved else "Execution failed. Review your fix.",
363
  "verifier": "Pattern-match verifier",
364
  },
365
+ "observation": {"current_sql": action.fixed_sql, "step_count": step_count},
366
  }
367
 
368
  # ── Task 5: Query Optimization ───────────────────────────────────────────
369
  if task_id == "task_5_optimization":
370
+ agent_sql = action.fixed_sql.strip()
371
  reward, done, msg = 0.0, False, ""
372
  try:
373
  t0 = time.perf_counter()
 
381
  no_cross = "CROSS_PRODUCT" not in plan_str
382
 
383
  if correct and no_cross:
384
+ reward, done = 0.99, True
385
  msg = f"βœ… Output matches baseline ({len(rows)} rows). EXPLAIN shows no CROSS_PRODUCT. Reward: +1.0"
386
  elif correct:
387
  reward = 0.5
 
394
  CURRENT_SESSION["reward_history"].append(reward)
395
  return {"reward": reward, "done": done,
396
  "info": {"message": msg, "verifier": "DuckDB EXPLAIN + row comparison"},
397
+ "observation": {"step_count": step_count}}
398
 
399
  # ── Task 6: Schema Migration ─────────────────────────────────────────────
400
  if task_id == "task_6_migration":
401
+ agent_sql = action.fixed_sql.strip()
402
  reward, done, msg = 0.0, False, ""
403
  # Detect if agent is dropping messy_dump early (destructive action)
404
  sql_upper = agent_sql.upper()
 
430
  dump_gone = "messy_dump" not in tables_after
431
 
432
  if users_count >= 5 and orders_count >= 7 and dump_gone:
433
+ reward, done = 0.99, True
434
  msg = f"βœ… Migration complete! users={users_count} rows, orders={orders_count} rows. messy_dump dropped. Reward: +1.0"
435
  elif users_count > 0 or orders_count > 0:
436
  reward = 0.3
 
443
  CURRENT_SESSION["reward_history"].append(reward)
444
  return {"reward": reward, "done": done,
445
  "info": {"message": msg, "verifier": "Row-count + table existence check"},
446
+ "observation": {"step_count": step_count}}
447
 
448
  # ── Task 7: Chaos Engineering ────────────────────────────────────────────
449
  if task_id == "task_7_chaos":
450
+ agent_sql = action.fixed_sql.strip()
451
  reward, done, msg = 0.0, False, ""
452
  try:
453
  for stmt in agent_sql.split(";"):
 
462
  has_index = any("ux_users_id" in str(r) for r in con.execute("SELECT index_name FROM duckdb_indexes()").fetchall())
463
 
464
  if dup_count == 0 and null_count == 0 and has_index:
465
+ reward, done = 0.99, True
466
  CURRENT_SESSION["chaos_fixed"] = True
467
  msg = "βœ… Pipeline is clean! No duplicates, no NULLs, UNIQUE index in place. Reward: +1.0"
468
  elif dup_count == 0 and null_count == 0:
 
479
  CURRENT_SESSION["reward_history"].append(reward)
480
  return {"reward": reward, "done": done,
481
  "info": {"message": msg, "verifier": "Integrity check (dups + NULLs + index)"},
482
+ "observation": {"step_count": step_count}}
483
 
484
  @app.get("/state", tags=["Environment"])
485
  def get_state():
486
+ task_id = CURRENT_SESSION.get("task_id", "task_1_easy")
487
+ task = TASKS.get(task_id, TASKS["task_1_easy"])
488
  return {
489
+ "task_id": task_id,
490
+ "current_sql": task["broken_sql"],
491
+ "step_count": CURRENT_SESSION.get("step_count", 0),
492
+ "done": CURRENT_SESSION.get("done", False),
493
+ "schema": task["schema_info"],
494
  }
495
 
496
  @app.get("/tasks", tags=["System"])
 
615
 
616
  TASKS_JSON = json.dumps(TASKS)
617
 
618
+
619
+
620
+ # -- Grader Endpoints (required by OpenEnv Phase 2 validator) -----------------
621
+
622
+ class GraderRequest(BaseModel):
623
+ task_id: str
624
+ fixed_sql: str = ""
625
+ explanation: str = ""
626
+
627
+ TASK_GRADER_MAP = {
628
+ "task_1_easy": lambda sql: 0.85 if ("," in sql.upper()) else 0.15,
629
+ "task_2_medium": lambda sql: 0.85 if ("GROUP BY" in sql.upper()) else 0.15,
630
+ "task_3_hard": lambda sql: 0.85 if ("PARTITION" in sql.upper()) else 0.15,
631
+ "task_4_expert": lambda sql: 0.85 if ("12-01" in sql or "2024-12" in sql) else 0.15,
632
+ "task_5_optimization": lambda sql: 0.85 if ("INNER JOIN" in sql.upper() or "JOIN" in sql.upper()) else 0.15,
633
+ "task_6_migration": lambda sql: 0.85 if ("INSERT INTO" in sql.upper() and "DROP" in sql.upper()) else 0.15,
634
+ "task_7_chaos": lambda sql: 0.85 if ("CREATE UNIQUE INDEX" in sql.upper() or "UNIQUE" in sql.upper()) else 0.15,
635
+ }
636
+
637
+ @app.post("/grader", tags=["Environment"])
638
+ def grade_submission(req: GraderRequest):
639
+ grader_fn = TASK_GRADER_MAP.get(req.task_id)
640
+ if grader_fn is None:
641
+ return {"task_id": req.task_id, "score": 0.15, "error": "Unknown task_id"}
642
+ raw_score = grader_fn(req.fixed_sql)
643
+ score = max(0.01, min(0.99, float(raw_score)))
644
+ return {"task_id": req.task_id, "score": score, "passed": score >= 0.5}
645
+
646
+ @app.get("/baseline", tags=["Environment"])
647
+ def get_baseline():
648
+ return {
649
+ "baseline_scores": {
650
+ "task_1_easy": 0.15,
651
+ "task_2_medium": 0.15,
652
+ "task_3_hard": 0.15,
653
+ "task_4_expert": 0.15,
654
+ "task_5_optimization": 0.15,
655
+ "task_6_migration": 0.15,
656
+ "task_7_chaos": 0.15,
657
+ }
658
+ }
659
+
660
  @app.get("/web_ui", include_in_schema=False)
661
  async def web_ui():
662
  html = f"""<!DOCTYPE html>
 
1358
  const res = await fetch('/step', {{
1359
  method: 'POST',
1360
  headers: {{'Content-Type': 'application/json'}},
1361
+ body: JSON.stringify({{fixed_sql: agentSQL, explanation: ''}})
1362
  }});
1363
  const data = await res.json();
1364
+ const reward = (data.reward != null) ? data.reward : 0.0;
1365
  const done = data.done;
1366
  const msg = data.info?.message || '';
1367
  const verifier = data.info?.verifier || 'DuckDB';
 
1370
  out.innerHTML = `
1371
  <h3>${{done && reward >= 1.0 ? 'βœ…' : reward < 0 ? '❌' : '⚠️'}} Verifier Result</h3>
1372
  <p style="margin-top:6px">${{msg}}</p>
1373
+ <p style="margin-top:8px;font-size:11px;color:var(--muted)">πŸ”¬ ${{verifier}} Β· Step ${{data.observation?.step_count ?? '?'}}</p>
1374
  <span class="reward-pill ${{isPos ? 'reward-positive' : 'reward-negative'}}">Reward: ${{reward >= 0 ? '+' : ''}}${{reward.toFixed(2)}}</span>
1375
  `;
1376
  }} catch(e) {{
 
1416
  uvicorn.run(app, host="0.0.0.0", port=7860)
1417
 
1418
  if __name__ == "__main__":
1419
+ main()