Spaces:
Sleeping
Sleeping
| """ | |
| Inference Module | |
| Handles question answering predictions | |
| """ | |
| import torch | |
| from typing import Tuple | |
| class QAInference: | |
| """Handles question answering inference""" | |
| def __init__(self, model, tokenizer, device): | |
| """ | |
| Initialize QA Inference | |
| Args: | |
| model: Loaded model | |
| tokenizer: Loaded tokenizer | |
| device: Torch device | |
| """ | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| def answer_question( | |
| self, | |
| question: str, | |
| context: str, | |
| language: str = "English", | |
| max_length: int = 64 | |
| ) -> Tuple[str, str]: | |
| """ | |
| Generate answer for given question and context | |
| Args: | |
| question: Question text | |
| context: Context/passage text | |
| language: "English" or "German" | |
| max_length: Maximum answer length | |
| Returns: | |
| Tuple of (answer, response_info) | |
| """ | |
| if not question.strip() or not context.strip(): | |
| return "β οΈ Please provide both a question and context!", "" | |
| try: | |
| # Configure language | |
| if language == "English": | |
| self.tokenizer.src_lang = "en_XX" | |
| self.tokenizer.tgt_lang = "en_XX" | |
| lang_code = self.tokenizer.lang_code_to_id["en_XX"] | |
| else: | |
| self.tokenizer.src_lang = "de_DE" | |
| self.tokenizer.tgt_lang = "de_DE" | |
| lang_code = self.tokenizer.lang_code_to_id["de_DE"] | |
| self.model.config.forced_bos_token_id = lang_code | |
| # Prepare input | |
| input_text = f"question: {question} context: {context}" | |
| inputs = self.tokenizer( | |
| input_text, | |
| max_length=256, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Generate answer | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=4, | |
| early_stopping=True, | |
| forced_bos_token_id=lang_code | |
| ) | |
| answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Calculate confidence | |
| confidence = self._calculate_confidence(answer, context) | |
| # Format response info | |
| response_info = f""" | |
| ### π Response Details | |
| - **Language**: {language} | |
| - **Answer Length**: {len(answer.split())} words | |
| - **Confidence**: {confidence} | |
| - **Model**: mBART-large-50 + LoRA | |
| """ | |
| return answer, response_info | |
| except Exception as e: | |
| return f"β Error: {str(e)}", "" | |
| def _calculate_confidence(self, answer: str, context: str) -> str: | |
| """ | |
| Calculate answer confidence (simple heuristic) | |
| Args: | |
| answer: Generated answer | |
| context: Input context | |
| Returns: | |
| Confidence level string | |
| """ | |
| if len(answer.split()) < 2: | |
| return "Low" | |
| elif answer.lower() in context.lower(): | |
| return "High" | |
| else: | |
| return "Medium" |