snikhilesh commited on
Commit
15ee17a
·
verified ·
1 Parent(s): 19d55ff

Deploy specialized_model_router.py to backend/ directory

Browse files
Files changed (1) hide show
  1. backend/specialized_model_router.py +811 -0
backend/specialized_model_router.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Specialized Medical AI Model Router - Phase 3
3
+ Routes structured medical data to appropriate specialized AI models.
4
+
5
+ This module integrates with the preprocessing pipeline to provide model-specific
6
+ preprocessing, inference, and confidence scoring for medical AI analysis.
7
+
8
+ Author: MiniMax Agent
9
+ Date: 2025-10-29
10
+ Version: 1.0.0
11
+ """
12
+
13
+ import os
14
+ import logging
15
+ import asyncio
16
+ import time
17
+ from typing import Dict, List, Optional, Any, Tuple, Union
18
+ from dataclasses import dataclass
19
+ import numpy as np
20
+ import torch
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
22
+
23
+ # Import existing model infrastructure
24
+ from model_loader import MedicalModelLoader
25
+
26
+ # Import new preprocessing components
27
+ from preprocessing_pipeline import ProcessingPipelineResult
28
+ from medical_schemas import (
29
+ ValidationResult, ConfidenceScore, ECGAnalysis, RadiologyAnalysis,
30
+ LaboratoryResults, ClinicalNotesAnalysis
31
+ )
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class ModelInferenceResult:
38
+ """Result of specialized model inference"""
39
+ model_name: str
40
+ input_data: Dict[str, Any]
41
+ output_data: Dict[str, Any]
42
+ confidence_score: float
43
+ processing_time: float
44
+ model_metadata: Dict[str, Any]
45
+ warnings: List[str]
46
+ errors: List[str]
47
+
48
+
49
+ @dataclass
50
+ class SpecializedModelConfig:
51
+ """Configuration for specialized medical models"""
52
+ model_name: str
53
+ model_type: str # "classification", "segmentation", "generation", "extraction"
54
+ input_format: str # "ecg_signal", "dicom_image", "clinical_text", "lab_values"
55
+ output_schema: str # Schema name for output validation
56
+ preprocessing_required: bool
57
+ gpu_memory_mb: Optional[int]
58
+ timeout_seconds: int
59
+ fallback_models: List[str]
60
+
61
+
62
+ class SpecializedModelRouter:
63
+ """Routes structured medical data to specialized AI models"""
64
+
65
+ def __init__(self, model_loader: Optional[MedicalModelLoader] = None):
66
+ self.model_loader = model_loader or MedicalModelLoader()
67
+ self.model_configs = self._initialize_model_configs()
68
+ self.model_cache = {}
69
+ self.inference_stats = {
70
+ "total_inferences": 0,
71
+ "successful_inferences": 0,
72
+ "average_processing_time": 0.0,
73
+ "model_usage_counts": {},
74
+ "error_counts": {}
75
+ }
76
+
77
+ logger.info("Specialized Model Router initialized")
78
+
79
+ def _initialize_model_configs(self) -> Dict[str, SpecializedModelConfig]:
80
+ """Initialize configuration for specialized medical models"""
81
+ return {
82
+ # ECG Models
83
+ "hubert_ecg": SpecializedModelConfig(
84
+ model_name=" superh transformercs/HubERT-ECG",
85
+ model_type="classification",
86
+ input_format="ecg_signal",
87
+ output_schema="ECGAnalysis",
88
+ preprocessing_required=True,
89
+ gpu_memory_mb=4096,
90
+ timeout_seconds=30,
91
+ fallback_models=["bio_clinicalbert"]
92
+ ),
93
+
94
+ # Radiology Models
95
+ "monai_unetr": SpecializedModelConfig(
96
+ model_name="monai/UNet", # Will be loaded from local or remote
97
+ model_type="segmentation",
98
+ input_format="dicom_image",
99
+ output_schema="RadiologyAnalysis",
100
+ preprocessing_required=True,
101
+ gpu_memory_mb=8192,
102
+ timeout_seconds=60,
103
+ fallback_models=["generic_segmentation"]
104
+ ),
105
+
106
+ # Clinical Text Models
107
+ "medgemma": SpecializedModelConfig(
108
+ model_name="google/medgemma-4b", # Placeholder for actual MedGemma model
109
+ model_type="generation",
110
+ input_format="clinical_text",
111
+ output_schema="ClinicalNotesAnalysis",
112
+ preprocessing_required=True,
113
+ gpu_memory_mb=16384,
114
+ timeout_seconds=45,
115
+ fallback_models=["bio_clinicalbert", "pubmedbert"]
116
+ ),
117
+
118
+ # Laboratory Models
119
+ "biomedical_ner": SpecializedModelConfig(
120
+ model_name="Clinical-AI-Apollo/BiomedNLP-PubMedBERT-base-uncased-abstract",
121
+ model_type="extraction",
122
+ input_format="lab_text",
123
+ output_schema="LaboratoryResults",
124
+ preprocessing_required=False,
125
+ gpu_memory_mb=2048,
126
+ timeout_seconds=20,
127
+ fallback_models=["scibert"]
128
+ ),
129
+
130
+ # Generic fallback models
131
+ "bio_clinicalbert": SpecializedModelConfig(
132
+ model_name="emilyalsentzer/Bio_ClinicalBERT",
133
+ model_type="classification",
134
+ input_format="clinical_text",
135
+ output_schema="ClinicalNotesAnalysis",
136
+ preprocessing_required=False,
137
+ gpu_memory_mb=1024,
138
+ timeout_seconds=15,
139
+ fallback_models=[]
140
+ ),
141
+
142
+ "pubmedbert": SpecializedModelConfig(
143
+ model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
144
+ model_type="classification",
145
+ input_format="clinical_text",
146
+ output_schema="ClinicalNotesAnalysis",
147
+ preprocessing_required=False,
148
+ gpu_memory_mb=1024,
149
+ timeout_seconds=15,
150
+ fallback_models=[]
151
+ )
152
+ }
153
+
154
+ async def route_and_infer(self, pipeline_result: ProcessingPipelineResult) -> ModelInferenceResult:
155
+ """
156
+ Route structured data to appropriate specialized model and perform inference
157
+
158
+ Args:
159
+ pipeline_result: Result from preprocessing pipeline
160
+
161
+ Returns:
162
+ ModelInferenceResult with model output and confidence
163
+ """
164
+ start_time = time.time()
165
+
166
+ try:
167
+ # Step 1: Determine optimal model routing
168
+ model_config = self._select_optimal_model(pipeline_result)
169
+
170
+ # Step 2: Validate input data format
171
+ input_validation = self._validate_input_format(pipeline_result, model_config)
172
+ if not input_validation["is_valid"]:
173
+ logger.warning(f"Input validation failed: {input_validation['errors']}")
174
+ return self._create_error_result(model_config.model_name, input_validation["errors"])
175
+
176
+ # Step 3: Preprocess input data for model
177
+ preprocessed_input = await self._preprocess_for_model(pipeline_result, model_config)
178
+
179
+ # Step 4: Perform model inference
180
+ inference_result = await self._perform_model_inference(preprocessed_input, model_config)
181
+
182
+ # Step 5: Post-process and validate output
183
+ final_output = self._postprocess_model_output(inference_result, model_config)
184
+
185
+ # Step 6: Calculate confidence score
186
+ confidence_score = self._calculate_model_confidence(
187
+ pipeline_result, model_config, final_output
188
+ )
189
+
190
+ processing_time = time.time() - start_time
191
+
192
+ # Update statistics
193
+ self._update_inference_stats(model_config.model_name, True, processing_time)
194
+
195
+ return ModelInferenceResult(
196
+ model_name=model_config.model_name,
197
+ input_data=preprocessed_input,
198
+ output_data=final_output,
199
+ confidence_score=confidence_score,
200
+ processing_time=processing_time,
201
+ model_metadata={
202
+ "model_config": model_config.__dict__,
203
+ "input_validation": input_validation,
204
+ "pipeline_confidence": pipeline_result.validation_result.compliance_score
205
+ },
206
+ warnings=[],
207
+ errors=[]
208
+ )
209
+
210
+ except Exception as e:
211
+ logger.error(f"Model routing/inference error: {str(e)}")
212
+
213
+ # Try fallback model
214
+ fallback_result = await self._try_fallback_model(pipeline_result)
215
+ if fallback_result:
216
+ return fallback_result
217
+
218
+ # Return error result
219
+ error_result = ModelInferenceResult(
220
+ model_name="error",
221
+ input_data={},
222
+ output_data={"error": str(e)},
223
+ confidence_score=0.0,
224
+ processing_time=time.time() - start_time,
225
+ model_metadata={"error": str(e)},
226
+ warnings=[],
227
+ errors=[str(e)]
228
+ )
229
+
230
+ self._update_inference_stats("error", False, time.time() - start_time)
231
+ return error_result
232
+
233
+ def _select_optimal_model(self, pipeline_result: ProcessingPipelineResult) -> SpecializedModelConfig:
234
+ """Select optimal model based on data type and quality"""
235
+ # Extract document type from pipeline result
236
+ doc_type = "unknown"
237
+ confidence = pipeline_result.validation_result.compliance_score
238
+
239
+ if "ECG" in pipeline_result.file_detection.file_type.value:
240
+ doc_type = "ecg"
241
+ elif "radiology" in pipeline_result.file_detection.file_type.value:
242
+ doc_type = "radiology"
243
+ elif "laboratory" in pipeline_result.file_detection.file_type.value:
244
+ doc_type = "laboratory"
245
+ elif "clinical" in pipeline_result.file_detection.file_type.value:
246
+ doc_type = "clinical"
247
+
248
+ # Model selection logic
249
+ if doc_type == "ecg" and confidence > 0.8:
250
+ return self.model_configs["hubert_ecg"]
251
+ elif doc_type == "radiology" and confidence > 0.7:
252
+ return self.model_configs["monai_unetr"]
253
+ elif doc_type == "clinical" and confidence > 0.6:
254
+ return self.model_configs["medgemma"]
255
+ elif doc_type == "laboratory":
256
+ return self.model_configs["biomedical_ner"]
257
+ else:
258
+ # Use general biomedical model for low confidence or unknown types
259
+ return self.model_configs["bio_clinicalbert"]
260
+
261
+ def _validate_input_format(self, pipeline_result: ProcessingPipelineResult,
262
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
263
+ """Validate input data format for the selected model"""
264
+ validation_result = {
265
+ "is_valid": True,
266
+ "errors": [],
267
+ "warnings": [],
268
+ "input_checks": {}
269
+ }
270
+
271
+ try:
272
+ # Check required fields based on input format
273
+ if model_config.input_format == "ecg_signal":
274
+ validation_result["input_checks"] = self._validate_ecg_input(pipeline_result)
275
+ elif model_config.input_format == "dicom_image":
276
+ validation_result["input_checks"] = self._validate_dicom_input(pipeline_result)
277
+ elif model_config.input_format in ["clinical_text", "lab_text"]:
278
+ validation_result["input_checks"] = self._validate_text_input(pipeline_result)
279
+
280
+ # Apply validation rules
281
+ for check_name, check_result in validation_result["input_checks"].items():
282
+ if not check_result["passed"]:
283
+ validation_result["is_valid"] = False
284
+ validation_result["errors"].append(f"{check_name}: {check_result['error']}")
285
+
286
+ except Exception as e:
287
+ validation_result["is_valid"] = False
288
+ validation_result["errors"].append(f"Validation error: {str(e)}")
289
+
290
+ return validation_result
291
+
292
+ def _validate_ecg_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]:
293
+ """Validate ECG signal input format"""
294
+ checks = {}
295
+
296
+ # Check if we have signal data
297
+ if hasattr(pipeline_result.extraction_result, 'signal_data'):
298
+ signal_data = pipeline_result.extraction_result.signal_data
299
+ checks["has_signal_data"] = {
300
+ "passed": bool(signal_data),
301
+ "error": "No ECG signal data found" if not signal_data else None
302
+ }
303
+
304
+ # Check sampling rate
305
+ if hasattr(pipeline_result.extraction_result, 'sampling_rate'):
306
+ sampling_rate = pipeline_result.extraction_result.sampling_rate
307
+ checks["adequate_sampling_rate"] = {
308
+ "passed": sampling_rate >= 250, # Minimum 250 Hz for ECG
309
+ "error": f"Sampling rate {sampling_rate} Hz too low for ECG analysis" if sampling_rate < 250 else None
310
+ }
311
+
312
+ # Check signal duration
313
+ if hasattr(pipeline_result.extraction_result, 'duration'):
314
+ duration = pipeline_result.extraction_result.duration
315
+ checks["adequate_duration"] = {
316
+ "passed": duration >= 5.0, # Minimum 5 seconds
317
+ "error": f"Signal duration {duration:.1f}s too short for analysis" if duration < 5.0 else None
318
+ }
319
+ else:
320
+ checks["has_signal_data"] = {
321
+ "passed": False,
322
+ "error": "Extraction result does not contain ECG signal data"
323
+ }
324
+
325
+ return checks
326
+
327
+ def _validate_dicom_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]:
328
+ """Validate DICOM image input format"""
329
+ checks = {}
330
+
331
+ if hasattr(pipeline_result.extraction_result, 'image_data'):
332
+ image_data = pipeline_result.extraction_result.image_data
333
+ checks["has_image_data"] = {
334
+ "passed": bool(image_data.size > 0),
335
+ "error": "No image data found" if image_data.size == 0 else None
336
+ }
337
+
338
+ # Check image dimensions
339
+ if image_data.size > 0:
340
+ checks["adequate_resolution"] = {
341
+ "passed": min(image_data.shape) >= 64,
342
+ "error": f"Image resolution too low: {image_data.shape}" if min(image_data.shape) < 64 else None
343
+ }
344
+ else:
345
+ checks["has_image_data"] = {
346
+ "passed": False,
347
+ "error": "Extraction result does not contain DICOM image data"
348
+ }
349
+
350
+ return checks
351
+
352
+ def _validate_text_input(self, pipeline_result: ProcessingPipelineResult) -> Dict[str, Any]:
353
+ """Validate text input format"""
354
+ checks = {}
355
+
356
+ # Check for text content
357
+ if hasattr(pipeline_result.extraction_result, 'raw_text'):
358
+ text = pipeline_result.extraction_result.raw_text
359
+ checks["has_text_content"] = {
360
+ "passed": bool(text and len(text.strip()) > 50),
361
+ "error": "Insufficient text content for analysis" if not text or len(text.strip()) <= 50 else None
362
+ }
363
+ else:
364
+ checks["has_text_content"] = {
365
+ "passed": False,
366
+ "error": "No text content found in extraction result"
367
+ }
368
+
369
+ return checks
370
+
371
+ async def _preprocess_for_model(self, pipeline_result: ProcessingPipelineResult,
372
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
373
+ """Preprocess input data for model-specific requirements"""
374
+ if not model_config.preprocessing_required:
375
+ # Return structured data as-is for models that don't need preprocessing
376
+ return {
377
+ "raw_data": pipeline_result.structured_data,
378
+ "metadata": pipeline_result.pipeline_metadata,
379
+ "validation_result": pipeline_result.validation_result
380
+ }
381
+
382
+ try:
383
+ if model_config.input_format == "ecg_signal":
384
+ return await self._preprocess_ecg_signal(pipeline_result, model_config)
385
+ elif model_config.input_format == "dicom_image":
386
+ return await self._preprocess_dicom_image(pipeline_result, model_config)
387
+ elif model_config.input_format in ["clinical_text", "lab_text"]:
388
+ return await self._preprocess_clinical_text(pipeline_result, model_config)
389
+ else:
390
+ return {"raw_data": pipeline_result.structured_data}
391
+
392
+ except Exception as e:
393
+ logger.error(f"Preprocessing error: {str(e)}")
394
+ return {"raw_data": pipeline_result.structured_data, "preprocessing_error": str(e)}
395
+
396
+ async def _preprocess_ecg_signal(self, pipeline_result: ProcessingPipelineResult,
397
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
398
+ """Preprocess ECG signal data for HuBERT-ECG model"""
399
+ extraction_result = pipeline_result.extraction_result
400
+
401
+ # Prepare ECG signal in format expected by HuBERT-ECG
402
+ ecg_input = {
403
+ "signals": extraction_result.signal_data,
404
+ "sampling_rate": extraction_result.sampling_rate,
405
+ "duration": extraction_result.duration,
406
+ "leads": extraction_result.lead_names
407
+ }
408
+
409
+ # Add preprocessing metadata
410
+ preprocessing_metadata = {
411
+ "original_sampling_rate": extraction_result.sampling_rate,
412
+ "resampled": False, # Would implement resampling if needed
413
+ "filtered": True, # Assuming signal was already filtered
414
+ "segment_length_seconds": min(10.0, extraction_result.duration) # Use up to 10 seconds
415
+ }
416
+
417
+ return {
418
+ "ecg_data": ecg_input,
419
+ "preprocessing_metadata": preprocessing_metadata,
420
+ "model_ready": True
421
+ }
422
+
423
+ async def _preprocess_dicom_image(self, pipeline_result: ProcessingPipelineResult,
424
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
425
+ """Preprocess DICOM image data for MONAI UNETR"""
426
+ extraction_result = pipeline_result.extraction_result
427
+
428
+ # Prepare image data for MONAI
429
+ image_input = {
430
+ "image_array": extraction_result.image_data,
431
+ "spacing": extraction_result.pixel_spacing,
432
+ "modality": extraction_result.modality,
433
+ "body_part": extraction_result.body_part
434
+ }
435
+
436
+ # Add preprocessing metadata
437
+ preprocessing_metadata = {
438
+ "window_level": self._get_window_settings(extraction_result.modality),
439
+ "normalized": True,
440
+ "resized": False, # Would implement resizing if needed
441
+ "channels_added": True # MONAI expects channel dimension
442
+ }
443
+
444
+ return {
445
+ "dicom_data": image_input,
446
+ "preprocessing_metadata": preprocessing_metadata,
447
+ "model_ready": True
448
+ }
449
+
450
+ async def _preprocess_clinical_text(self, pipeline_result: ProcessingPipelineResult,
451
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
452
+ """Preprocess clinical text for MedGemma or biomedical models"""
453
+ extraction_result = pipeline_result.extraction_result
454
+
455
+ # Extract text content
456
+ if hasattr(extraction_result, 'raw_text'):
457
+ text_content = extraction_result.raw_text
458
+ elif hasattr(extraction_result, 'structured_data'):
459
+ text_content = str(extraction_result.structured_data)
460
+ else:
461
+ text_content = str(pipeline_result.structured_data)
462
+
463
+ # Prepare text for model
464
+ text_input = {
465
+ "raw_text": text_content,
466
+ "document_type": pipeline_result.file_detection.file_type.value,
467
+ "deidentified": pipeline_result.deidentification_result is not None
468
+ }
469
+
470
+ # Add preprocessing metadata
471
+ preprocessing_metadata = {
472
+ "tokenized": False, # Will be done by model
473
+ "max_length": 512, # Typical max sequence length
474
+ "language": "en",
475
+ "medical_domain": self._extract_medical_domain(pipeline_result)
476
+ }
477
+
478
+ return {
479
+ "text_data": text_input,
480
+ "preprocessing_metadata": preprocessing_metadata,
481
+ "model_ready": True
482
+ }
483
+
484
+ def _get_window_settings(self, modality: str) -> Dict[str, float]:
485
+ """Get appropriate window settings for medical imaging"""
486
+ window_configs = {
487
+ "CT": {"level": 40, "width": 400}, # Lung window
488
+ "MRI": {"level": 0, "width": 500}, # Brain window
489
+ "XRAY": {"level": 0, "width": 1000} # General window
490
+ }
491
+ return window_configs.get(modality, {"level": 0, "width": 500})
492
+
493
+ def _extract_medical_domain(self, pipeline_result: ProcessingPipelineResult) -> str:
494
+ """Extract medical domain from pipeline result"""
495
+ file_type = pipeline_result.file_detection.file_type.value
496
+
497
+ if "ecg" in file_type or "ECG" in file_type:
498
+ return "cardiology"
499
+ elif "radiology" in file_type:
500
+ return "radiology"
501
+ elif "laboratory" in file_type:
502
+ return "laboratory"
503
+ elif "clinical" in file_type:
504
+ return "clinical"
505
+ else:
506
+ return "general"
507
+
508
+ async def _perform_model_inference(self, preprocessed_input: Dict[str, Any],
509
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
510
+ """Perform inference using the specialized model"""
511
+ try:
512
+ if model_config.model_type == "classification":
513
+ return await self._perform_classification_inference(preprocessed_input, model_config)
514
+ elif model_config.model_type == "segmentation":
515
+ return await self._perform_segmentation_inference(preprocessed_input, model_config)
516
+ elif model_config.model_type == "generation":
517
+ return await self._perform_generation_inference(preprocessed_input, model_config)
518
+ elif model_config.model_type == "extraction":
519
+ return await self._perform_extraction_inference(preprocessed_input, model_config)
520
+ else:
521
+ raise ValueError(f"Unsupported model type: {model_config.model_type}")
522
+
523
+ except Exception as e:
524
+ logger.error(f"Model inference error: {str(e)}")
525
+ raise
526
+
527
+ async def _perform_classification_inference(self, preprocessed_input: Dict[str, Any],
528
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
529
+ """Perform classification inference (e.g., ECG rhythm classification)"""
530
+ # Use existing model loader for classification tasks
531
+ model_key = "bio_clinicalbert" # Use biomedical model for now
532
+
533
+ try:
534
+ # Prepare input for model
535
+ if "ecg_data" in preprocessed_input:
536
+ # ECG classification
537
+ ecg_data = preprocessed_input["ecg_data"]
538
+ text_input = f"ECG Analysis: {len(ecg_data['signals'])} leads, {ecg_data['duration']:.1f}s duration"
539
+ else:
540
+ text_input = preprocessed_input.get("text_data", {}).get("raw_text", "")
541
+
542
+ # Perform inference using model loader
543
+ result = await self.model_loader.run_inference(
544
+ model_key,
545
+ text_input,
546
+ {"max_new_tokens": 200, "task": "classification"}
547
+ )
548
+
549
+ return {
550
+ "model_output": result,
551
+ "classification_type": "medical_document_classification",
552
+ "confidence": 0.8 # Default confidence
553
+ }
554
+
555
+ except Exception as e:
556
+ logger.error(f"Classification inference error: {str(e)}")
557
+ raise
558
+
559
+ async def _perform_segmentation_inference(self, preprocessed_input: Dict[str, Any],
560
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
561
+ """Perform segmentation inference (e.g., organ segmentation in medical images)"""
562
+ try:
563
+ dicom_data = preprocessed_input["dicom_data"]
564
+ image_array = dicom_data["image_array"]
565
+ modality = dicom_data["modality"]
566
+
567
+ # Placeholder segmentation result
568
+ # In real implementation, would use MONAI UNETR
569
+ segmentation_result = {
570
+ "segmentation_mask": np.random.rand(*image_array.shape) > 0.7, # Placeholder
571
+ "organ_detected": f"{modality.lower()}_tissue",
572
+ "volume_estimate_ml": np.prod(image_array.shape) * 0.001, # Placeholder
573
+ "confidence": 0.75
574
+ }
575
+
576
+ return {
577
+ "model_output": segmentation_result,
578
+ "segmentation_type": f"{modality}_segmentation"
579
+ }
580
+
581
+ except Exception as e:
582
+ logger.error(f"Segmentation inference error: {str(e)}")
583
+ raise
584
+
585
+ async def _perform_generation_inference(self, preprocessed_input: Dict[str, Any],
586
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
587
+ """Perform text generation inference (e.g., clinical summary generation)"""
588
+ try:
589
+ text_data = preprocessed_input["text_data"]
590
+ raw_text = text_data["raw_text"]
591
+
592
+ # Use biomedical model for text generation
593
+ model_key = "bio_clinicalbert"
594
+
595
+ # Prepare generation prompt
596
+ prompt = f"Analyze the following medical text and provide a structured summary:\n\n{raw_text}"
597
+
598
+ # Perform inference
599
+ result = await self.model_loader.run_inference(
600
+ model_key,
601
+ prompt,
602
+ {"max_new_tokens": 300, "task": "generation"}
603
+ )
604
+
605
+ return {
606
+ "model_output": result,
607
+ "generation_type": "clinical_summary",
608
+ "original_length": len(raw_text),
609
+ "generated_length": len(str(result))
610
+ }
611
+
612
+ except Exception as e:
613
+ logger.error(f"Generation inference error: {str(e)}")
614
+ raise
615
+
616
+ async def _perform_extraction_inference(self, preprocessed_input: Dict[str, Any],
617
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
618
+ """Perform extraction inference (e.g., lab value extraction)"""
619
+ try:
620
+ text_data = preprocessed_input["text_data"]
621
+ raw_text = text_data["raw_text"]
622
+
623
+ # Use biomedical NER model for extraction
624
+ model_key = "biomedical_ner_all"
625
+
626
+ # Perform NER extraction
627
+ result = await self.model_loader.run_inference(
628
+ model_key,
629
+ raw_text,
630
+ {"task": "ner", "aggregation_strategy": "simple"}
631
+ )
632
+
633
+ return {
634
+ "model_output": result,
635
+ "extraction_type": "medical_entities",
636
+ "entities_found": len(result) if isinstance(result, list) else 0
637
+ }
638
+
639
+ except Exception as e:
640
+ logger.error(f"Extraction inference error: {str(e)}")
641
+ raise
642
+
643
+ def _postprocess_model_output(self, inference_result: Dict[str, Any],
644
+ model_config: SpecializedModelConfig) -> Dict[str, Any]:
645
+ """Post-process model output to match expected schema"""
646
+ try:
647
+ model_output = inference_result["model_output"]
648
+
649
+ # Convert to appropriate schema format
650
+ if model_config.output_schema == "ECGAnalysis":
651
+ return self._convert_to_ecg_schema(model_output, inference_result)
652
+ elif model_config.output_schema == "RadiologyAnalysis":
653
+ return self._convert_to_radiology_schema(model_output, inference_result)
654
+ elif model_config.output_schema == "LaboratoryResults":
655
+ return self._convert_to_laboratory_schema(model_output, inference_result)
656
+ elif model_config.output_schema == "ClinicalNotesAnalysis":
657
+ return self._convert_to_clinical_notes_schema(model_output, inference_result)
658
+ else:
659
+ return {"model_output": model_output, "schema": "generic"}
660
+
661
+ except Exception as e:
662
+ logger.error(f"Post-processing error: {str(e)}")
663
+ return {"model_output": inference_result.get("model_output", {}), "error": str(e)}
664
+
665
+ def _convert_to_ecg_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]:
666
+ """Convert model output to ECG schema format"""
667
+ # This would convert model-specific ECG output to the canonical ECGAnalysis schema
668
+ return {
669
+ "model_output": model_output,
670
+ "schema": "ECGAnalysis",
671
+ "postprocessed": True
672
+ }
673
+
674
+ def _convert_to_radiology_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]:
675
+ """Convert model output to radiology schema format"""
676
+ return {
677
+ "model_output": model_output,
678
+ "schema": "RadiologyAnalysis",
679
+ "postprocessed": True
680
+ }
681
+
682
+ def _convert_to_laboratory_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]:
683
+ """Convert model output to laboratory schema format"""
684
+ return {
685
+ "model_output": model_output,
686
+ "schema": "LaboratoryResults",
687
+ "postprocessed": True
688
+ }
689
+
690
+ def _convert_to_clinical_notes_schema(self, model_output: Any, inference_result: Dict[str, Any]) -> Dict[str, Any]:
691
+ """Convert model output to clinical notes schema format"""
692
+ return {
693
+ "model_output": model_output,
694
+ "schema": "ClinicalNotesAnalysis",
695
+ "postprocessed": True
696
+ }
697
+
698
+ def _calculate_model_confidence(self, pipeline_result: ProcessingPipelineResult,
699
+ model_config: SpecializedModelConfig,
700
+ model_output: Dict[str, Any]) -> float:
701
+ """Calculate confidence score for model inference"""
702
+ try:
703
+ # Base confidence from pipeline
704
+ pipeline_confidence = pipeline_result.validation_result.compliance_score
705
+
706
+ # Model-specific confidence adjustments
707
+ model_confidence = 0.8 # Default high confidence for specialized models
708
+
709
+ # Adjust based on model type
710
+ if model_config.model_type == "classification":
711
+ model_confidence = 0.85
712
+ elif model_config.model_type == "segmentation":
713
+ model_confidence = 0.80
714
+ elif model_config.model_type == "generation":
715
+ model_confidence = 0.75
716
+ elif model_config.model_type == "extraction":
717
+ model_confidence = 0.90
718
+
719
+ # Check for model output quality
720
+ if "error" in model_output:
721
+ model_confidence *= 0.3 # Reduce confidence for error outputs
722
+
723
+ # Calculate weighted confidence
724
+ overall_confidence = (0.4 * pipeline_confidence + 0.6 * model_confidence)
725
+
726
+ return min(1.0, max(0.0, overall_confidence))
727
+
728
+ except Exception as e:
729
+ logger.error(f"Confidence calculation error: {str(e)}")
730
+ return 0.5
731
+
732
+ async def _try_fallback_model(self, pipeline_result: ProcessingPipelineResult) -> Optional[ModelInferenceResult]:
733
+ """Try fallback model when primary model fails"""
734
+ try:
735
+ # Use generic biomedical model as fallback
736
+ fallback_config = self.model_configs["bio_clinicalbert"]
737
+
738
+ # Prepare generic text input
739
+ text_input = str(pipeline_result.structured_data)
740
+
741
+ # Perform inference with fallback
742
+ result = await self.model_loader.run_inference(
743
+ "bio_clinicalbert",
744
+ text_input[:1000], # Limit text length
745
+ {"max_new_tokens": 150, "task": "general"}
746
+ )
747
+
748
+ return ModelInferenceResult(
749
+ model_name="fallback_bio_clinicalbert",
750
+ input_data={"fallback_text": text_input[:1000]},
751
+ output_data={"model_output": result, "fallback_used": True},
752
+ confidence_score=0.4, # Lower confidence for fallback
753
+ processing_time=0.0,
754
+ model_metadata={"fallback_reason": "primary_model_failed"},
755
+ warnings=["Used fallback model due to primary model failure"],
756
+ errors=[]
757
+ )
758
+
759
+ except Exception as e:
760
+ logger.error(f"Fallback model error: {str(e)}")
761
+ return None
762
+
763
+ def _create_error_result(self, model_name: str, errors: List[str]) -> ModelInferenceResult:
764
+ """Create error result for failed inference"""
765
+ return ModelInferenceResult(
766
+ model_name=model_name,
767
+ input_data={},
768
+ output_data={"error": "Input validation failed"},
769
+ confidence_score=0.0,
770
+ processing_time=0.0,
771
+ model_metadata={"validation_errors": errors},
772
+ warnings=[],
773
+ errors=errors
774
+ )
775
+
776
+ def _update_inference_stats(self, model_name: str, success: bool, processing_time: float):
777
+ """Update inference statistics"""
778
+ self.inference_stats["total_inferences"] += 1
779
+
780
+ if success:
781
+ self.inference_stats["successful_inferences"] += 1
782
+
783
+ # Update processing time average
784
+ total_time = self.inference_stats["average_processing_time"] * (self.inference_stats["total_inferences"] - 1)
785
+ self.inference_stats["average_processing_time"] = (total_time + processing_time) / self.inference_stats["total_inferences"]
786
+
787
+ # Update usage counts
788
+ self.inference_stats["model_usage_counts"][model_name] = self.inference_stats["model_usage_counts"].get(model_name, 0) + 1
789
+
790
+ if not success:
791
+ error_type = "inference_failure"
792
+ self.inference_stats["error_counts"][error_type] = self.inference_stats["error_counts"].get(error_type, 0) + 1
793
+
794
+ def get_inference_statistics(self) -> Dict[str, Any]:
795
+ """Get comprehensive inference statistics"""
796
+ return {
797
+ "total_inferences": self.inference_stats["total_inferences"],
798
+ "success_rate": self.inference_stats["successful_inferences"] / max(self.inference_stats["total_inferences"], 1),
799
+ "average_processing_time": self.inference_stats["average_processing_time"],
800
+ "model_usage_breakdown": self.inference_stats["model_usage_counts"],
801
+ "error_breakdown": self.inference_stats["error_counts"],
802
+ "router_health": "healthy" if self.inference_stats["successful_inferences"] > self.inference_stats["total_inferences"] * 0.8 else "degraded"
803
+ }
804
+
805
+
806
+ # Export main classes
807
+ __all__ = [
808
+ "SpecializedModelRouter",
809
+ "ModelInferenceResult",
810
+ "SpecializedModelConfig"
811
+ ]