import os import numpy as np import torch import streamlit as st import joblib from transformers import AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download from tensorflow.keras.models import load_model # ----------- Config Space ----------- SPACE_REPO = "melvinalves/protein_function_prediction" # <- o teu Space MODELS_DIR = "models" DATA_DIR = "data" TOP_N = 10 CHUNK_PB = 512 CHUNK_ESM = 1024 # ----------- Helpers ----------- @st.cache_resource def hf_cached(path_inside_repo: str): """Faz download (uma vez) e devolve caminho local.""" return hf_hub_download( repo_id=SPACE_REPO, repo_type="space", filename=path_inside_repo, ) @st.cache_resource def load_hf_model(model_name): tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False) mdl = AutoModel.from_pretrained(model_name); mdl.eval() return tok, mdl @st.cache_resource def load_local_model(file_name): local_path = hf_cached(f"{MODELS_DIR}/{file_name}") return load_model(local_path, compile=False) # ----------- Carregar modelos (.keras) ----------- mlp_pb = load_local_model("mlp_protbert.keras") mlp_bfd = load_local_model("mlp_protbertbfd.keras") mlp_esm = load_local_model("mlp_esm2.keras") stacking = load_local_model("ensemble_stacking.keras") # ----------- MultiLabelBinarizer ----------- mlb_path = hf_cached(f"{DATA_DIR}/mlb_597.pkl") mlb = joblib.load(mlb_path) go_terms = mlb.classes_ # ----------- Embedding por chunks ----------- def embed_sequence(model_name: str, seq: str, chunk: int) -> np.ndarray: tok, mdl = load_hf_model(model_name) fmt = lambda s: " ".join(list(s)) parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)] vecs = [] for p in parts: with torch.no_grad(): out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True)) vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy()) return np.mean(vecs, axis=0, keepdims=True) # ----------- UI ----------- st.title("Predição de Funções de Proteínas 🔬") fa_input = st.text_area("Insere a sequência FASTA:", height=200) if fa_input and st.button("Prever GO terms"): # Limpa FASTA seq = "\n".join(l for l in fa_input.splitlines() if not l.startswith(">")) seq = seq.replace(" ", "").replace("\n", "").upper() if not seq: st.warning("Sequência vazia.") st.stop() st.write("🔄 A gerar embeddings…") emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB) emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB) emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM) st.write("🧠 A fazer predições…") y_pb = mlp_pb.predict(emb_pb) y_bfd = mlp_bfd.predict(emb_bfd) y_esm = mlp_esm.predict(emb_esm)[:, :597] # garante 597 colunas X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) y_pred = stacking.predict(X_stack) # ----------- Output ----------- st.subheader("GO terms com probabilidade ≥ 0.5") hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0] st.code("\n".join(hits) or "— nenhum —") st.subheader(f"Top {TOP_N} GO terms mais prováveis") for idx in np.argsort(-y_pred[0])[:TOP_N]: st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")