melvinalves commited on
Commit
c6dfc57
Β·
verified Β·
1 Parent(s): a0a95a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -45
app.py CHANGED
@@ -1,18 +1,16 @@
1
- import os
2
- import re
3
- import numpy as np
4
- import torch
5
- import joblib
6
- import streamlit as st
7
  from transformers import AutoTokenizer, AutoModel
8
- from huggingface_hub import hf_hub_download
9
  from keras.models import load_model
10
  from goatools.obo_parser import GODag
11
 
12
-
13
- from huggingface_hub import login
14
-
15
  login(os.environ["HF_TOKEN"])
 
16
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
17
  SPACE_ID = "melvinalves/protein_function_prediction"
18
  TOP_N = 10
@@ -20,36 +18,63 @@ THRESH = 0.37
20
  CHUNK_PB = 512
21
  CHUNK_ESM = 1024
22
 
 
 
 
 
 
23
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
24
  @st.cache_resource
25
  def download_file(path):
 
 
26
  return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
27
 
28
  @st.cache_resource
29
  def load_keras(name):
 
30
  return load_model(download_file(f"models/{name}"), compile=False)
31
 
32
  @st.cache_resource
33
- def load_hf_encoder(model):
34
- tok = AutoTokenizer.from_pretrained(model, do_lower_case=False)
35
- mdl = AutoModel.from_pretrained(model)
 
 
 
 
 
 
 
 
36
  mdl.eval()
37
  return tok, mdl
38
 
39
- def embed_seq(model, seq, chunk):
40
- tok, mdl = load_hf_encoder(model)
 
 
 
 
 
 
 
 
 
 
41
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
42
- vecs = []
43
  for p in parts:
 
44
  with torch.no_grad():
45
- out = mdl(**tok(" ".join(p), return_tensors="pt", truncation=False))
46
- vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
47
  return np.mean(vecs, axis=0, keepdims=True)
48
 
49
  @st.cache_resource
50
  def load_go_info():
51
  obo_path = download_file("data/go.obo")
52
- dag = GODag(obo_path, optional_attrs=['defn'])
53
  return {tid: (term.name, term.defn) for tid, term in dag.items()}
54
 
55
  GO_INFO = load_go_info()
@@ -67,9 +92,7 @@ GO = mlb.classes_
67
  st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓ­nas")
68
 
69
  st.markdown(
70
- """
71
- <style> textarea { font-size: 0.9rem !important; } </style>
72
- """,
73
  unsafe_allow_html=True,
74
  )
75
 
@@ -78,58 +101,55 @@ predict_clicked = st.button("Prever GO terms")
78
 
79
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PARSE DE MÚLTIPLAS SEQUÊNCIAS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
80
  def parse_fasta_multiple(fasta_str):
81
- entries = fasta_str.strip().split(">")
82
- parsed = []
83
-
84
  for i, entry in enumerate(entries):
85
  if not entry.strip():
86
  continue
87
-
88
  lines = entry.strip().splitlines()
89
-
90
- # Verifica se estamos num bloco com '>'
91
  if i > 0:
92
  header = lines[0].strip()
93
- seq = "".join(l.strip() for l in lines[1:]).replace(" ", "").upper()
94
  else:
95
- # Entrada sem '>', trata tudo como sequΓͺncia
96
  header = f"Seq_{i+1}"
97
- seq = "".join(l.strip() for l in lines).replace(" ", "").upper()
98
-
99
  if seq:
100
  parsed.append((header, seq))
101
  return parsed
102
 
 
103
  if predict_clicked:
104
  parsed_seqs = parse_fasta_multiple(fasta_input)
105
-
106
  if not parsed_seqs:
107
  st.warning("NΓ£o foi possΓ­vel encontrar nenhuma sequΓͺncia vΓ‘lida.")
108
  st.stop()
109
 
110
  for header, seq in parsed_seqs:
111
  with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
112
- emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
113
- emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
114
- emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
 
115
 
 
116
  y_pb = mlp_pb.predict(emb_pb)
117
  y_bfd = mlp_bfd.predict(emb_bfd)
118
  y_esm = mlp_esm.predict(emb_esm)[:, :597]
 
 
119
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
120
  y_ens = stacking.predict(X)
121
 
122
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” RESULTADOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
123
- def mostrar(tag, y_pred):
124
  with st.expander(tag, expanded=True):
125
  hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
126
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
127
  if hits:
128
  for go_id in hits:
129
  name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", ""))
130
- limpa_def = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1', defin or "")
131
  st.write(f"**{go_id} β€” {name}**")
132
- st.caption(limpa_def)
133
  else:
134
  st.code("β€” nenhum β€”")
135
 
@@ -139,9 +159,8 @@ if predict_clicked:
139
  name, _ = GO_INFO.get(go_id, ("", ""))
140
  st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
141
 
142
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” ESCOLHE QUAIS MOSTRAR β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
143
-
144
- # mostrar(f"{header} β€” ProtBERT (MLP)", y_pb)
145
- # mostrar(f"{header} β€” ProtBERT-BFD (MLP)", y_bfd)
146
- # mostrar(f"{header} β€” ESM-2 (MLP)", y_esm)
147
- mostrar(f"{header}", y_ens)
 
1
+ # -------------------------------------------------------------------------------------------------
2
+ # app.py – Streamlit app para prediΓ§Γ£o de GO:MF
3
+ # VersΓ£o: usa ProtBERT & ProtBERT-BFD fine-tuned (melvinalves/FineTune) + ESM-2 base
4
+ # -------------------------------------------------------------------------------------------------
5
+ import os, re, numpy as np, torch, joblib, streamlit as st
6
+ from huggingface_hub import login
7
  from transformers import AutoTokenizer, AutoModel
 
8
  from keras.models import load_model
9
  from goatools.obo_parser import GODag
10
 
11
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” AUTHENTICAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
 
 
12
  login(os.environ["HF_TOKEN"])
13
+
14
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
15
  SPACE_ID = "melvinalves/protein_function_prediction"
16
  TOP_N = 10
 
18
  CHUNK_PB = 512
19
  CHUNK_ESM = 1024
20
 
21
+ # RepositΓ³rios dos modelos
22
+ FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
23
+ FINETUNED_BFD = ("melvinalves/FineTune", "fineTunedProtbertbfd")
24
+ BASE_ESM = "facebook/esm2_t33_650M_UR50D"
25
+
26
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
27
  @st.cache_resource
28
  def download_file(path):
29
+ """Ficheiros pequenos guardados no repositΓ³rio do Space (≀1 GB total)."""
30
+ from huggingface_hub import hf_hub_download
31
  return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
32
 
33
  @st.cache_resource
34
  def load_keras(name):
35
+ """Carrega modelos Keras (MLPs + stacking)."""
36
  return load_model(download_file(f"models/{name}"), compile=False)
37
 
38
  @st.cache_resource
39
+ def load_hf_encoder(repo_id, subfolder=None, base_tok="Rostlab/prot_bert"):
40
+ """
41
+ Carrega um encoder HF (PyTorch) – se existir apenas tf_model.h5 no repo,
42
+ usa from_tf=True para converter on-the-fly.
43
+ """
44
+ tok = AutoTokenizer.from_pretrained(base_tok, do_lower_case=False)
45
+ mdl = AutoModel.from_pretrained(
46
+ repo_id,
47
+ subfolder=subfolder,
48
+ from_tf=True, # converte pesos TF se necessΓ‘rio
49
+ )
50
  mdl.eval()
51
  return tok, mdl
52
 
53
+ def embed_seq(model_ref, seq, chunk):
54
+ """
55
+ Extrai embedding mΓ©dio (CLS) para sequΓͺncias grandes usando chunks.
56
+ - model_ref pode ser string (modelo base) ou tuple (repo_id, subfolder) p/ fine-tuned.
57
+ """
58
+ if isinstance(model_ref, tuple):
59
+ tok, mdl = load_hf_encoder(*model_ref)
60
+ else:
61
+ # mantΓ©m o tokenizer apropriado
62
+ base_tok = "Rostlab/prot_bert" if "prot_bert" in model_ref else model_ref
63
+ tok, mdl = load_hf_encoder(model_ref, base_tok=base_tok)
64
+
65
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
66
+ vecs = []
67
  for p in parts:
68
+ tokens = tok(" ".join(p), return_tensors="pt", truncation=False)
69
  with torch.no_grad():
70
+ out = mdl(**{k: v.to(mdl.device) for k, v in tokens.items()})
71
+ vecs.append(out.last_hidden_state[:, 0, :].cpu().numpy())
72
  return np.mean(vecs, axis=0, keepdims=True)
73
 
74
  @st.cache_resource
75
  def load_go_info():
76
  obo_path = download_file("data/go.obo")
77
+ dag = GODag(obo_path, optional_attrs=["defn"])
78
  return {tid: (term.name, term.defn) for tid, term in dag.items()}
79
 
80
  GO_INFO = load_go_info()
 
92
  st.title("PrediΓ§Γ£o de FunΓ§Γ΅es Moleculares de ProteΓ­nas")
93
 
94
  st.markdown(
95
+ "<style> textarea { font-size: 0.9rem !important; } </style>",
 
 
96
  unsafe_allow_html=True,
97
  )
98
 
 
101
 
102
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” PARSE DE MÚLTIPLAS SEQUÊNCIAS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
103
  def parse_fasta_multiple(fasta_str):
104
+ entries, parsed = fasta_str.strip().split(">"), []
 
 
105
  for i, entry in enumerate(entries):
106
  if not entry.strip():
107
  continue
 
108
  lines = entry.strip().splitlines()
 
 
109
  if i > 0:
110
  header = lines[0].strip()
111
+ seq = "".join(lines[1:]).replace(" ", "").upper()
112
  else:
 
113
  header = f"Seq_{i+1}"
114
+ seq = "".join(lines).replace(" ", "").upper()
 
115
  if seq:
116
  parsed.append((header, seq))
117
  return parsed
118
 
119
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” INFERÊNCIA β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
120
  if predict_clicked:
121
  parsed_seqs = parse_fasta_multiple(fasta_input)
 
122
  if not parsed_seqs:
123
  st.warning("NΓ£o foi possΓ­vel encontrar nenhuma sequΓͺncia vΓ‘lida.")
124
  st.stop()
125
 
126
  for header, seq in parsed_seqs:
127
  with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
128
+ # β€”β€”β€” Embeddings β€”β€”β€” #
129
+ emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
130
+ emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
131
+ emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
132
 
133
+ # β€”β€”β€” PrediΓ§Γ΅es dos MLPs β€”β€”β€” #
134
  y_pb = mlp_pb.predict(emb_pb)
135
  y_bfd = mlp_bfd.predict(emb_bfd)
136
  y_esm = mlp_esm.predict(emb_esm)[:, :597]
137
+
138
+ # β€”β€”β€” Stacking β€”β€”β€” #
139
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
140
  y_ens = stacking.predict(X)
141
 
142
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” RESULTADOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
143
+ def mostrar_resultados(tag, y_pred):
144
  with st.expander(tag, expanded=True):
145
  hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
146
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
147
  if hits:
148
  for go_id in hits:
149
  name, defin = GO_INFO.get(go_id, ("β€” sem nome β€”", ""))
150
+ limp = re.sub(r'^\s*"?(.+?)"?\s*(\[[^\]]*\])?\s*$', r'\1', defin or "")
151
  st.write(f"**{go_id} β€” {name}**")
152
+ st.caption(limp)
153
  else:
154
  st.code("β€” nenhum β€”")
155
 
 
159
  name, _ = GO_INFO.get(go_id, ("", ""))
160
  st.write(f"{go_id} β€” {name} : {y_pred[0][idx]:.4f}")
161
 
162
+ # Mostrar apenas ensemble (descomenta se quiseres os individuais)
163
+ # mostrar_resultados(f"{header} β€” ProtBERT", y_pb)
164
+ # mostrar_resultados(f"{header} β€” ProtBERT-BFD", y_bfd)
165
+ # mostrar_resultados(f"{header} β€” ESM-2", y_esm)
166
+ mostrar_resultados(header, y_ens)