melvinalves commited on
Commit
c971d5f
·
verified ·
1 Parent(s): 0725542

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -70
app.py CHANGED
@@ -14,10 +14,10 @@ login(os.environ["HF_TOKEN"])
14
 
15
  # ——————————————————— CONFIG ——————————————————— #
16
  SPACE_ID = "melvinalves/protein_function_prediction"
17
- TOP_N = 20 # top-20 mais prováveis
18
  THRESH = 0.37
19
- CHUNK_PB = 512 # janela ProtBERT / ProtBERT-BFD
20
- CHUNK_ESM = 1024 # janela ESM-2
21
 
22
  # repositórios HF
23
  FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
@@ -41,9 +41,9 @@ def load_keras(name):
41
  def load_hf_encoder(repo_id, subfolder=None, base_tok=None):
42
  """
43
  • repo_id : repositório HF ou caminho local
44
- • subfolder : subpasta dos pesos (None se não houver)
45
- • base_tok : repo do tokenizer (None usa repo_id)
46
- Converte tf_model.h5 → PyTorch on-the-fly (from_tf=True).
47
  """
48
  if base_tok is None:
49
  base_tok = repo_id
@@ -59,8 +59,7 @@ def load_hf_encoder(repo_id, subfolder=None, base_tok=None):
59
  # ---------- extrair embedding ----------
60
  def embed_seq(model_ref, seq, chunk):
61
  """
62
- model_ref = string (modelo base) OU tuple(repo_id, subfolder) (modelo fine-tuned)
63
- Retorna embedding CLS médio (caso a sequência seja dividida em chunks).
64
  """
65
  if isinstance(model_ref, tuple): # ProtBERT fine-tuned
66
  repo_id, subf = model_ref
@@ -100,28 +99,65 @@ GO = mlb.classes_
100
  st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas",
101
  page_icon="🧬", layout="centered")
102
 
103
- # CSS global : fundo branco, texto preto, textarea branca + traço colunas
104
  st.markdown(
105
  """
106
  <style>
107
- body, .stApp { background:#FFFFFF !important; color:#000000 !important; }
108
- textarea { background:#FFFFFF !important; color:#000000 !important;
109
- font-size:0.9rem !important; }
110
- /* traço vertical entre as duas colunas (segunda coluna) */
111
- div[data-testid="column"]:nth-of-type(3) {
112
- border-left:1px solid #e0e0e0;
113
- padding-left:1rem;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  }
115
- .block-container { padding-top:1.5rem; }
116
  </style>
117
  """,
118
  unsafe_allow_html=True
119
  )
120
 
121
- # Logo (coloca logo.png na raiz do Space)
122
  LOGO_PATH = "logo.png"
123
  if os.path.exists(LOGO_PATH):
124
- st.image(LOGO_PATH, width=180)
125
 
126
  st.title("Predição de Funções Moleculares de Proteínas (GO:MF)")
127
 
@@ -131,67 +167,71 @@ predict_clicked = st.button("Prever GO terms")
131
  # ——————————————————— PARSE DE MÚLTIPLAS SEQUÊNCIAS ——————————————————— #
132
  def parse_fasta_multiple(fasta_str):
133
  """
134
- Devolve lista de (header, seq) a partir de texto FASTA possivelmente múltiplo.
135
- Suporta bloco inicial sem '>'.
136
  """
137
  entries, parsed = fasta_str.strip().split(">"), []
138
  for i, entry in enumerate(entries):
139
  if not entry.strip():
140
  continue
141
  lines = entry.strip().splitlines()
142
- if i > 0: # bloco típico FASTA
143
  header = lines[0].strip()
144
- seq = "".join(lines[1:]).replace(" ", "").upper()
145
- else: # sequência sem '>'
146
  header = f"Seq_{i+1}"
147
- seq = "".join(lines).replace(" ", "").upper()
148
  if seq:
149
  parsed.append((header, seq))
150
  return parsed
151
 
152
- # ——————————————————— FUNÇÕES AUXILIARES DE LAYOUT ——————————————————— #
153
  def go_link(go_id, name=""):
154
- """Cria link para página do GO term (QuickGO)."""
155
- url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}"
156
  label = f"{go_id} — {name}" if name else go_id
157
  return f"[{label}]({url})"
158
 
159
- def prot_link(header):
160
- """Tenta gerar link para UniProt usando o primeiro token do header."""
161
  pid = header.split()[0]
162
- url = f"https://www.uniprot.org/uniprotkb/{pid}"
163
- return f"[{header}]({url})"
164
-
165
- # ——————————————————— FUNÇÃO PRINCIPAL DE RESULTADOS ——————————————————— #
166
- def mostrar(tag, y_pred):
167
- """Mostra resultados em duas colunas separadas por traço."""
168
- # 3 colunas: esquerda | traço (muito estreito) | direita
169
- col1, col_mid, col2 = st.columns([1, 0.04, 1])
170
-
171
- with col1:
172
- st.markdown(f"**GO terms com prob ≥ {THRESH}**")
173
- hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
174
- if hits:
175
- for go_id in hits:
176
- name, defin = GO_INFO.get(go_id, ("— sem nome —", ""))
177
- defin = re.sub(r'^\\s*\"?(.+?)\"?\\s*(\\[[^\\]]*\\])?\\s*$', r'\\1',
178
- defin or "")
179
- st.markdown(f"- {go_link(go_id, name)} ")
180
- if defin:
181
- st.caption(defin)
182
- else:
183
- st.code("— nenhum —")
184
-
185
- # coluna do meio já tem a linha (CSS) — fica vazia
186
- with col_mid:
187
- st.write("")
188
-
189
- with col2:
190
- st.markdown(f"**Top {TOP_N} GO terms mais prováveis**")
191
- for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], start=1):
192
- go_id = GO[idx]
193
- name, _ = GO_INFO.get(go_id, ("", ""))
194
- st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}")
 
 
 
 
 
 
 
195
 
196
  # ——————————————————— INFERÊNCIA ——————————————————— #
197
  if predict_clicked:
@@ -202,24 +242,23 @@ if predict_clicked:
202
 
203
  for header, seq in parsed_seqs:
204
  with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
205
- # ———————————— EMBEDDINGS ———————————— #
206
  emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
207
  emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
208
  emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
209
 
210
- # ———————————— PREDIÇÕES MLPs ———————————— #
211
  y_pb = mlp_pb.predict(emb_pb)
212
  y_bfd = mlp_bfd.predict(emb_bfd)
213
- y_esm = mlp_esm.predict(emb_esm)[:, :597] # alinhar nº termos
214
 
215
- # ———————————— STACKING ———————————— #
216
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
217
  y_ens = stacking.predict(X)
218
 
219
- st.markdown(f"### {prot_link(header)}", unsafe_allow_html=True)
220
- mostrar("", y_ens)
221
 
222
- # ——————————————————— LISTA COMPLETA DE TERMOS SUPORTADOS ——————————————————— #
223
  with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False):
224
  cols = st.columns(3)
225
  for i, go_id in enumerate(GO):
 
14
 
15
  # ——————————————————— CONFIG ——————————————————— #
16
  SPACE_ID = "melvinalves/protein_function_prediction"
17
+ TOP_N = 20 # top-20 mais prováveis
18
  THRESH = 0.37
19
+ CHUNK_PB = 512 # janela ProtBERT / ProtBERT-BFD
20
+ CHUNK_ESM = 1024 # janela ESM-2
21
 
22
  # repositórios HF
23
  FINETUNED_PB = ("melvinalves/FineTune", "fineTunedProtbert")
 
41
  def load_hf_encoder(repo_id, subfolder=None, base_tok=None):
42
  """
43
  • repo_id : repositório HF ou caminho local
44
+ • subfolder : subpasta onde vivem pesos/config (None se não houver)
45
+ • base_tok : repo para o tokenizer (None => usa repo_id)
46
+ Converte tf_model.h5 → PyTorch on-the-fly.
47
  """
48
  if base_tok is None:
49
  base_tok = repo_id
 
59
  # ---------- extrair embedding ----------
60
  def embed_seq(model_ref, seq, chunk):
61
  """
62
+ Retorna embedding CLS médio (divide sequência em chunks se necessário).
 
63
  """
64
  if isinstance(model_ref, tuple): # ProtBERT fine-tuned
65
  repo_id, subf = model_ref
 
99
  st.set_page_config(page_title="Predição de Funções Moleculares de Proteínas",
100
  page_icon="🧬", layout="centered")
101
 
102
+ # ---------- CSS global ----------
103
  st.markdown(
104
  """
105
  <style>
106
+ /* fundo branco + texto preto */
107
+ body, .stApp { background-color:#FFFFFF !important; color:#000000 !important; }
108
+ /* reduz top padding para o logo caber completo */
109
+ .block-container { padding-top:3rem; }
110
+
111
+ /* logo centralizado e afastado do topo */
112
+ img.logo-top { display:block; margin:0 auto 1.5rem; }
113
+
114
+ /* textarea/input brancos */
115
+ textarea, input, .stTextArea textarea, .stTextInput input {
116
+ background-color:#FFFFFF !important;
117
+ color:#000000 !important;
118
+ }
119
+
120
+ /* botões Streamlit */
121
+ .stButton>button {
122
+ background:#F8F9FA !important; /* cinza muito claro */
123
+ color:#000000 !important;
124
+ border:1px solid #007BFF !important;
125
+ border-radius:4px;
126
+ }
127
+ .stButton>button:hover {
128
+ background:#007BFF !important;
129
+ color:#FFFFFF !important;
130
+ }
131
+
132
+ /* botão UniProt custom */
133
+ .prot-btn {
134
+ background:#007BFF; color:#FFFFFF; border:none;
135
+ padding:6px 12px; border-radius:4px; cursor:pointer;
136
+ }
137
+ .prot-btn:hover {
138
+ background:#0056B3;
139
+ }
140
+
141
+ /* tiramos cores de hover vermelhas dos expanders; seta + texto azuis */
142
+ .st-expander:focus:not(:active) .streamlit-expanderHeader,
143
+ .streamlit-expanderHeader:hover {
144
+ color:#007BFF !important;
145
+ }
146
+
147
+ /* divisória vertical entre colunas */
148
+ div[data-testid='column']:nth-of-type(1) {
149
+ border-right:1px solid #DDDDDD;
150
+ padding-right:1rem;
151
  }
 
152
  </style>
153
  """,
154
  unsafe_allow_html=True
155
  )
156
 
157
+ # ---------- Logo ----------
158
  LOGO_PATH = "logo.png"
159
  if os.path.exists(LOGO_PATH):
160
+ st.markdown(f'<img src="app://{LOGO_PATH}" width="180" class="logo-top">', unsafe_allow_html=True)
161
 
162
  st.title("Predição de Funções Moleculares de Proteínas (GO:MF)")
163
 
 
167
  # ——————————————————— PARSE DE MÚLTIPLAS SEQUÊNCIAS ——————————————————— #
168
  def parse_fasta_multiple(fasta_str):
169
  """
170
+ Devolve lista (header, seq). Suporta bloco inicial sem '>'.
 
171
  """
172
  entries, parsed = fasta_str.strip().split(">"), []
173
  for i, entry in enumerate(entries):
174
  if not entry.strip():
175
  continue
176
  lines = entry.strip().splitlines()
177
+ if i > 0: # FASTA normal
178
  header = lines[0].strip()
179
+ seq = "".join(lines[1:]).replace(" ", "").upper()
180
+ else: # sequência sem '>'
181
  header = f"Seq_{i+1}"
182
+ seq = "".join(lines).replace(" ", "").upper()
183
  if seq:
184
  parsed.append((header, seq))
185
  return parsed
186
 
187
+ # ——————————————————— FUNÇÕES AUX COLUNA/LINKS ——————————————————— #
188
  def go_link(go_id, name=""):
189
+ url = f"https://www.ebi.ac.uk/QuickGO/term/{go_id}"
 
190
  label = f"{go_id} — {name}" if name else go_id
191
  return f"[{label}]({url})"
192
 
193
+ def prot_url(header):
 
194
  pid = header.split()[0]
195
+ return f"https://www.uniprot.org/uniprotkb/{pid}"
196
+
197
+ # ——————————————————— MOSTRAR RESULTADOS ——————————————————— #
198
+ def mostrar(header, y_pred):
199
+ """Expander com coluna-esq (hits) + coluna-dir (Top-20)."""
200
+ url = prot_url(header)
201
+
202
+ # botão UniProt fora do expander
203
+ st.markdown(
204
+ f'<a href="{url}" target="_blank">'
205
+ f'<button class="prot-btn">🔗 Ver UniProt ({header.split()[0]})</button>'
206
+ f'</a>',
207
+ unsafe_allow_html=True
208
+ )
209
+
210
+ with st.expander(header, expanded=True):
211
+ col1, col2 = st.columns(2)
212
+
213
+ # coluna 1 – hits acima do threshold
214
+ with col1:
215
+ st.markdown(f"**GO terms com prob ≥ {THRESH}**")
216
+ hits = mlb.inverse_transform((y_pred >= THRESH).astype(int))[0]
217
+ if hits:
218
+ for go_id in hits:
219
+ name, defin = GO_INFO.get(go_id, ("— sem nome —", ""))
220
+ defin = re.sub(r'^\\s*\"?(.+?)\"?\\s*(\\[[^\\]]*\\])?\\s*$', r'\\1',
221
+ defin or "")
222
+ st.markdown(f"- {go_link(go_id, name)}")
223
+ if defin:
224
+ st.caption(defin)
225
+ else:
226
+ st.code(" nenhum —")
227
+
228
+ # coluna 2 – top-20
229
+ with col2:
230
+ st.markdown(f"**Top {TOP_N} GO terms mais prováveis**")
231
+ for rank, idx in enumerate(np.argsort(-y_pred[0])[:TOP_N], start=1):
232
+ go_id = GO[idx]
233
+ name, _ = GO_INFO.get(go_id, ("", ""))
234
+ st.markdown(f"{rank}. {go_link(go_id, name)} : {y_pred[0][idx]:.4f}")
235
 
236
  # ——————————————————— INFERÊNCIA ——————————————————— #
237
  if predict_clicked:
 
242
 
243
  for header, seq in parsed_seqs:
244
  with st.spinner(f"A processar {header}… (pode demorar alguns minutos)"):
245
+ # embeddings
246
  emb_pb = embed_seq(FINETUNED_PB, seq, CHUNK_PB)
247
  emb_bfd = embed_seq(FINETUNED_BFD, seq, CHUNK_PB)
248
  emb_esm = embed_seq(BASE_ESM, seq, CHUNK_ESM)
249
 
250
+ # predições MLPs
251
  y_pb = mlp_pb.predict(emb_pb)
252
  y_bfd = mlp_bfd.predict(emb_bfd)
253
+ y_esm = mlp_esm.predict(emb_esm)[:, :597] # alinhar nº de termos
254
 
255
+ # stacking
256
  X = np.concatenate([y_pb, y_bfd, y_esm], axis=1)
257
  y_ens = stacking.predict(X)
258
 
259
+ mostrar(header, y_ens)
 
260
 
261
+ # ——————————————————— LISTA COMPLETA DE TERMOS ——————————————————— #
262
  with st.expander("Mostrar lista completa dos 597 GO terms possíveis", expanded=False):
263
  cols = st.columns(3)
264
  for i, go_id in enumerate(GO):