Spaces:
Running
Running
| import torch | |
| import os | |
| from PIL import Image | |
| from transformers import AutoModelForImageClassification, SiglipImageProcessor | |
| import gradio as gr | |
| # Alternative OCR using transformers | |
| def setup_alternative_ocr(): | |
| """Setup alternative OCR using transformers models""" | |
| try: | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| print("Setting up TrOCR for text extraction...") | |
| ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") | |
| ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") | |
| print("β TrOCR model loaded successfully!") | |
| return ocr_processor, ocr_model, True | |
| except Exception as e: | |
| print(f"β οΈ Could not load TrOCR: {e}") | |
| return None, None, False | |
| # Try to setup OCR | |
| OCR_PROCESSOR, OCR_MODEL, OCR_AVAILABLE = setup_alternative_ocr() | |
| # Model path | |
| MODEL_PATH = "./model" | |
| try: | |
| print(f"=== Loading model from: {MODEL_PATH} ===") | |
| print(f"Available files: {os.listdir(MODEL_PATH)}") | |
| # Load the model | |
| print("Loading model...") | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) | |
| print("β Model loaded successfully!") | |
| # Load image processor | |
| print("Loading image processor...") | |
| try: | |
| processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True) | |
| print("β Image processor loaded from local files!") | |
| except Exception as e: | |
| print(f"β οΈ Could not load local processor: {e}") | |
| print("Loading image processor from base SigLIP model...") | |
| processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224") | |
| print("β Image processor loaded from base model!") | |
| # Get labels | |
| if hasattr(model.config, 'id2label') and model.config.id2label: | |
| labels = model.config.id2label | |
| print(f"β Found {len(labels)} labels in model config") | |
| else: | |
| num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2 | |
| labels = {i: f"class_{i}" for i in range(num_labels)} | |
| print(f"β Created {len(labels)} generic labels") | |
| print("π Model setup complete!") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| print(f"Files in model directory: {os.listdir(MODEL_PATH)}") | |
| raise | |
| def extract_text_alternative(image): | |
| """Extract text using TrOCR model""" | |
| if not OCR_AVAILABLE: | |
| return "OCR not available" | |
| try: | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Process with TrOCR | |
| pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values | |
| generated_ids = OCR_MODEL.generate(pixel_values) | |
| generated_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| except Exception as e: | |
| return f"OCR error: {str(e)}" | |
| def classify_meme(image: Image.Image): | |
| """ | |
| Classify meme and extract text | |
| """ | |
| try: | |
| # Extract text using alternative OCR | |
| if OCR_AVAILABLE: | |
| extracted_text = extract_text_alternative(image) | |
| else: | |
| extracted_text = "OCR not available in this environment" | |
| # Process image for classification | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Get predictions | |
| predictions = {} | |
| for i in range(len(labels)): | |
| label = labels.get(i, f"class_{i}") | |
| predictions[label] = float(probs[0][i]) | |
| # Sort predictions by confidence | |
| sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) | |
| # Debug prints | |
| print("=== Classification Results ===") | |
| print(f"Extracted Text: '{extracted_text.strip()}'") | |
| print("Top 3 Predictions:") | |
| for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]): | |
| print(f" {i+1}. {label}: {prob:.4f}") | |
| return sorted_predictions, extracted_text.strip() | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| print(f"β {error_msg}") | |
| return {"Error": 1.0}, error_msg | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=classify_meme, | |
| inputs=gr.Image(type="pil", label="Upload Meme Image"), | |
| outputs=[ | |
| gr.Label(num_top_classes=5, label="Meme Classification"), | |
| gr.Textbox(label="Extracted Text", lines=3) | |
| ], | |
| title="π Meme Classifier" + (" with TrOCR" if OCR_AVAILABLE else ""), | |
| description=f""" | |
| Upload a meme image to **classify** its content using your trained SigLIP2_77 model. | |
| {'β **Text extraction** available via TrOCR (Microsoft Transformer OCR)' if OCR_AVAILABLE else 'β οΈ **Text extraction** not available'} | |
| Your model will predict the category/sentiment of the uploaded meme. | |
| """, | |
| examples=None, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| print("π Starting Gradio interface...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |