Spaces:
Running
Running
| import torch | |
| import os | |
| from PIL import Image | |
| from transformers import AutoModelForImageClassification, SiglipImageProcessor | |
| import gradio as gr | |
| import pytesseract | |
| # Model path | |
| MODEL_PATH = "./model" | |
| try: | |
| print(f"=== Loading model from: {MODEL_PATH} ===") | |
| print(f"Available files: {os.listdir(MODEL_PATH)}") | |
| # Load the model (this should work with your files) | |
| print("Loading model...") | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) | |
| print("β Model loaded successfully!") | |
| # Load just the image processor (not the full AutoProcessor) | |
| print("Loading image processor...") | |
| try: | |
| # Try to load the image processor from your local files | |
| processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True) | |
| print("β Image processor loaded from local files!") | |
| except Exception as e: | |
| print(f"β οΈ Could not load local processor: {e}") | |
| print("Loading image processor from base SigLIP model...") | |
| # Fallback: load processor from base model online | |
| processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224") | |
| print("β Image processor loaded from base model!") | |
| # Get labels from your model config | |
| if hasattr(model.config, 'id2label') and model.config.id2label: | |
| labels = model.config.id2label | |
| print(f"β Found {len(labels)} labels in model config") | |
| else: | |
| # Create generic labels if none exist | |
| num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2 | |
| labels = {i: f"class_{i}" for i in range(num_labels)} | |
| print(f"β Created {len(labels)} generic labels") | |
| print("π Model setup complete!") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| print("\n=== Debug Information ===") | |
| print(f"Files in model directory: {os.listdir(MODEL_PATH)}") | |
| raise | |
| def classify_meme(image: Image.Image): | |
| """ | |
| Classify meme and extract text using OCR | |
| """ | |
| try: | |
| # OCR: extract text from image | |
| extracted_text = pytesseract.image_to_string(image) | |
| # Process image for the model | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Get predictions | |
| predictions = {} | |
| for i in range(len(labels)): | |
| label = labels.get(i, f"class_{i}") | |
| predictions[label] = float(probs[0][i]) | |
| # Sort predictions by confidence | |
| sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) | |
| # Debug prints | |
| print("=== Classification Results ===") | |
| print(f"Extracted Text: '{extracted_text.strip()}'") | |
| print("Top 3 Predictions:") | |
| for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]): | |
| print(f" {i+1}. {label}: {prob:.4f}") | |
| return sorted_predictions, extracted_text.strip() | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| print(f"β {error_msg}") | |
| return {"Error": 1.0}, error_msg | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=classify_meme, | |
| inputs=gr.Image(type="pil", label="Upload Meme Image"), | |
| outputs=[ | |
| gr.Label(num_top_classes=5, label="Meme Classification"), | |
| gr.Textbox(label="Extracted Text", lines=3) | |
| ], | |
| title="π Meme Classifier with OCR", | |
| description=""" | |
| Upload a meme image to: | |
| 1. **Classify** its content using your trained SigLIP2_77 model | |
| 2. **Extract text** using OCR (Optical Character Recognition) | |
| Your model was trained on meme data and will predict the category/sentiment of the uploaded meme. | |
| """, | |
| examples=None, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| print("π Starting Gradio interface...") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |