Update app.py
Browse files
app.py
CHANGED
|
@@ -24,6 +24,20 @@ def load_hf_model(name):
|
|
| 24 |
model.eval()
|
| 25 |
return tokenizer, model
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# ---------- Função para gerar embedding por chunk ----------
|
| 28 |
def embed_sequence(model_name, seq, chunk_size):
|
| 29 |
tokenizer, model = load_hf_model(model_name)
|
|
@@ -44,16 +58,6 @@ def embed_sequence(model_name, seq, chunk_size):
|
|
| 44 |
|
| 45 |
return np.mean(embeddings, axis=0, keepdims=True)
|
| 46 |
|
| 47 |
-
# ---------- Carregar modelos ----------
|
| 48 |
-
mlp_pb = load_model(os.path.join(MODELS_DIR, "mlp_protbert.keras"), compile=False)
|
| 49 |
-
mlp_bfd = load_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.keras"), compile=False)
|
| 50 |
-
mlp_esm = load_model(os.path.join(MODELS_DIR, "mlp_esm2.keras"), compile=False)
|
| 51 |
-
stacking = load_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.keras"), compile=False)
|
| 52 |
-
|
| 53 |
-
# ---------- Carregar MultiLabelBinarizer ----------
|
| 54 |
-
mlb = joblib.load(MLB_PATH)
|
| 55 |
-
go_terms = mlb.classes_
|
| 56 |
-
|
| 57 |
# ---------- Interface Streamlit ----------
|
| 58 |
st.title("Predição de Funções de Proteínas")
|
| 59 |
|
|
@@ -68,17 +72,17 @@ if st.button("Prever GO terms"):
|
|
| 68 |
if not seq:
|
| 69 |
st.warning("Por favor, insere uma sequência válida.")
|
| 70 |
else:
|
| 71 |
-
st.write("A gerar embeddings
|
| 72 |
|
| 73 |
emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
|
| 74 |
emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
|
| 75 |
emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
|
| 76 |
|
| 77 |
-
st.write("A fazer predições
|
| 78 |
|
| 79 |
-
y_pb = mlp_pb.predict(emb_pb)
|
| 80 |
-
y_bfd = mlp_bfd.predict(emb_bfd)
|
| 81 |
-
y_esm = mlp_esm.predict(emb_esm)
|
| 82 |
|
| 83 |
X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
|
| 84 |
y_pred = stacking.predict(X_stack)
|
|
@@ -94,4 +98,3 @@ if st.button("Prever GO terms"):
|
|
| 94 |
top_idx = np.argsort(-y_pred[0])[:TOP_N]
|
| 95 |
for i in top_idx:
|
| 96 |
st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
|
| 97 |
-
|
|
|
|
| 24 |
model.eval()
|
| 25 |
return tokenizer, model
|
| 26 |
|
| 27 |
+
# ---------- Cache dos modelos locais ----------
|
| 28 |
+
@st.cache_resource
|
| 29 |
+
def load_local_model(path):
|
| 30 |
+
return load_model(path, compile=False)
|
| 31 |
+
|
| 32 |
+
mlp_pb = load_local_model(os.path.join(MODELS_DIR, "mlp_protbert.keras"))
|
| 33 |
+
mlp_bfd = load_local_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.keras"))
|
| 34 |
+
mlp_esm = load_local_model(os.path.join(MODELS_DIR, "mlp_esm2.keras"))
|
| 35 |
+
stacking = load_local_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.keras"))
|
| 36 |
+
|
| 37 |
+
# ---------- Carregar MultiLabelBinarizer ----------
|
| 38 |
+
mlb = joblib.load(MLB_PATH)
|
| 39 |
+
go_terms = mlb.classes_
|
| 40 |
+
|
| 41 |
# ---------- Função para gerar embedding por chunk ----------
|
| 42 |
def embed_sequence(model_name, seq, chunk_size):
|
| 43 |
tokenizer, model = load_hf_model(model_name)
|
|
|
|
| 58 |
|
| 59 |
return np.mean(embeddings, axis=0, keepdims=True)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# ---------- Interface Streamlit ----------
|
| 62 |
st.title("Predição de Funções de Proteínas")
|
| 63 |
|
|
|
|
| 72 |
if not seq:
|
| 73 |
st.warning("Por favor, insere uma sequência válida.")
|
| 74 |
else:
|
| 75 |
+
st.write("🔄 A gerar embeddings...")
|
| 76 |
|
| 77 |
emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
|
| 78 |
emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
|
| 79 |
emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
|
| 80 |
|
| 81 |
+
st.write("🧠 A fazer predições com cada modelo...")
|
| 82 |
|
| 83 |
+
y_pb = mlp_pb.predict(emb_pb)
|
| 84 |
+
y_bfd = mlp_bfd.predict(emb_bfd)
|
| 85 |
+
y_esm = mlp_esm.predict(emb_esm)
|
| 86 |
|
| 87 |
X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
|
| 88 |
y_pred = stacking.predict(X_stack)
|
|
|
|
| 98 |
top_idx = np.argsort(-y_pred[0])[:TOP_N]
|
| 99 |
for i in top_idx:
|
| 100 |
st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
|
|
|