melvinalves commited on
Commit
a08cc8f
Β·
verified Β·
1 Parent(s): 431635b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -63
app.py CHANGED
@@ -7,42 +7,31 @@ from transformers import AutoTokenizer, AutoModel
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIGURAÇÃO β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
- SPACE_ID = "melvinalves/protein_function_prediction"
12
- TOP_N = 10
13
- THRESH = 0.50 # limiar para listar GO terms
14
- CHUNK_PB = 512
15
- CHUNK_ESM = 1024
16
-
17
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS DE CACHE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
18
  @st.cache_resource
19
- def download_file(path_in_repo: str):
20
- return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path_in_repo)
21
 
22
  @st.cache_resource
23
- def load_keras(file_name: str):
24
- return load_model(download_file(f"models/{file_name}"), compile=False)
25
 
26
  @st.cache_resource
27
- def load_hf_encoder(model_name: str):
28
- tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
29
- mdl = AutoModel.from_pretrained(model_name)
30
  mdl.eval()
31
  return tok, mdl
32
 
33
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
34
- mlp_pb = load_keras("mlp_protbert.h5")
35
- mlp_bfd = load_keras("mlp_protbertbfd.h5")
36
- mlp_esm = load_keras("mlp_esm2.h5")
37
- stacking = load_keras("ensemble_stack.h5") # usa o nome que tiveres guardado
38
-
39
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” LABEL BINARIZER β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
40
- mlb = joblib.load(download_file("data/mlb_597.pkl"))
41
- GO_TERMS = mlb.classes_
42
-
43
- # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” EMBEDDINGS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
44
- def embed_seq(model_name: str, seq: str, chunk: int) -> np.ndarray:
45
- tok, mdl = load_hf_encoder(model_name)
46
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
47
  vecs = []
48
  for p in parts:
@@ -51,6 +40,15 @@ def embed_seq(model_name: str, seq: str, chunk: int) -> np.ndarray:
51
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
52
  return np.mean(vecs, axis=0, keepdims=True)
53
 
 
 
 
 
 
 
 
 
 
54
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
55
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
56
 
@@ -61,47 +59,42 @@ st.markdown(
61
  unsafe_allow_html=True,
62
  )
63
 
64
- fasta = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
 
65
 
66
- # ---------- BOTÃO ----------
67
- if fasta and st.button("Prever GO terms"):
68
- seq = "\n".join(l for l in fasta.splitlines() if not l.startswith(">"))
69
- seq = seq.replace(" ", "").replace("\n", "").upper()
70
 
 
 
71
  if not seq:
72
- st.warning("Por favor, insere uma sequΓͺncia vΓ‘lida.")
73
  st.stop()
74
 
75
- # 1) EMBEDDINGS
76
- st.write("⏳ A gerar embeddings…")
77
- emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
78
- emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
- emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
-
81
- # 2) PREDIÇÕES INDIVIDUAIS
82
- st.write("🧠 A fazer prediΓ§Γ΅es…")
83
- y_pb = mlp_pb.predict(emb_pb)
84
- y_bfd = mlp_bfd.predict(emb_bfd)
85
- y_esm = mlp_esm.predict(emb_esm)[:, :597] # corta 602 β†’ 597
86
-
87
- # 3) ENSEMBLE
88
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1, 1791)
89
- y_ens = stacking.predict(X_stack)
90
-
91
- # β€”β€”β€” FunΓ§Γ£o auxiliar para mostrar resultados β€”β€”β€”
92
- def show_results(label: str, y_pred):
93
- with st.expander(label, expanded=(label == "Ensemble (Stacking)")):
94
  hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
95
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
96
  st.code("\n".join(hits) if hits else "β€” nenhum β€”")
97
-
98
  st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
99
- top_idx = np.argsort(-y_pred[0])[:TOP_N]
100
- for i in top_idx:
101
- st.write(f"{GO_TERMS[i]} : {y_pred[0][i]:.4f}")
102
-
103
- # 4) OUTPUT
104
- show_results("ProtBERT (MLP)", y_pb)
105
- show_results("ProtBERT-BFD (MLP)", y_bfd)
106
- show_results("ESM-2 (MLP)", y_esm)
107
- show_results("Ensemble (Stacking)", y_ens)
 
7
  from huggingface_hub import hf_hub_download
8
  from keras.models import load_model
9
 
10
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CONFIG β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
11
+ SPACE_ID = "melvinalves/protein_function_prediction"
12
+ TOP_N = 10
13
+ THRESH = 0.50
14
+ CHUNK_PB = 512
15
+ CHUNK_ESM = 1024
16
+
17
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” HELPERS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
18
  @st.cache_resource
19
+ def download_file(path):
20
+ return hf_hub_download(repo_id=SPACE_ID, repo_type="space", filename=path)
21
 
22
  @st.cache_resource
23
+ def load_keras(name):
24
+ return load_model(download_file(f"models/{name}"), compile=False)
25
 
26
  @st.cache_resource
27
+ def load_hf_encoder(model):
28
+ tok = AutoTokenizer.from_pretrained(model, do_lower_case=False)
29
+ mdl = AutoModel.from_pretrained(model)
30
  mdl.eval()
31
  return tok, mdl
32
 
33
+ def embed_seq(model, seq, chunk):
34
+ tok, mdl = load_hf_encoder(model)
 
 
 
 
 
 
 
 
 
 
 
35
  parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
36
  vecs = []
37
  for p in parts:
 
40
  vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
41
  return np.mean(vecs, axis=0, keepdims=True)
42
 
43
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” CARGA MODELOS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
44
+ mlp_pb = load_keras("mlp_protbert.h5")
45
+ mlp_bfd = load_keras("mlp_protbertbfd.h5")
46
+ mlp_esm = load_keras("mlp_esm2.h5")
47
+ stacking = load_keras("ensemble_stack.h5") # usa o nome real aqui
48
+
49
+ mlb = joblib.load(download_file("data/mlb_597.pkl"))
50
+ GO = mlb.classes_
51
+
52
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” #
53
  st.title("πŸ”¬ PrediΓ§Γ£o de FunΓ§Γ΅es de ProteΓ­nas")
54
 
 
59
  unsafe_allow_html=True,
60
  )
61
 
62
+ fasta_input = st.text_area("Insere a sequΓͺncia FASTA:", height=200)
63
+ predict_clicked = st.button("Prever GO terms")
64
 
65
+ if predict_clicked:
 
 
 
66
 
67
+ # β€”β€”β€” ValidaΓ§Γ£o mΓ­nima β€”β€”β€”
68
+ seq = "\n".join(l for l in fasta_input.splitlines() if not l.startswith(">")).replace(" ", "").upper()
69
  if not seq:
70
+ st.warning("Por favor, insere primeiro uma sequΓͺncia FASTA vΓ‘lida.")
71
  st.stop()
72
 
73
+ # β€”β€”β€” 1) EMBEDDINGS β€”β€”β€”
74
+ with st.spinner("⏳ A gerar embeddings…"):
75
+ emb_pb = embed_seq("Rostlab/prot_bert", seq, CHUNK_PB)
76
+ emb_bfd = embed_seq("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
77
+ emb_esm = embed_seq("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
78
+
79
+ # β€”β€”β€” 2) PREDIÇÕES β€”β€”β€”
80
+ with st.spinner("🧠 A fazer prediΓ§Γ΅es…"):
81
+ y_pb = mlp_pb.predict(emb_pb)
82
+ y_bfd = mlp_bfd.predict(emb_bfd)
83
+ y_esm = mlp_esm.predict(emb_esm)[:, :597]
84
+ X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
85
+ y_ens = stacking.predict(X)
86
+
87
+ # β€”β€”β€” 3) MOSTRAR RESULTADOS β€”β€”β€”
88
+ def mostrar(tag, y_pred):
89
+ with st.expander(tag, expanded=(tag == "Ensemble (Stacking)")):
 
 
90
  hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
91
  st.markdown(f"**GO terms com prob β‰₯ {THRESH}**")
92
  st.code("\n".join(hits) if hits else "β€” nenhum β€”")
 
93
  st.markdown(f"**Top {TOP_N} GO terms mais provΓ‘veis**")
94
+ for i in np.argsort(-y_pred[0])[:TOP_N]:
95
+ st.write(f"{GO[i]} : {y_pred[0][i]:.4f}")
96
+
97
+ mostrar("ProtBERT (MLP)", y_pb)
98
+ mostrar("ProtBERT-BFD (MLP)", y_bfd)
99
+ mostrar("ESM-2 (MLP)", y_esm)
100
+ mostrar("Ensemble (Stacking)", y_ens)