melvinalves commited on
Commit
49f3a1b
·
verified ·
1 Parent(s): 5ab6b3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -97
app.py CHANGED
@@ -1,97 +1,97 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- from transformers import AutoTokenizer, AutoModel
5
- from tensorflow.keras.models import load_model
6
- import joblib
7
- import streamlit as st
8
-
9
- # ---------- Caminhos ----------
10
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
- MODELS_DIR = os.path.join(BASE_DIR, "models")
12
- MLB_PATH = os.path.join(BASE_DIR, "data", "mlb_597.pkl")
13
-
14
- # ---------- Parâmetros ----------
15
- TOP_N = 10
16
- CHUNK_PB = 512
17
- CHUNK_ESM = 1024
18
-
19
- # ---------- Cache dos modelos HuggingFace ----------
20
- @st.cache_resource
21
- def load_hf_model(name):
22
- tokenizer = AutoTokenizer.from_pretrained(name, do_lower_case=False)
23
- model = AutoModel.from_pretrained(name)
24
- model.eval()
25
- return tokenizer, model
26
-
27
- # ---------- Função para gerar embedding por chunk ----------
28
- def embed_sequence(model_name, seq, chunk_size):
29
- tokenizer, model = load_hf_model(model_name)
30
-
31
- def format_seq(s):
32
- return " ".join(list(s))
33
-
34
- chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]
35
- embeddings = []
36
-
37
- for chunk in chunks:
38
- formatted = format_seq(chunk)
39
- inputs = tokenizer(formatted, return_tensors="pt", truncation=True)
40
- with torch.no_grad():
41
- outputs = model(**inputs)
42
- cls = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
43
- embeddings.append(cls)
44
-
45
- return np.mean(embeddings, axis=0, keepdims=True)
46
-
47
- # ---------- Carregar modelos ----------
48
- mlp_pb = load_model(os.path.join(MODELS_DIR, "mlp_protbert.h5"), compile=False)
49
- mlp_bfd = load_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.h5"), compile=False)
50
- mlp_esm = load_model(os.path.join(MODELS_DIR, "mlp_esm2.h5"), compile=False)
51
- stacking = load_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.h5"), compile=False)
52
-
53
- # ---------- Carregar MultiLabelBinarizer ----------
54
- mlb = joblib.load(MLB_PATH)
55
- go_terms = mlb.classes_
56
-
57
- # ---------- Interface Streamlit ----------
58
- st.title("Predição de Funções de Proteínas")
59
-
60
- seq = st.text_area("Insere a sequência FASTA:", height=200)
61
-
62
- # Limpar sequência: remover cabeçalhos (">") e espaços/quebras
63
- if seq:
64
- seq = "\n".join([line for line in seq.splitlines() if not line.startswith(">")])
65
- seq = seq.replace(" ", "").replace("\n", "").strip()
66
-
67
- if st.button("Prever GO terms"):
68
- if not seq:
69
- st.warning("Por favor, insere uma sequência válida.")
70
- else:
71
- st.write("A gerar embeddings por chunks...")
72
-
73
- emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
74
- emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
75
- emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
76
-
77
- st.write("A fazer predições base...")
78
-
79
- y_pb = mlp_pb.predict(emb_pb)[:, :597]
80
- y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]
81
- y_esm = mlp_esm.predict(emb_esm)[:, :597]
82
-
83
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
84
- y_pred = stacking.predict(X_stack)
85
-
86
- st.subheader("GO terms com probabilidade ≥ 0.5:")
87
- predicted = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
88
- if predicted:
89
- st.code("\n".join(predicted))
90
- else:
91
- st.info("Nenhum GO term com probabilidade ≥ 0.5.")
92
-
93
- st.subheader(f"Top {TOP_N} GO terms mais prováveis:")
94
- top_idx = np.argsort(-y_pred[0])[:TOP_N]
95
- for i in top_idx:
96
- st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
97
-
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from tensorflow.keras.models import load_model
6
+ import joblib
7
+ import streamlit as st
8
+
9
+ # ---------- Caminhos ----------
10
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ MODELS_DIR = os.path.join(BASE_DIR, "models")
12
+ MLB_PATH = os.path.join(BASE_DIR, "data", "mlb_597.pkl")
13
+
14
+ # ---------- Parâmetros ----------
15
+ TOP_N = 10
16
+ CHUNK_PB = 512
17
+ CHUNK_ESM = 1024
18
+
19
+ # ---------- Cache dos modelos HuggingFace ----------
20
+ @st.cache_resource
21
+ def load_hf_model(name):
22
+ tokenizer = AutoTokenizer.from_pretrained(name, do_lower_case=False)
23
+ model = AutoModel.from_pretrained(name)
24
+ model.eval()
25
+ return tokenizer, model
26
+
27
+ # ---------- Função para gerar embedding por chunk ----------
28
+ def embed_sequence(model_name, seq, chunk_size):
29
+ tokenizer, model = load_hf_model(model_name)
30
+
31
+ def format_seq(s):
32
+ return " ".join(list(s))
33
+
34
+ chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]
35
+ embeddings = []
36
+
37
+ for chunk in chunks:
38
+ formatted = format_seq(chunk)
39
+ inputs = tokenizer(formatted, return_tensors="pt", truncation=True)
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+ cls = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
43
+ embeddings.append(cls)
44
+
45
+ return np.mean(embeddings, axis=0, keepdims=True)
46
+
47
+ # ---------- Carregar modelos ----------
48
+ mlp_pb = load_model(os.path.join(MODELS_DIR, "mlp_protbert.keras"), compile=False)
49
+ mlp_bfd = load_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.keras"), compile=False)
50
+ mlp_esm = load_model(os.path.join(MODELS_DIR, "mlp_esm2.keras"), compile=False)
51
+ stacking = load_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.keras"), compile=False)
52
+
53
+ # ---------- Carregar MultiLabelBinarizer ----------
54
+ mlb = joblib.load(MLB_PATH)
55
+ go_terms = mlb.classes_
56
+
57
+ # ---------- Interface Streamlit ----------
58
+ st.title("Predição de Funções de Proteínas")
59
+
60
+ seq = st.text_area("Insere a sequência FASTA:", height=200)
61
+
62
+ # Limpar sequência: remover cabeçalhos (">") e espaços/quebras
63
+ if seq:
64
+ seq = "\n".join([line for line in seq.splitlines() if not line.startswith(">")])
65
+ seq = seq.replace(" ", "").replace("\n", "").strip()
66
+
67
+ if st.button("Prever GO terms"):
68
+ if not seq:
69
+ st.warning("Por favor, insere uma sequência válida.")
70
+ else:
71
+ st.write("A gerar embeddings por chunks...")
72
+
73
+ emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
74
+ emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
75
+ emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
76
+
77
+ st.write("A fazer predições base...")
78
+
79
+ y_pb = mlp_pb.predict(emb_pb)[:, :597]
80
+ y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]
81
+ y_esm = mlp_esm.predict(emb_esm)[:, :597]
82
+
83
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
84
+ y_pred = stacking.predict(X_stack)
85
+
86
+ st.subheader("GO terms com probabilidade ≥ 0.5:")
87
+ predicted = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
88
+ if predicted:
89
+ st.code("\n".join(predicted))
90
+ else:
91
+ st.info("Nenhum GO term com probabilidade ≥ 0.5.")
92
+
93
+ st.subheader(f"Top {TOP_N} GO terms mais prováveis:")
94
+ top_idx = np.argsort(-y_pred[0])[:TOP_N]
95
+ for i in top_idx:
96
+ st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
97
+