melvinalves commited on
Commit
4990c94
·
verified ·
1 Parent(s): 0104888

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -78
app.py CHANGED
@@ -1,100 +1,97 @@
1
  import os
2
  import numpy as np
3
  import torch
 
 
4
  from transformers import AutoTokenizer, AutoModel
 
5
  from tensorflow.keras.models import load_model
6
- import joblib
7
- import streamlit as st
8
 
9
- # ---------- Caminhos ----------
10
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
- MODELS_DIR = os.path.join(BASE_DIR, "models")
12
- MLB_PATH = os.path.join(BASE_DIR, "data", "mlb_597.pkl")
13
 
14
- # ---------- Parâmetros ----------
15
- TOP_N = 10
16
- CHUNK_PB = 512
17
- CHUNK_ESM = 1024
18
 
19
- # ---------- Cache dos modelos HuggingFace ----------
20
  @st.cache_resource
21
- def load_hf_model(name):
22
- tokenizer = AutoTokenizer.from_pretrained(name, do_lower_case=False)
23
- model = AutoModel.from_pretrained(name)
24
- model.eval()
25
- return tokenizer, model
 
 
26
 
27
- # ---------- Cache dos modelos locais ----------
28
  @st.cache_resource
29
- def load_local_model(path):
30
- return load_model(path, compile=False)
31
-
32
- mlp_pb = load_local_model(os.path.join(MODELS_DIR, "mlp_protbert.keras"))
33
- mlp_bfd = load_local_model(os.path.join(MODELS_DIR, "mlp_protbertbfd.keras"))
34
- mlp_esm = load_local_model(os.path.join(MODELS_DIR, "mlp_esm2.keras"))
35
- stacking = load_local_model(os.path.join(MODELS_DIR, "modelo_ensemble_stack.keras"))
36
 
37
- # ---------- Carregar MultiLabelBinarizer ----------
38
- mlb = joblib.load(MLB_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
39
  go_terms = mlb.classes_
40
 
41
- # ---------- Função para gerar embedding por chunk ----------
42
- def embed_sequence(model_name, seq, chunk_size):
43
- tokenizer, model = load_hf_model(model_name)
 
 
 
 
 
 
 
 
44
 
45
- def format_seq(s):
46
- return " ".join(list(s))
47
 
48
- chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]
49
- embeddings = []
50
 
51
- for chunk in chunks:
52
- formatted = format_seq(chunk)
53
- inputs = tokenizer(formatted, return_tensors="pt", truncation=True)
54
- with torch.no_grad():
55
- outputs = model(**inputs)
56
- cls = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
57
- embeddings.append(cls)
58
 
59
- return np.mean(embeddings, axis=0, keepdims=True)
 
 
 
60
 
61
- # ---------- Interface Streamlit ----------
62
- st.title("Predição de Funções de Proteínas")
 
 
63
 
64
- seq = st.text_area("Insere a sequência FASTA:", height=200)
 
65
 
66
- # Limpar sequência: remover cabeçalhos (">") e espaços/quebras
67
- if seq:
68
- seq = "\n".join([line for line in seq.splitlines() if not line.startswith(">")])
69
- seq = seq.replace(" ", "").replace("\n", "").strip()
 
 
 
 
70
 
71
- if st.button("Prever GO terms"):
72
- if not seq:
73
- st.warning("Por favor, insere uma sequência válida.")
74
- else:
75
- st.write("🔄 A gerar embeddings...")
76
-
77
- emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
78
- emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
- emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
-
81
- st.write("🧠 A fazer predições com cada modelo...")
82
-
83
- y_pb = mlp_pb.predict(emb_pb)
84
- y_bfd = mlp_bfd.predict(emb_bfd)
85
- y_esm = mlp_esm.predict(emb_esm)
86
-
87
- X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
88
- y_pred = stacking.predict(X_stack)
89
-
90
- st.subheader("GO terms com probabilidade ≥ 0.5:")
91
- predicted = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
92
- if predicted:
93
- st.code("\n".join(predicted))
94
- else:
95
- st.info("Nenhum GO term com probabilidade ≥ 0.5.")
96
-
97
- st.subheader(f"Top {TOP_N} GO terms mais prováveis:")
98
- top_idx = np.argsort(-y_pred[0])[:TOP_N]
99
- for i in top_idx:
100
- st.write(f"{go_terms[i]} : {y_pred[0][i]:.4f}")
 
1
  import os
2
  import numpy as np
3
  import torch
4
+ import streamlit as st
5
+ import joblib
6
  from transformers import AutoTokenizer, AutoModel
7
+ from huggingface_hub import hf_hub_download
8
  from tensorflow.keras.models import load_model
 
 
9
 
10
+ # ----------- Config Space -----------
11
+ SPACE_REPO = "melvinalves/protein_function_prediction" # <- o teu Space
12
+ MODELS_DIR = "models"
13
+ DATA_DIR = "data"
14
 
15
+ TOP_N = 10
16
+ CHUNK_PB = 512
17
+ CHUNK_ESM = 1024
 
18
 
19
+ # ----------- Helpers -----------
20
  @st.cache_resource
21
+ def hf_cached(path_inside_repo: str):
22
+ """Faz download (uma vez) e devolve caminho local."""
23
+ return hf_hub_download(
24
+ repo_id=SPACE_REPO,
25
+ repo_type="space",
26
+ filename=path_inside_repo,
27
+ )
28
 
 
29
  @st.cache_resource
30
+ def load_hf_model(model_name):
31
+ tok = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
32
+ mdl = AutoModel.from_pretrained(model_name); mdl.eval()
33
+ return tok, mdl
 
 
 
34
 
35
+ @st.cache_resource
36
+ def load_local_model(file_name):
37
+ local_path = hf_cached(f"{MODELS_DIR}/{file_name}")
38
+ return load_model(local_path, compile=False)
39
+
40
+ # ----------- Carregar modelos (.keras) -----------
41
+ mlp_pb = load_local_model("mlp_protbert.keras")
42
+ mlp_bfd = load_local_model("mlp_protbertbfd.keras")
43
+ mlp_esm = load_local_model("mlp_esm2.keras")
44
+ stacking = load_local_model("ensemble_stacking.keras")
45
+
46
+ # ----------- MultiLabelBinarizer -----------
47
+ mlb_path = hf_cached(f"{DATA_DIR}/mlb_597.pkl")
48
+ mlb = joblib.load(mlb_path)
49
  go_terms = mlb.classes_
50
 
51
+ # ----------- Embedding por chunks -----------
52
+ def embed_sequence(model_name: str, seq: str, chunk: int) -> np.ndarray:
53
+ tok, mdl = load_hf_model(model_name)
54
+ fmt = lambda s: " ".join(list(s))
55
+ parts = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]
56
+ vecs = []
57
+ for p in parts:
58
+ with torch.no_grad():
59
+ out = mdl(**tok(fmt(p), return_tensors="pt", truncation=True))
60
+ vecs.append(out.last_hidden_state[:, 0, :].squeeze().numpy())
61
+ return np.mean(vecs, axis=0, keepdims=True)
62
 
63
+ # ----------- UI -----------
64
+ st.title("Predição de Funções de Proteínas 🔬")
65
 
66
+ fa_input = st.text_area("Insere a sequência FASTA:", height=200)
 
67
 
68
+ if fa_input and st.button("Prever GO terms"):
69
+ # Limpa FASTA
70
+ seq = "\n".join(l for l in fa_input.splitlines() if not l.startswith(">"))
71
+ seq = seq.replace(" ", "").replace("\n", "").upper()
72
+ if not seq:
73
+ st.warning("Sequência vazia.")
74
+ st.stop()
75
 
76
+ st.write("🔄 A gerar embeddings…")
77
+ emb_pb = embed_sequence("Rostlab/prot_bert", seq, CHUNK_PB)
78
+ emb_bfd = embed_sequence("Rostlab/prot_bert_bfd", seq, CHUNK_PB)
79
+ emb_esm = embed_sequence("facebook/esm2_t33_650M_UR50D", seq, CHUNK_ESM)
80
 
81
+ st.write("🧠 A fazer predições…")
82
+ y_pb = mlp_pb.predict(emb_pb)
83
+ y_bfd = mlp_bfd.predict(emb_bfd)
84
+ y_esm = mlp_esm.predict(emb_esm)[:, :597] # garante 597 colunas
85
 
86
+ X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
87
+ y_pred = stacking.predict(X_stack)
88
 
89
+ # ----------- Output -----------
90
+ st.subheader("GO terms com probabilidade ≥ 0.5")
91
+ hits = mlb.inverse_transform((y_pred >= 0.5).astype(int))[0]
92
+ st.code("\n".join(hits) or " nenhum —")
93
+
94
+ st.subheader(f"Top {TOP_N} GO terms mais prováveis")
95
+ for idx in np.argsort(-y_pred[0])[:TOP_N]:
96
+ st.write(f"{go_terms[idx]} : {y_pred[0][idx]:.4f}")
97