melvinalves's picture
Update app.py
0e2cc06 verified
# app.py – Streamlit app para predição de GO:MF
# ProtBERT / ProtBERT-BFD fine-tuned (melvinalves/FineTune)
# ESM-2 base (facebook/esm2_t33_650M_UR50D)
import os, re, numpy as np, torch, joblib, streamlit as st
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModel
from keras.models import load_model
from goatools.obo_parser import GODag
# AUTENTICAÇÃO #
login(os.environ["HF_TOKEN"])
# CONFIG #
SPACE_ID = "melvinalves/protein_function_prediction"
TOP_N = 20
THRESH = 0.37
CHUNK_PB = 512
CHUNK_ESM = 1024
# REPOSITÓRIOS HF
FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd")
BASE_ESM = "facebook/esm2_t33_650M_UR50D"
# HELPERS #
@st.cache_resource
def download_file(path):
"""Ficheiros pequenos (≤1 GB) guardados no Space."""
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
@st.cache_resource
def load_keras(name):
"""Carrega modelos Keras (MLPs e stacking)."""
return load_model(download_file(f"models/{name}"), compile=False)
@st.cache_resource
def load_hf_encoder(repo_id, subfolder=None, base_tok=None):
"""Carrega tokenizer + encoder; converte TF-weights → PyTorch on-the-fly."""
if base_tok is None:
base_tok = repo_id
tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False)
kwargs = dict(from_tf=True)
if subfolder:
kwargs["subfolder"] = subfolder
mdl = AutoModel.from_pretrained(repo_id, **kwargs)
mdl.eval()
return tok, mdl
def embed_seq(model_ref, seq, chunk):
"""Devolve embedding CLS médio; corta seq. longa em chunks se preciso."""
if isinstance(model_ref, tuple): # ProtBERT fine-tuned
repo_id, subf = model_ref
tok, mdl = load_hf_encoder(repo_id, subfolder=subf,
base_tok="Rostlab/prot_bert")
else: # modelo base ESM-2
tok, mdl = load_hf_encoder(model_ref)
parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
vecs = []
for p in parts:
toks = tok(" ".join(p), return_tensors="pt", truncation=False)
with torch.no_grad():
out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()})
vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy())
return np.mean(vecs, axis=0)
@st.cache_resource
def load_go_info():
"""Lê GO.obo e devolve {id: (name, definition bruta)}."""
dag = GODag(download_file("data/go.obo"), optional_attrs=["defn"])
return {tid: (term.name, term.defn) for tid, term in dag.items()}
GO_INFO = load_go_info()
# MODELOS #
mlp_pb = load_keras("mlp_protbert.h5")
mlp_bfd = load_keras("mlp_protbertbfd.h5")
mlp_esm = load_keras("mlp_esm2.h5")
stacking = load_keras("ensemble_stack.h5")
mlb = joblib.load(download_file("data/mlb_597.pkl"))
GO = mlb.classes_
# UI #
st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas",
page_icon="🧬", layout="centered")
st.markdown(
"""
<style>
body, .stApp { background:#FFFFFF !important; }
.block-container { padding-top:1.5rem; }
textarea { font-size:0.9rem !important; }
div[data-testid="column"]:first-child {
border-right:1px solid #E0E0E0; padding-right:1rem !important;
}
</style>
""",
unsafe_allow_html=True
)
if os.path.exists("logo.png"):
st.image("logo.png", width=180)
st.title("Predição de Funções Moleculares de Proteínas (GO:MF)")
fasta_input = st.text_area("Insere uma ou mais sequências FASTA:", height=300)
predict_clicked = st.button("Prever GO terms")
# UTILITÁRIOS #
def parse_fasta_multiple(text):
"""Extrai [(header, seq)] de texto FASTA (bloco inicial sem '>' suportado)."""
out = []
for i, blk in enumerate(text.strip().split(">")):
if not blk.strip():
continue
lines = blk.strip().splitlines()
header = lines[0].strip() if i else f"Seq_{i+1}"
seq = "".join(lines[1:] if i else lines).replace(" ", "").upper()
if seq:
out.append((header, seq))
return out
def clean_definition(defin: str) -> str:
"""
Retorna apenas o texto dentro das primeiras aspas.
Se não houver aspas, devolve texto antes do primeiro '['.
"""
if not defin:
return ""
m = re.search(r'"([^"]+)"', defin)
if m:
return m.group(1).strip()
return defin.split("[", 1)[0].strip()
def go_link(go_id, name=""):
url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}"
return f"[{go_id} - {name}]({url})" if name else f"[{go_id}]({url})"
# MOSTRAR RESULTADOS #
def mostrar(header, y_pred):
pid = header.split()[0]
uniprot = f"https://www.uniprot.org/uniprotkb/{pid}"
with st.expander(header, expanded=True):
st.markdown(
f"""
<div style="text-align:right;margin-bottom:0.5rem">
<a href="{uniprot}" target="_blank">
<button style="background:#2b8cbe;border:none;border-radius:4px;
padding:0.35rem 0.8rem;color:#fff;font-size:0.9rem;
cursor:pointer">Visitar UniProt</button>
</a>
</div>
""",
unsafe_allow_html=True
)
col1, col2 = st.columns(2)
# coluna 1 : ≥ threshold
with col1:
st.markdown(f"**GO terms com prob ≥ {THRESH}**")
hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
if hits:
for go_id in hits:
name, defin_raw = GO_INFO.get(go_id, ("- sem nome -", ""))
defin = clean_definition(defin_raw)
st.markdown(f"- {go_link(go_id, name)}")
if defin:
st.caption(defin)
else:
st.code("- nenhum -")
# coluna 2 : Top-20
with col2:
st.markdown(f"**Top {TOP_N} GO terms mais prováveis**")
for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], 1):
go_id = GO[idx]
name, _ = GO_INFO.get(go_id, ("", ""))
st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}")
# INFERÊNCIA #
if predict_clicked:
for header, seq in parse_fasta_multiple(fasta_input):
with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
y_pb = mlp_pb.predict(emb_pb)
y_bfd = mlp_bfd.predict(emb_bfd)
y_esm = mlp_esm.predict(emb_esm)[:, :597]
y_ens = stacking.predict(np.concatenate([y_pb, y_bfd, y_esm], axis=1))
mostrar(header, y_ens)
# LISTA COMPLETA COM BARRA DE PESQUISA #
with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False):
search_term = st.text_input("Filtra GO term ou nome:")
# aplicar filtro
filtered_go_terms = []
for go_id in GO:
name, _ = GO_INFO.get(go_id, ("", ""))
if search_term.strip().lower() in go_id.lower() or search_term.strip().lower() in name.lower():
filtered_go_terms.append((go_id, name))
# mostrar por colunas
if filtered_go_terms:
cols = st.columns(3)
for i, (go_id, name) in enumerate(filtered_go_terms):
cols[i % 3].markdown(f"- {go_link(go_id, name)}")
else:
st.info("Nenhum GO term corresponde ao filtro inserido.")