{ "cells": [ { "cell_type": "code", "execution_count": 6, "id": "0fbbb46c-1a00-4585-9ecd-a490a46e8b99", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\sklearn\\base.py:380: InconsistentVersionWarning: Trying to unpickle estimator MultiLabelBinarizer from version 1.1.3 when using version 1.6.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ProtBERT Fmax=0.6611 Thr=0.45 AuPRC=0.6951 Smin=13.4386\n", " ProtBERT-BFD Fmax=0.6588 Thr=0.46 AuPRC=0.6991 Smin=13.5461\n", " ESM-2 Fmax=0.6378 Thr=0.35 AuPRC=0.6850 Smin=14.4083\n", " Ensemble Fmax=0.6880 Thr=0.37 AuPRC=0.7334 Smin=12.7141\n" ] } ], "source": [ "# %%\n", "import numpy as np, joblib, math\n", "from sklearn.metrics import precision_recall_curve, auc\n", "from goatools.obo_parser import GODag\n", "\n", "GO_FILE = \"go.obo\"\n", "dag = GODag(GO_FILE)\n", "\n", "# y_true + GO terms (referência ProtBERT)\n", "test_pb = joblib.load(\"embeddings/test_protbert.pkl\")\n", "y_true = test_pb[\"labels\"] # (1724, 597) ← ground-truth\n", "go_ref = list(test_pb[\"go_terms\"]) # ordem exacta das colunas\n", "\n", "n_go = len(go_ref) # 597\n", "\n", "# Carregar predições\n", "y_pb = np.load(\"predictions/mf-protbert-pam1.npy\") # 1724×597\n", "y_bfd = np.load(\"predictions/mf-protbertbfd-pam1.npy\") # 1724×597\n", "y_esm0 = np.load(\"predictions/mf-esm2.npy\") # 1724×602\n", "\n", "# Remapear ESM-2 para ordem ProtBERT\n", "mlb_esm = joblib.load(\"data/mlb.pkl\") # 602 GO terms\n", "idx_map = [list(mlb_esm.classes_).index(t) for t in go_ref]\n", "y_esm = y_esm0[:, idx_map] # 1724×597\n", "\n", "# Garantir shapes iguais\n", "assert (y_true.shape == y_pb.shape == y_bfd.shape\n", " == y_esm.shape == (1724, n_go)), \"Ainda há desalinhamento!\"\n", "\n", "# Métricas\n", "THR = np.linspace(0,1,101)\n", "def fmax(y_t,y_p):\n", " best,thr = 0,0\n", " for t in THR:\n", " y_b = (y_p>=t).astype(int)\n", " tp = (y_t*y_b).sum(1); fp=((1-y_t)*y_b).sum(1); fn=(y_t*(1-y_b)).sum(1)\n", " f1 = 2*tp/(2*tp+fp+fn+1e-8); m=f1.mean()\n", " if m>best: best,thr = m,t\n", " return best,thr\n", "\n", "def auprc(y_t,y_p):\n", " p,r,_ = precision_recall_curve(y_t.ravel(), y_p.ravel()); return auc(r,p)\n", "\n", "def smin(y_t,y_p,thr,alpha=0.5):\n", " y_b=(y_p>=thr).astype(int)\n", " ic=-(np.log((y_t+y_b).sum(0)+1e-8)-np.log((y_t+y_b).sum()+1e-8))\n", " ru=np.logical_and(y_b, np.logical_not(y_t))*ic\n", " mi=np.logical_and(y_t, np.logical_not(y_b))*ic\n", " return np.sqrt((alpha*ru.sum(1))**2 + ((1-alpha)*mi.sum(1))**2).mean()\n", "\n", "def show(name,y_p):\n", " f,thr=fmax(y_true,y_p)\n", " print(f\"{name:>13s} Fmax={f:.4f} Thr={thr:.2f} \"\n", " f\"AuPRC={auprc(y_true,y_p):.4f} Smin={smin(y_true,y_p,thr):.4f}\")\n", "\n", "show(\"ProtBERT\", y_pb)\n", "show(\"ProtBERT-BFD\", y_bfd)\n", "show(\"ESM-2\", y_esm)\n", "show(\"Ensemble\", (y_pb + y_bfd + y_esm)/3)\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "f1807404-c2ce-48d0-b87c-a7e0fecc1728", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "19/19 [==============================] - 1s 12ms/step - loss: 0.3895 - val_loss: 0.0855\n", "Epoch 2/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0879 - val_loss: 0.0704\n", "Epoch 3/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0625 - val_loss: 0.0567\n", "Epoch 4/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0553 - val_loss: 0.0526\n", "Epoch 5/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0508 - val_loss: 0.0484\n", "Epoch 6/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0468 - val_loss: 0.0452\n", "Epoch 7/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0433 - val_loss: 0.0428\n", "Epoch 8/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0407 - val_loss: 0.0409\n", "Epoch 9/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0387 - val_loss: 0.0395\n", "Epoch 10/50\n", "19/19 [==============================] - 0s 10ms/step - loss: 0.0369 - val_loss: 0.0382\n", "Epoch 11/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0352 - val_loss: 0.0367\n", "Epoch 12/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0339 - val_loss: 0.0359\n", "Epoch 13/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0328 - val_loss: 0.0352\n", "Epoch 14/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0315 - val_loss: 0.0344\n", "Epoch 15/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0305 - val_loss: 0.0341\n", "Epoch 16/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0296 - val_loss: 0.0336\n", "Epoch 17/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0291 - val_loss: 0.0332\n", "Epoch 18/50\n", "19/19 [==============================] - 0s 8ms/step - loss: 0.0282 - val_loss: 0.0331\n", "Epoch 19/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0273 - val_loss: 0.0329\n", "Epoch 20/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0269 - val_loss: 0.0329\n", "Epoch 21/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0264 - val_loss: 0.0324\n", "Epoch 22/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0257 - val_loss: 0.0325\n", "Epoch 23/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0253 - val_loss: 0.0324\n", "Epoch 24/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0246 - val_loss: 0.0322\n", "Epoch 25/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0247 - val_loss: 0.0323\n", "Epoch 26/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0241 - val_loss: 0.0321\n", "Epoch 27/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0236 - val_loss: 0.0323\n", "Epoch 28/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0233 - val_loss: 0.0324\n", "Epoch 29/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0228 - val_loss: 0.0325\n", "Epoch 30/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0227 - val_loss: 0.0323\n", "Epoch 31/50\n", "19/19 [==============================] - 0s 9ms/step - loss: 0.0219 - val_loss: 0.0325\n", "27/27 [==============================] - 0s 2ms/step\n", "\n", " STACKING (GPU-Keras MLP)\n", "Fmax = 0.6956\n", "Thr. = 0.37\n", "AuPRC = 0.7591\n", "Smin = 12.2272\n" ] } ], "source": [ "# %%\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Dropout\n", "from tensorflow.keras.optimizers import Adam\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import precision_recall_curve, auc\n", "import numpy as np\n", "import math\n", "\n", "# Preparar dados para stacking\n", "# (já com y_pb, y_bfd, y_esm com shape (1724, 597))\n", "X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1724, 597*3)\n", "y_stack = y_true.copy() # (1724, 597)\n", "\n", "# Divisão treino/validação\n", "X_train, X_val, y_train, y_val = train_test_split(X_stack, y_stack, test_size=0.3, random_state=42)\n", "\n", "# Modelo MLP (usa GPU automaticamente se disponível)\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "\n", "model = Sequential([\n", " Dense(512, activation=\"relu\", input_shape=(X_train.shape[1],)),\n", " Dropout(0.3),\n", " Dense(256, activation=\"relu\"),\n", " Dropout(0.3),\n", " Dense(y_stack.shape[1], activation=\"sigmoid\")\n", "])\n", "\n", "model.compile(optimizer=Adam(1e-3), loss=\"binary_crossentropy\")\n", "\n", "model.fit(X_train, y_train, validation_data=(X_val, y_val),\n", " epochs=50, batch_size=64, verbose=1,\n", " callbacks=[EarlyStopping(patience=5, restore_best_weights=True)])\n", "\n", "# Prever com stacking\n", "y_pred_stack = model.predict(X_stack, batch_size=64)\n", "\n", "# Métricas\n", "THR = np.linspace(0, 1, 101)\n", "def fmax(y_t, y_p):\n", " best, thr = 0, 0\n", " for t in THR:\n", " y_b = (y_p >= t).astype(int)\n", " tp = (y_t * y_b).sum(1); fp = ((1 - y_t) * y_b).sum(1); fn = (y_t * (1 - y_b)).sum(1)\n", " f1 = 2 * tp / (2 * tp + fp + fn + 1e-8); m = f1.mean()\n", " if m > best: best, thr = m, t\n", " return best, thr\n", "\n", "def auprc(y_t, y_p):\n", " p, r, _ = precision_recall_curve(y_t.ravel(), y_p.ravel())\n", " return auc(r, p)\n", "\n", "def smin(y_t, y_p, thr, alpha=0.5):\n", " y_b = (y_p >= thr).astype(int)\n", " ic = -(np.log((y_t + y_b).sum(0) + 1e-8) - np.log((y_t + y_b).sum() + 1e-8))\n", " ru = np.logical_and(y_b, np.logical_not(y_t)) * ic\n", " mi = np.logical_and(y_t, np.logical_not(y_b)) * ic\n", " return np.sqrt((alpha * ru.sum(1))**2 + ((1 - alpha) * mi.sum(1))**2).mean()\n", "\n", "f, thr = fmax(y_stack, y_pred_stack)\n", "print(f\"\\n STACKING (GPU-Keras MLP)\")\n", "print(f\"Fmax = {f:.4f}\")\n", "print(f\"Thr. = {thr:.2f}\")\n", "print(f\"AuPRC = {auprc(y_stack, y_pred_stack):.4f}\")\n", "print(f\"Smin = {smin(y_stack, y_pred_stack, thr):.4f}\")\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "2fac1b06-2695-4c94-855e-24e5bd993e1c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Modelo guardado em models/\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py:3000: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n", " saving_api.save_model(\n" ] } ], "source": [ "model.save(\"models/ensemble_stack.h5\")\n", "model.save(\"models/ensemble_stack.keras\")\n", "print(\"Modelo guardado em models/\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "00695029-3d24-4803-a6e1-8ac5fd70b710", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "🔎 STACKING (MLP) — Avaliação completa\n", "Fmax = 0.6956\n", "Thr. = 0.37\n", "AuPRC = 0.7591\n", "Smin = 12.2272\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# %%\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.metrics import precision_recall_curve, auc\n", "\n", "# Funções métricas\n", "THR = np.linspace(0.01, 0.99, 99)\n", "\n", "def fmax(y_t, y_p):\n", " best, thr = 0, 0\n", " for t in THR:\n", " y_b = (y_p >= t).astype(int)\n", " tp = (y_t * y_b).sum(1)\n", " fp = ((1 - y_t) * y_b).sum(1)\n", " fn = (y_t * (1 - y_b)).sum(1)\n", " f1 = 2 * tp / (2 * tp + fp + fn + 1e-8)\n", " m = f1.mean()\n", " if m > best:\n", " best, thr = m, t\n", " return best, thr\n", "\n", "def auprc(y_t, y_p):\n", " p, r, _ = precision_recall_curve(y_t.ravel(), y_p.ravel())\n", " return auc(r, p)\n", "\n", "def smin(y_t, y_p, thr, alpha=0.5):\n", " y_b = (y_p >= thr).astype(int)\n", " ic = -(np.log((y_t + y_b).sum(0) + 1e-8) - np.log((y_t + y_b).sum() + 1e-8))\n", " ru = np.logical_and(y_b, np.logical_not(y_t)) * ic\n", " mi = np.logical_and(y_t, np.logical_not(y_b)) * ic\n", " return np.sqrt((alpha * ru.sum(1))**2 + ((1 - alpha) * mi.sum(1))**2).mean()\n", "\n", "# Avaliação\n", "f, thr = fmax(y_stack, y_pred_stack)\n", "print(f\"\\n🔎 STACKING (MLP) — Avaliação completa\")\n", "print(f\"Fmax = {f:.4f}\")\n", "print(f\"Thr. = {thr:.2f}\")\n", "print(f\"AuPRC = {auprc(y_stack, y_pred_stack):.4f}\")\n", "print(f\"Smin = {smin(y_stack, y_pred_stack, thr):.4f}\")\n", "\n", "# Gráfico Fmax vs Threshold\n", "fmax_scores = []\n", "for t in THR:\n", " y_b = (y_pred_stack >= t).astype(int)\n", " tp = (y_stack * y_b).sum(1)\n", " fp = ((1 - y_stack) * y_b).sum(1)\n", " fn = (y_stack * (1 - y_b)).sum(1)\n", " f1 = 2 * tp / (2 * tp + fp + fn + 1e-8)\n", " fmax_scores.append(f1.mean())\n", "\n", "plt.figure(figsize=(8, 5))\n", "plt.plot(THR, fmax_scores, label=\"F1 médio (Fmax)\")\n", "plt.axvline(thr, color=\"red\", linestyle=\"--\", label=f\"Threshold ótimo = {thr:.2f}\")\n", "plt.xlabel(\"Threshold\")\n", "plt.ylabel(\"F1-score médio\")\n", "plt.title(\"Fmax vs Threshold (Stacking MLP)\")\n", "plt.legend()\n", "plt.grid(True)\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ee9f9472-0b22-4ceb-ab09-7d925349e237", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }