Spaces:
Sleeping
Sleeping
| # model.py - Fixed for CodeT5+ | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| from functools import lru_cache | |
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| _tokenizer = None | |
| _model = None | |
| _model_loading = False | |
| _model_loaded = False | |
| def get_model_config(): | |
| return { | |
| "model_id": "Salesforce/codet5p-220m", | |
| "trust_remote_code": True | |
| } | |
| def load_model_sync(): | |
| global _tokenizer, _model, _model_loaded | |
| if _model_loaded: | |
| return _tokenizer, _model | |
| config = get_model_config() | |
| model_id = config["model_id"] | |
| logger.info(f"π§ Loading model {model_id}...") | |
| try: | |
| cache_dir = os.environ.get("TRANSFORMERS_CACHE", "./model_cache") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| logger.info("π Loading tokenizer...") | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=config["trust_remote_code"], | |
| cache_dir=cache_dir, | |
| use_fast=True, | |
| ) | |
| logger.info("π§ Loading model...") | |
| _model = AutoModelForSeq2SeqLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=config["trust_remote_code"], | |
| cache_dir=cache_dir | |
| ) | |
| _model.eval() | |
| _model_loaded = True | |
| logger.info("β Model loaded successfully!") | |
| return _tokenizer, _model | |
| except Exception as e: | |
| logger.error(f"β Failed to load model: {e}") | |
| raise | |
| async def load_model_async(): | |
| global _model_loading | |
| if _model_loaded: | |
| return _tokenizer, _model | |
| if _model_loading: | |
| while _model_loading and not _model_loaded: | |
| await asyncio.sleep(0.1) | |
| return _tokenizer, _model | |
| _model_loading = True | |
| try: | |
| loop = asyncio.get_event_loop() | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| tokenizer, model = await loop.run_in_executor(executor, load_model_sync) | |
| return tokenizer, model | |
| finally: | |
| _model_loading = False | |
| def get_model(): | |
| if not _model_loaded: | |
| return load_model_sync() | |
| return _tokenizer, _model | |
| def is_model_loaded(): | |
| return _model_loaded | |
| def get_model_info(): | |
| config = get_model_config() | |
| return { | |
| "model_id": config["model_id"], | |
| "loaded": _model_loaded, | |
| "loading": _model_loading, | |
| } | |