melvinalves commited on
Commit
0104888
·
verified ·
1 Parent(s): 4a83aa1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -24,6 +24,20 @@ def load_hf_model(name):
24
  model.eval()
25
  return tokenizer, model
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ---------- Função para gerar embedding por chunk ----------
28
  def embed_sequence(model_name, seq, chunk_size):
29
  tokenizer, model = load_hf_model(model_name)
@@ -44,16 +58,6 @@ def embed_sequence(model_name, seq, chunk_size):
44
 
45
  return np.mean(embeddings, axis=0, keepdims=True)
46
 
47
- # ---------- Carregar modelos ----------
48
- mlp_pb = load_model(os.path.join(MODELS_DIR, "mlp_protbert.keras"), compile=False)
49
- mlp_bfd = load_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.keras"), compile=False)
50
- mlp_esm = load_model(os.path.join(MODELS_DIR, "mlp_esm2.keras"), compile=False)
51
- stacking = load_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.keras"), compile=False)
52
-
53
- # ---------- Carregar MultiLabelBinarizer ----------
54
- mlb = joblib.load(MLB_PATH)
55
- go_terms = mlb.classes_
56
-
57
  # ---------- Interface Streamlit ----------
58
  st.title("Predição de Funções de Proteínas")
59
 
@@ -68,17 +72,17 @@ if st.button("Prever GO terms"):
68
  if not seq:
69
  st.warning("Por favor, insere uma sequência válida.")
70
  else:
71
- st.write("A gerar embeddings por chunks...")
72
 
73
  emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
74
  emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
75
  emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
76
 
77
- st.write("A fazer predições base...")
78
 
79
- y_pb = mlp_pb.predict(emb_pb)[:, :597]
80
- y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]
81
- y_esm = mlp_esm.predict(emb_esm)[:, :597]
82
 
83
  X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
84
  y_pred = stacking.predict(X_stack)
@@ -94,4 +98,3 @@ if st.button("Prever GO terms"):
94
  top_idx = np.argsort(-y_pred[0])[:TOP_N]
95
  for i in top_idx:
96
  st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
97
-
 
24
  model.eval()
25
  return tokenizer, model
26
 
27
+ # ---------- Cache dos modelos locais ----------
28
+ @st.cache_resource
29
+ def load_local_model(path):
30
+ return load_model(path, compile=False)
31
+
32
+ mlp_pb = load_local_model(os.path.join(MODELS_DIR, "mlp_protbert.keras"))
33
+ mlp_bfd = load_local_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.keras"))
34
+ mlp_esm = load_local_model(os.path.join(MODELS_DIR, "mlp_esm2.keras"))
35
+ stacking = load_local_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.keras"))
36
+
37
+ # ---------- Carregar MultiLabelBinarizer ----------
38
+ mlb = joblib.load(MLB_PATH)
39
+ go_terms = mlb.classes_
40
+
41
  # ---------- Função para gerar embedding por chunk ----------
42
  def embed_sequence(model_name, seq, chunk_size):
43
  tokenizer, model = load_hf_model(model_name)
 
58
 
59
  return np.mean(embeddings, axis=0, keepdims=True)
60
 
 
 
 
 
 
 
 
 
 
 
61
  # ---------- Interface Streamlit ----------
62
  st.title("Predição de Funções de Proteínas")
63
 
 
72
  if not seq:
73
  st.warning("Por favor, insere uma sequência válida.")
74
  else:
75
+ st.write("🔄 A gerar embeddings...")
76
 
77
  emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
78
  emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
  emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
+ st.write("🧠 A fazer predições com cada modelo...")
82
 
83
+ y_pb = mlp_pb.predict(emb_pb)
84
+ y_bfd = mlp_bfd.predict(emb_bfd)
85
+ y_esm = mlp_esm.predict(emb_esm)
86
 
87
  X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
88
  y_pred = stacking.predict(X_stack)
 
98
  top_idx = np.argsort(-y_pred[0])[:TOP_N]
99
  for i in top_idx:
100
  st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")