|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, re, numpy as np, torch, joblib, streamlit as st |
|
|
from huggingface_hub import login |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from keras.models import load_model |
|
|
from goatools.obo_parser import GODag |
|
|
|
|
|
|
|
|
login(os.environ["HF_TOKEN"]) |
|
|
|
|
|
|
|
|
SPACE_ID = "melvinalves/protein_function_prediction" |
|
|
TOP_N = 20 |
|
|
THRESH = 0.37 |
|
|
CHUNK_PB = 512 |
|
|
CHUNK_ESM = 1024 |
|
|
|
|
|
|
|
|
FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert") |
|
|
FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd") |
|
|
BASE_ESM = "facebook/esm2_t33_650M_UR50D" |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def download_file(path): |
|
|
"""Ficheiros pequenos (≤1 GB) guardados no Space.""" |
|
|
from huggingface_hub import hf_hub_download |
|
|
return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path) |
|
|
|
|
|
@st.cache_resource |
|
|
def load_keras(name): |
|
|
"""Carrega modelos Keras (MLPs e stacking).""" |
|
|
return load_model(download_file(f"models/{name}"), compile=False) |
|
|
|
|
|
@st.cache_resource |
|
|
def load_hf_encoder(repo_id, subfolder=None, base_tok=None): |
|
|
"""Carrega tokenizer + encoder; converte TF-weights → PyTorch on-the-fly.""" |
|
|
if base_tok is None: |
|
|
base_tok = repo_id |
|
|
tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False) |
|
|
|
|
|
kwargs = dict(from_tf=True) |
|
|
if subfolder: |
|
|
kwargs["subfolder"] = subfolder |
|
|
mdl = AutoModel.from_pretrained(repo_id, **kwargs) |
|
|
mdl.eval() |
|
|
return tok, mdl |
|
|
|
|
|
def embed_seq(model_ref, seq, chunk): |
|
|
"""Devolve embedding CLS médio; corta seq. longa em chunks se preciso.""" |
|
|
if isinstance(model_ref, tuple): |
|
|
repo_id, subf = model_ref |
|
|
tok, mdl = load_hf_encoder(repo_id, subfolder=subf, |
|
|
base_tok="Rostlab/prot_bert") |
|
|
else: |
|
|
tok, mdl = load_hf_encoder(model_ref) |
|
|
|
|
|
parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)] |
|
|
vecs = [] |
|
|
for p in parts: |
|
|
toks = tok(" ".join(p), return_tensors="pt", truncation=False) |
|
|
with torch.no_grad(): |
|
|
out = mdl(**{k: v.to(mdl.device) for k, v in toks.items()}) |
|
|
vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy()) |
|
|
return np.mean(vecs, axis=0) |
|
|
|
|
|
@st.cache_resource |
|
|
def load_go_info(): |
|
|
"""Lê GO.obo e devolve {id: (name, definition bruta)}.""" |
|
|
dag = GODag(download_file("data/go.obo"), optional_attrs=["defn"]) |
|
|
return {tid: (term.name, term.defn) for tid, term in dag.items()} |
|
|
|
|
|
GO_INFO = load_go_info() |
|
|
|
|
|
|
|
|
mlp_pb = load_keras("mlp_protbert.h5") |
|
|
mlp_bfd = load_keras("mlp_protbertbfd.h5") |
|
|
mlp_esm = load_keras("mlp_esm2.h5") |
|
|
stacking = load_keras("ensemble_stack.h5") |
|
|
|
|
|
mlb = joblib.load(download_file("data/mlb_597.pkl")) |
|
|
GO = mlb.classes_ |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas", |
|
|
page_icon="🧬", layout="centered") |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
body, .stApp { background:#FFFFFF !important; } |
|
|
.block-container { padding-top:1.5rem; } |
|
|
textarea { font-size:0.9rem !important; } |
|
|
div[data-testid="column"]:first-child { |
|
|
border-right:1px solid #E0E0E0; padding-right:1rem !important; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
if os.path.exists("logo.png"): |
|
|
st.image("logo.png", width=180) |
|
|
|
|
|
st.title("Predição de Funções Moleculares de Proteínas (GO:MF)") |
|
|
|
|
|
fasta_input = st.text_area("Insere uma ou mais sequências FASTA:", height=300) |
|
|
predict_clicked = st.button("Prever GO terms") |
|
|
|
|
|
|
|
|
def parse_fasta_multiple(text): |
|
|
"""Extrai [(header, seq)] de texto FASTA (bloco inicial sem '>' suportado).""" |
|
|
out = [] |
|
|
for i, blk in enumerate(text.strip().split(">")): |
|
|
if not blk.strip(): |
|
|
continue |
|
|
lines = blk.strip().splitlines() |
|
|
header = lines[0].strip() if i else f"Seq_{i+1}" |
|
|
seq = "".join(lines[1:] if i else lines).replace(" ", "").upper() |
|
|
if seq: |
|
|
out.append((header, seq)) |
|
|
return out |
|
|
|
|
|
def clean_definition(defin: str) -> str: |
|
|
""" |
|
|
Retorna apenas o texto dentro das primeiras aspas. |
|
|
Se não houver aspas, devolve texto antes do primeiro '['. |
|
|
""" |
|
|
if not defin: |
|
|
return "" |
|
|
m = re.search(r'"([^"]+)"', defin) |
|
|
if m: |
|
|
return m.group(1).strip() |
|
|
return defin.split("[", 1)[0].strip() |
|
|
|
|
|
def go_link(go_id, name=""): |
|
|
url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}" |
|
|
return f"[{go_id} - {name}]({url})" if name else f"[{go_id}]({url})" |
|
|
|
|
|
|
|
|
def mostrar(header, y_pred): |
|
|
pid = header.split()[0] |
|
|
uniprot = f"https://www.uniprot.org/uniprotkb/{pid}" |
|
|
|
|
|
with st.expander(header, expanded=True): |
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="text-align:right;margin-bottom:0.5rem"> |
|
|
<a href="{uniprot}" target="_blank"> |
|
|
<button style="background:#2b8cbe;border:none;border-radius:4px; |
|
|
padding:0.35rem 0.8rem;color:#fff;font-size:0.9rem; |
|
|
cursor:pointer">Visitar UniProt</button> |
|
|
</a> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
with col1: |
|
|
st.markdown(f"**GO terms com prob ≥ {THRESH}**") |
|
|
hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0] |
|
|
if hits: |
|
|
for go_id in hits: |
|
|
name, defin_raw = GO_INFO.get(go_id, ("- sem nome -", "")) |
|
|
defin = clean_definition(defin_raw) |
|
|
st.markdown(f"- {go_link(go_id, name)}") |
|
|
if defin: |
|
|
st.caption(defin) |
|
|
else: |
|
|
st.code("- nenhum -") |
|
|
|
|
|
|
|
|
with col2: |
|
|
st.markdown(f"**Top {TOP_N} GO terms mais prováveis**") |
|
|
for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], 1): |
|
|
go_id = GO[idx] |
|
|
name, _ = GO_INFO.get(go_id, ("", "")) |
|
|
st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}") |
|
|
|
|
|
|
|
|
if predict_clicked: |
|
|
for header, seq in parse_fasta_multiple(fasta_input): |
|
|
with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"): |
|
|
emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB) |
|
|
emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB) |
|
|
emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM) |
|
|
|
|
|
y_pb = mlp_pb.predict(emb_pb) |
|
|
y_bfd = mlp_bfd.predict(emb_bfd) |
|
|
y_esm = mlp_esm.predict(emb_esm)[:, :597] |
|
|
|
|
|
y_ens = stacking.predict(np.concatenate([y_pb, y_bfd, y_esm], axis=1)) |
|
|
|
|
|
mostrar(header, y_ens) |
|
|
|
|
|
|
|
|
with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False): |
|
|
search_term = st.text_input("Filtra GO term ou nome:") |
|
|
|
|
|
|
|
|
filtered_go_terms = [] |
|
|
for go_id in GO: |
|
|
name, _ = GO_INFO.get(go_id, ("", "")) |
|
|
if search_term.strip().lower() in go_id.lower() or search_term.strip().lower() in name.lower(): |
|
|
filtered_go_terms.append((go_id, name)) |
|
|
|
|
|
|
|
|
if filtered_go_terms: |
|
|
cols = st.columns(3) |
|
|
for i, (go_id, name) in enumerate(filtered_go_terms): |
|
|
cols[i % 3].markdown(f"- {go_link(go_id, name)}") |
|
|
else: |
|
|
st.info("Nenhum GO term corresponde ao filtro inserido.") |
|
|
|
|
|
|
|
|
|