Spaces:
Restarting
Restarting
| """Model manager for keypoint–argument matching model""" | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class KpaModelManager: | |
| """Manages loading and inference for keypoint matching model""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| self.max_length = 256 | |
| self.model_id = None | |
| def load_model(self, model_id: str, api_key: str = None): | |
| """Load complete model and tokenizer directly from Hugging Face""" | |
| if self.model_loaded: | |
| logger.info("KPA model already loaded") | |
| return | |
| try: | |
| # Debug: Vérifier les paramètres d'entrée | |
| logger.info(f"=== DEBUG KPA MODEL LOADING ===") | |
| logger.info(f"model_id reçu: {model_id}") | |
| logger.info(f"model_id type: {type(model_id)}") | |
| logger.info(f"api_key présent: {api_key is not None}") | |
| if model_id is None: | |
| raise ValueError("model_id cannot be None - check your .env file") | |
| logger.info(f"Loading KPA model from Hugging Face: {model_id}") | |
| # Determine device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Store model ID | |
| self.model_id = model_id | |
| # Prepare token for authentication if API key is provided | |
| token = api_key if api_key else None | |
| # Load tokenizer and model directly from Hugging Face | |
| logger.info("Step 1: Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| logger.info("✓ Tokenizer loaded successfully") | |
| logger.info("Step 2: Loading model...") | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True | |
| ) | |
| logger.info("✓ Model architecture loaded") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info("✓ Model moved to device and set to eval mode") | |
| self.model_loaded = True | |
| logger.info("✓ KPA model loaded successfully from Hugging Face!") | |
| logger.info(f"=== KPA MODEL LOADING COMPLETE ===") | |
| except Exception as e: | |
| logger.error(f"❌ Error loading KPA model: {str(e)}") | |
| logger.error(f"❌ Model ID was: {model_id}") | |
| logger.error(f"❌ API Key present: {api_key is not None}") | |
| raise RuntimeError(f"Failed to load KPA model: {str(e)}") | |
| def predict(self, argument: str, key_point: str) -> dict: | |
| """Run a prediction for (argument, key_point)""" | |
| if not self.model_loaded: | |
| raise RuntimeError("KPA model not loaded") | |
| try: | |
| # Tokenize input | |
| encoding = self.tokenizer( | |
| argument, | |
| key_point, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| outputs = self.model(**encoding) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities, dim=-1).item() | |
| confidence = probabilities[0][predicted_class].item() | |
| return { | |
| "prediction": predicted_class, | |
| "confidence": confidence, | |
| "label": "apparie" if predicted_class == 1 else "non_apparie", | |
| "probabilities": { | |
| "non_apparie": probabilities[0][0].item(), | |
| "apparie": probabilities[0][1].item(), | |
| }, | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {str(e)}") | |
| raise RuntimeError(f"KPA prediction failed: {str(e)}") | |
| def get_model_info(self): | |
| """Get model information""" | |
| if not self.model_loaded: | |
| return {"loaded": False} | |
| return { | |
| "model_name": self.model_id, | |
| "device": str(self.device), | |
| "max_length": self.max_length, | |
| "num_labels": 2, | |
| "loaded": self.model_loaded | |
| } | |
| # Initialize singleton instance | |
| kpa_model_manager = KpaModelManager() |