| | """Keypoint–Argument Matching Endpoints""" |
| |
|
| | from fastapi import APIRouter, HTTPException |
| | from datetime import datetime |
| | import logging |
| |
|
| | from models import ( |
| | PredictionRequest, |
| | PredictionResponse, |
| | BatchPredictionRequest, |
| | BatchPredictionResponse |
| | ) |
| |
|
| | from services import kpa_model_manager |
| |
|
| | router = APIRouter() |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @router.get("/model-info", tags=["KPA"]) |
| | async def get_model_info(): |
| | """ |
| | Return information about the loaded KPA model. |
| | """ |
| | try: |
| | model_info = kpa_model_manager.get_model_info() |
| |
|
| | return { |
| | "model_name": model_info.get("model_name", "unknown"), |
| | "device": model_info.get("device", "cpu"), |
| | "max_length": model_info.get("max_length", 256), |
| | "num_labels": model_info.get("num_labels", 2), |
| | "loaded": model_info.get("loaded", False), |
| | "timestamp": datetime.now().isoformat() |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Model info error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=f"Failed to get model info: {str(e)}") |
| |
|
| |
|
| | @router.post("/predict", response_model=PredictionResponse, tags=["KPA"]) |
| | async def predict_kpa(request: PredictionRequest): |
| | """ |
| | Predict keypoint-argument matching for a single pair. |
| | |
| | - **argument**: The argument text |
| | - **key_point**: The key point to evaluate |
| | |
| | Returns the predicted class (apparie / non_apparie) with probabilities. |
| | """ |
| | try: |
| | result = kpa_model_manager.predict( |
| | argument=request.argument, |
| | key_point=request.key_point |
| | ) |
| |
|
| | response = PredictionResponse( |
| | prediction=result["prediction"], |
| | confidence=result["confidence"], |
| | label=result["label"], |
| | probabilities=result["probabilities"] |
| | ) |
| |
|
| | logger.info( |
| | f"KPA Prediction: {response.label} " |
| | f"(conf={response.confidence:.4f})" |
| | ) |
| |
|
| | return response |
| |
|
| | except Exception as e: |
| | logger.error(f"KPA prediction error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
| |
|
| |
|
| | @router.post("/batch-predict", response_model=BatchPredictionResponse, tags=["KPA"]) |
| | async def batch_predict_kpa(request: BatchPredictionRequest): |
| | """ |
| | Predict keypoint-argument matching for multiple argument/keypoint pairs. |
| | |
| | - **pairs**: List of items to classify |
| | |
| | Returns predictions for all pairs. |
| | """ |
| | try: |
| | results = [] |
| |
|
| | for item in request.pairs: |
| | try: |
| | result = kpa_model_manager.predict( |
| | argument=item.argument, |
| | key_point=item.key_point |
| | ) |
| |
|
| | response = PredictionResponse( |
| | prediction=result["prediction"], |
| | confidence=result["confidence"], |
| | label=result["label"], |
| | probabilities=result["probabilities"] |
| | ) |
| |
|
| | results.append(response) |
| |
|
| | except Exception: |
| | results.append( |
| | PredictionResponse( |
| | prediction=-1, |
| | confidence=0.0, |
| | label="error", |
| | probabilities={"error": 1.0} |
| | ) |
| | ) |
| |
|
| | |
| | successful_predictions = [r for r in results if r.prediction != -1] |
| | |
| | if successful_predictions: |
| | total_apparie = sum(1 for r in successful_predictions if r.prediction == 1) |
| | total_non_apparie = sum(1 for r in successful_predictions if r.prediction == 0) |
| | average_confidence = sum(r.confidence for r in successful_predictions) / len(successful_predictions) |
| | |
| | summary = { |
| | "total_apparie": total_apparie, |
| | "total_non_apparie": total_non_apparie, |
| | "average_confidence": round(average_confidence, 4), |
| | "successful_predictions": len(successful_predictions), |
| | "failed_predictions": len(results) - len(successful_predictions) |
| | } |
| | else: |
| | summary = { |
| | "total_apparie": 0, |
| | "total_non_apparie": 0, |
| | "average_confidence": 0.0, |
| | "successful_predictions": 0, |
| | "failed_predictions": len(results) |
| | } |
| |
|
| | logger.info(f"Batch KPA prediction completed — {len(results)} items processed") |
| |
|
| | return BatchPredictionResponse( |
| | predictions=results, |
| | total_processed=len(results), |
| | summary=summary |
| | ) |
| |
|
| | except Exception as e: |
| | logger.error(f"Batch KPA prediction error: {str(e)}") |
| | raise HTTPException(status_code=500, detail=f"Batch prediction failed: {str(e)}") |