Delete notebooks
Browse files- notebooks/Ensemble.ipynb +0 -320
- notebooks/Input.ipynb +0 -157
- notebooks/PAM1_ESM2.ipynb +0 -548
- notebooks/PAM1_protbert.ipynb +0 -971
- notebooks/PAM1_protbertBFD.ipynb +0 -872
notebooks/Ensemble.ipynb
DELETED
|
@@ -1,320 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"id": "0fbbb46c-1a00-4585-9ecd-a490a46e8b99",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stderr",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\protein_prediction_env\\lib\\site-packages\\requests\\__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).\n",
|
| 14 |
-
" warnings.warn(\n"
|
| 15 |
-
]
|
| 16 |
-
},
|
| 17 |
-
{
|
| 18 |
-
"name": "stdout",
|
| 19 |
-
"output_type": "stream",
|
| 20 |
-
"text": [
|
| 21 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 22 |
-
"mlb com 597 GO terms guardado em data/mlb_597.pkl\n"
|
| 23 |
-
]
|
| 24 |
-
},
|
| 25 |
-
{
|
| 26 |
-
"name": "stderr",
|
| 27 |
-
"output_type": "stream",
|
| 28 |
-
"text": [
|
| 29 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\protein_prediction_env\\lib\\site-packages\\sklearn\\base.py:348: InconsistentVersionWarning: Trying to unpickle estimator MultiLabelBinarizer from version 1.1.3 when using version 1.3.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
|
| 30 |
-
"https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
|
| 31 |
-
" warnings.warn(\n"
|
| 32 |
-
]
|
| 33 |
-
},
|
| 34 |
-
{
|
| 35 |
-
"name": "stdout",
|
| 36 |
-
"output_type": "stream",
|
| 37 |
-
"text": [
|
| 38 |
-
" ProtBERT Fmax=0.6665 Thr=0.41 AuPRC=0.7037 Smin=13.3996\n",
|
| 39 |
-
" ProtBERT-BFD Fmax=0.6579 Thr=0.43 AuPRC=0.6947 Smin=13.7870\n",
|
| 40 |
-
" ESM-2 Fmax=0.6350 Thr=0.35 AuPRC=0.6823 Smin=14.2115\n",
|
| 41 |
-
" Ensemble Fmax=0.6893 Thr=0.36 AuPRC=0.7350 Smin=12.6319\n"
|
| 42 |
-
]
|
| 43 |
-
}
|
| 44 |
-
],
|
| 45 |
-
"source": [
|
| 46 |
-
"# %%\n",
|
| 47 |
-
"import numpy as np, joblib, math\n",
|
| 48 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 49 |
-
"from goatools.obo_parser import GODag\n",
|
| 50 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 51 |
-
"\n",
|
| 52 |
-
"GO_FILE = \"go.obo\"\n",
|
| 53 |
-
"dag = GODag(GO_FILE)\n",
|
| 54 |
-
"\n",
|
| 55 |
-
"# ---------- 1. y_true + GO terms (referência ProtBERT) ----------\n",
|
| 56 |
-
"test_pb = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 57 |
-
"y_true = test_pb[\"labels\"] # (1724, 597) ← ground-truth\n",
|
| 58 |
-
"go_ref = list(test_pb[\"go_terms\"]) # ordem exacta das colunas\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"n_go = len(go_ref) # 597\n",
|
| 61 |
-
"\n",
|
| 62 |
-
"# --- Recriar o MultiLabelBinarizer com os 597 termos corretos ---\n",
|
| 63 |
-
"mlb = MultiLabelBinarizer(classes=go_ref)\n",
|
| 64 |
-
"mlb.fit([go_ref]) # necessário para permitir inverse_transform depois\n",
|
| 65 |
-
"\n",
|
| 66 |
-
"# ---------- 2. Carregar predições ----------\n",
|
| 67 |
-
"y_pb = np.load(\"predictions/mf-protbert-pam1.npy\") # 1724×597\n",
|
| 68 |
-
"y_bfd = np.load(\"predictions/mf-protbertbfd-pam1.npy\") # 1724×597\n",
|
| 69 |
-
"y_esm0 = np.load(\"predictions/mf-esm2.npy\") # 1724×602\n",
|
| 70 |
-
"\n",
|
| 71 |
-
"# ---------- 3. Remapear ESM-2 para ordem ProtBERT ----------\n",
|
| 72 |
-
"mlb_esm = joblib.load(\"data/mlb.pkl\") # 602 GO terms\n",
|
| 73 |
-
"idx_map = [list(mlb_esm.classes_).index(t) for t in go_ref]\n",
|
| 74 |
-
"y_esm = y_esm0[:, idx_map] # 1724×597\n",
|
| 75 |
-
"\n",
|
| 76 |
-
"# ---------- 4. Garantir shapes iguais ----------\n",
|
| 77 |
-
"assert (y_true.shape == y_pb.shape == y_bfd.shape\n",
|
| 78 |
-
" == y_esm.shape == (1724, n_go)), \"Ainda há desalinhamento!\"\n",
|
| 79 |
-
"\n",
|
| 80 |
-
"# ---------- 4. Guardar mlb (y_true) alinhado ----------\n",
|
| 81 |
-
"joblib.dump(mlb, \"data/mlb_597.pkl\")\n",
|
| 82 |
-
"print(\"mlb com 597 GO terms guardado em data/mlb_597.pkl\")\n",
|
| 83 |
-
"\n",
|
| 84 |
-
"# ---------- 5. Métricas ----------\n",
|
| 85 |
-
"THR = np.linspace(0,1,101)\n",
|
| 86 |
-
"def fmax(y_t,y_p):\n",
|
| 87 |
-
" best,thr = 0,0\n",
|
| 88 |
-
" for t in THR:\n",
|
| 89 |
-
" y_b = (y_p>=t).astype(int)\n",
|
| 90 |
-
" tp = (y_t*y_b).sum(1); fp=((1-y_t)*y_b).sum(1); fn=(y_t*(1-y_b)).sum(1)\n",
|
| 91 |
-
" f1 = 2*tp/(2*tp+fp+fn+1e-8); m=f1.mean()\n",
|
| 92 |
-
" if m>best: best,thr = m,t\n",
|
| 93 |
-
" return best,thr\n",
|
| 94 |
-
"\n",
|
| 95 |
-
"def auprc(y_t,y_p):\n",
|
| 96 |
-
" p,r,_ = precision_recall_curve(y_t.ravel(), y_p.ravel()); return auc(r,p)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"def smin(y_t,y_p,thr,alpha=0.5):\n",
|
| 99 |
-
" y_b=(y_p>=thr).astype(int)\n",
|
| 100 |
-
" ic=-(np.log((y_t+y_b).sum(0)+1e-8)-np.log((y_t+y_b).sum()+1e-8))\n",
|
| 101 |
-
" ru=np.logical_and(y_b, np.logical_not(y_t))*ic\n",
|
| 102 |
-
" mi=np.logical_and(y_t, np.logical_not(y_b))*ic\n",
|
| 103 |
-
" return np.sqrt((alpha*ru.sum(1))**2 + ((1-alpha)*mi.sum(1))**2).mean()\n",
|
| 104 |
-
"\n",
|
| 105 |
-
"def show(name,y_p):\n",
|
| 106 |
-
" f,thr=fmax(y_true,y_p)\n",
|
| 107 |
-
" print(f\"{name:>13s} Fmax={f:.4f} Thr={thr:.2f} \"\n",
|
| 108 |
-
" f\"AuPRC={auprc(y_true,y_p):.4f} Smin={smin(y_true,y_p,thr):.4f}\")\n",
|
| 109 |
-
"\n",
|
| 110 |
-
"show(\"ProtBERT\", y_pb)\n",
|
| 111 |
-
"show(\"ProtBERT-BFD\", y_bfd)\n",
|
| 112 |
-
"show(\"ESM-2\", y_esm)\n",
|
| 113 |
-
"show(\"Ensemble\", (y_pb + y_bfd + y_esm)/3)\n",
|
| 114 |
-
"\n"
|
| 115 |
-
]
|
| 116 |
-
},
|
| 117 |
-
{
|
| 118 |
-
"cell_type": "code",
|
| 119 |
-
"execution_count": 3,
|
| 120 |
-
"id": "f1807404-c2ce-48d0-b87c-a7e0fecc1728",
|
| 121 |
-
"metadata": {},
|
| 122 |
-
"outputs": [
|
| 123 |
-
{
|
| 124 |
-
"name": "stdout",
|
| 125 |
-
"output_type": "stream",
|
| 126 |
-
"text": [
|
| 127 |
-
"Epoch 1/50\n",
|
| 128 |
-
"19/19 [==============================] - 1s 13ms/step - loss: 0.3857 - val_loss: 0.0857\n",
|
| 129 |
-
"Epoch 2/50\n",
|
| 130 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0884 - val_loss: 0.0689\n",
|
| 131 |
-
"Epoch 3/50\n",
|
| 132 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0618 - val_loss: 0.0556\n",
|
| 133 |
-
"Epoch 4/50\n",
|
| 134 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0545 - val_loss: 0.0518\n",
|
| 135 |
-
"Epoch 5/50\n",
|
| 136 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0503 - val_loss: 0.0482\n",
|
| 137 |
-
"Epoch 6/50\n",
|
| 138 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0466 - val_loss: 0.0449\n",
|
| 139 |
-
"Epoch 7/50\n",
|
| 140 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0433 - val_loss: 0.0426\n",
|
| 141 |
-
"Epoch 8/50\n",
|
| 142 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0407 - val_loss: 0.0406\n",
|
| 143 |
-
"Epoch 9/50\n",
|
| 144 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0382 - val_loss: 0.0388\n",
|
| 145 |
-
"Epoch 10/50\n",
|
| 146 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0363 - val_loss: 0.0376\n",
|
| 147 |
-
"Epoch 11/50\n",
|
| 148 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0347 - val_loss: 0.0365\n",
|
| 149 |
-
"Epoch 12/50\n",
|
| 150 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0334 - val_loss: 0.0354\n",
|
| 151 |
-
"Epoch 13/50\n",
|
| 152 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0324 - val_loss: 0.0346\n",
|
| 153 |
-
"Epoch 14/50\n",
|
| 154 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0317 - val_loss: 0.0343\n",
|
| 155 |
-
"Epoch 15/50\n",
|
| 156 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0304 - val_loss: 0.0337\n",
|
| 157 |
-
"Epoch 16/50\n",
|
| 158 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0296 - val_loss: 0.0334\n",
|
| 159 |
-
"Epoch 17/50\n",
|
| 160 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0287 - val_loss: 0.0330\n",
|
| 161 |
-
"Epoch 18/50\n",
|
| 162 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0282 - val_loss: 0.0328\n",
|
| 163 |
-
"Epoch 19/50\n",
|
| 164 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0277 - val_loss: 0.0328\n",
|
| 165 |
-
"Epoch 20/50\n",
|
| 166 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0270 - val_loss: 0.0325\n",
|
| 167 |
-
"Epoch 21/50\n",
|
| 168 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0263 - val_loss: 0.0323\n",
|
| 169 |
-
"Epoch 22/50\n",
|
| 170 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0258 - val_loss: 0.0322\n",
|
| 171 |
-
"Epoch 23/50\n",
|
| 172 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0254 - val_loss: 0.0322\n",
|
| 173 |
-
"Epoch 24/50\n",
|
| 174 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0250 - val_loss: 0.0322\n",
|
| 175 |
-
"Epoch 25/50\n",
|
| 176 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0248 - val_loss: 0.0321\n",
|
| 177 |
-
"Epoch 26/50\n",
|
| 178 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0243 - val_loss: 0.0323\n",
|
| 179 |
-
"Epoch 27/50\n",
|
| 180 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0238 - val_loss: 0.0320\n",
|
| 181 |
-
"Epoch 28/50\n",
|
| 182 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0236 - val_loss: 0.0320\n",
|
| 183 |
-
"Epoch 29/50\n",
|
| 184 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0230 - val_loss: 0.0327\n",
|
| 185 |
-
"Epoch 30/50\n",
|
| 186 |
-
"19/19 [==============================] - 0s 10ms/step - loss: 0.0228 - val_loss: 0.0324\n",
|
| 187 |
-
"Epoch 31/50\n",
|
| 188 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0223 - val_loss: 0.0325\n",
|
| 189 |
-
"Epoch 32/50\n",
|
| 190 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0220 - val_loss: 0.0326\n",
|
| 191 |
-
"Epoch 33/50\n",
|
| 192 |
-
"19/19 [==============================] - 0s 9ms/step - loss: 0.0217 - val_loss: 0.0329\n",
|
| 193 |
-
"27/27 [==============================] - 0s 2ms/step\n",
|
| 194 |
-
"\n",
|
| 195 |
-
" STACKING (GPU-Keras MLP)\n",
|
| 196 |
-
"Fmax = 0.7020\n",
|
| 197 |
-
"Thr. = 0.34\n",
|
| 198 |
-
"AuPRC = 0.7637\n",
|
| 199 |
-
"Smin = 12.1382\n"
|
| 200 |
-
]
|
| 201 |
-
}
|
| 202 |
-
],
|
| 203 |
-
"source": [
|
| 204 |
-
"# %%\n",
|
| 205 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 206 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 207 |
-
"from tensorflow.keras.optimizers import Adam\n",
|
| 208 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 209 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 210 |
-
"import numpy as np\n",
|
| 211 |
-
"import math\n",
|
| 212 |
-
"\n",
|
| 213 |
-
"# --- Preparar dados para stacking ---\n",
|
| 214 |
-
"# (já com y_pb, y_bfd, y_esm com shape (1724, 597))\n",
|
| 215 |
-
"X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1) # (1724, 597*3)\n",
|
| 216 |
-
"y_stack = y_true.copy() # (1724, 597)\n",
|
| 217 |
-
"\n",
|
| 218 |
-
"# --- Divisão treino/validação ---\n",
|
| 219 |
-
"X_train, X_val, y_train, y_val = train_test_split(X_stack, y_stack, test_size=0.3, random_state=42)\n",
|
| 220 |
-
"\n",
|
| 221 |
-
"# --- Modelo MLP (usa GPU automaticamente se disponível) ---\n",
|
| 222 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 223 |
-
"\n",
|
| 224 |
-
"model = Sequential([\n",
|
| 225 |
-
" Dense(512, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
|
| 226 |
-
" Dropout(0.3),\n",
|
| 227 |
-
" Dense(256, activation=\"relu\"),\n",
|
| 228 |
-
" Dropout(0.3),\n",
|
| 229 |
-
" Dense(y_stack.shape[1], activation=\"sigmoid\")\n",
|
| 230 |
-
"])\n",
|
| 231 |
-
"\n",
|
| 232 |
-
"model.compile(optimizer=Adam(1e-3), loss=\"binary_crossentropy\")\n",
|
| 233 |
-
"\n",
|
| 234 |
-
"model.fit(X_train, y_train, validation_data=(X_val, y_val),\n",
|
| 235 |
-
" epochs=50, batch_size=64, verbose=1,\n",
|
| 236 |
-
" callbacks=[EarlyStopping(patience=5, restore_best_weights=True)])\n",
|
| 237 |
-
"\n",
|
| 238 |
-
"# --- Prever com stacking ---\n",
|
| 239 |
-
"y_pred_stack = model.predict(X_stack, batch_size=64)\n",
|
| 240 |
-
"\n",
|
| 241 |
-
"# --- Métricas ---\n",
|
| 242 |
-
"THR = np.linspace(0, 1, 101)\n",
|
| 243 |
-
"def fmax(y_t, y_p):\n",
|
| 244 |
-
" best, thr = 0, 0\n",
|
| 245 |
-
" for t in THR:\n",
|
| 246 |
-
" y_b = (y_p >= t).astype(int)\n",
|
| 247 |
-
" tp = (y_t * y_b).sum(1); fp = ((1 - y_t) * y_b).sum(1); fn = (y_t * (1 - y_b)).sum(1)\n",
|
| 248 |
-
" f1 = 2 * tp / (2 * tp + fp + fn + 1e-8); m = f1.mean()\n",
|
| 249 |
-
" if m > best: best, thr = m, t\n",
|
| 250 |
-
" return best, thr\n",
|
| 251 |
-
"\n",
|
| 252 |
-
"def auprc(y_t, y_p):\n",
|
| 253 |
-
" p, r, _ = precision_recall_curve(y_t.ravel(), y_p.ravel())\n",
|
| 254 |
-
" return auc(r, p)\n",
|
| 255 |
-
"\n",
|
| 256 |
-
"def smin(y_t, y_p, thr, alpha=0.5):\n",
|
| 257 |
-
" y_b = (y_p >= thr).astype(int)\n",
|
| 258 |
-
" ic = -(np.log((y_t + y_b).sum(0) + 1e-8) - np.log((y_t + y_b).sum() + 1e-8))\n",
|
| 259 |
-
" ru = np.logical_and(y_b, np.logical_not(y_t)) * ic\n",
|
| 260 |
-
" mi = np.logical_and(y_t, np.logical_not(y_b)) * ic\n",
|
| 261 |
-
" return np.sqrt((alpha * ru.sum(1))**2 + ((1 - alpha) * mi.sum(1))**2).mean()\n",
|
| 262 |
-
"\n",
|
| 263 |
-
"f, thr = fmax(y_stack, y_pred_stack)\n",
|
| 264 |
-
"print(f\"\\n STACKING (GPU-Keras MLP)\")\n",
|
| 265 |
-
"print(f\"Fmax = {f:.4f}\")\n",
|
| 266 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 267 |
-
"print(f\"AuPRC = {auprc(y_stack, y_pred_stack):.4f}\")\n",
|
| 268 |
-
"print(f\"Smin = {smin(y_stack, y_pred_stack, thr):.4f}\")\n"
|
| 269 |
-
]
|
| 270 |
-
},
|
| 271 |
-
{
|
| 272 |
-
"cell_type": "code",
|
| 273 |
-
"execution_count": 4,
|
| 274 |
-
"id": "00695029-3d24-4803-a6e1-8ac5fd70b710",
|
| 275 |
-
"metadata": {},
|
| 276 |
-
"outputs": [
|
| 277 |
-
{
|
| 278 |
-
"name": "stdout",
|
| 279 |
-
"output_type": "stream",
|
| 280 |
-
"text": [
|
| 281 |
-
"guardado em models/ensemble_stacking.keras\n"
|
| 282 |
-
]
|
| 283 |
-
}
|
| 284 |
-
],
|
| 285 |
-
"source": [
|
| 286 |
-
"model.save(\"models/ensemble_stacking.keras\")\n",
|
| 287 |
-
"print('guardado em models/ensemble_stacking.keras')"
|
| 288 |
-
]
|
| 289 |
-
},
|
| 290 |
-
{
|
| 291 |
-
"cell_type": "code",
|
| 292 |
-
"execution_count": null,
|
| 293 |
-
"id": "37629e3a-1c24-4f0f-9d12-dddf48be8724",
|
| 294 |
-
"metadata": {},
|
| 295 |
-
"outputs": [],
|
| 296 |
-
"source": []
|
| 297 |
-
}
|
| 298 |
-
],
|
| 299 |
-
"metadata": {
|
| 300 |
-
"kernelspec": {
|
| 301 |
-
"display_name": "Python 3 (ipykernel)",
|
| 302 |
-
"language": "python",
|
| 303 |
-
"name": "python3"
|
| 304 |
-
},
|
| 305 |
-
"language_info": {
|
| 306 |
-
"codemirror_mode": {
|
| 307 |
-
"name": "ipython",
|
| 308 |
-
"version": 3
|
| 309 |
-
},
|
| 310 |
-
"file_extension": ".py",
|
| 311 |
-
"mimetype": "text/x-python",
|
| 312 |
-
"name": "python",
|
| 313 |
-
"nbconvert_exporter": "python",
|
| 314 |
-
"pygments_lexer": "ipython3",
|
| 315 |
-
"version": "3.10.16"
|
| 316 |
-
}
|
| 317 |
-
},
|
| 318 |
-
"nbformat": 4,
|
| 319 |
-
"nbformat_minor": 5
|
| 320 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/Input.ipynb
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"id": "9eca7d69-3f17-4306-84d0-58a0363144fa",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"A gerar embeddings por chunks...\n"
|
| 14 |
-
]
|
| 15 |
-
},
|
| 16 |
-
{
|
| 17 |
-
"name": "stderr",
|
| 18 |
-
"output_type": "stream",
|
| 19 |
-
"text": [
|
| 20 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 21 |
-
" warnings.warn(\n",
|
| 22 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 23 |
-
" warnings.warn(\n",
|
| 24 |
-
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
|
| 25 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
| 26 |
-
]
|
| 27 |
-
},
|
| 28 |
-
{
|
| 29 |
-
"name": "stdout",
|
| 30 |
-
"output_type": "stream",
|
| 31 |
-
"text": [
|
| 32 |
-
"A fazer predições base...\n",
|
| 33 |
-
"\n",
|
| 34 |
-
" GO terms com prob ≥ 0.5:\n",
|
| 35 |
-
"('GO:0003674', 'GO:0003824', 'GO:0005488', 'GO:0016491', 'GO:0036094', 'GO:0043167')\n",
|
| 36 |
-
"\n",
|
| 37 |
-
" Top 10 GO terms mais prováveis:\n",
|
| 38 |
-
"GO:0003674 : 0.9975\n",
|
| 39 |
-
"GO:0003824 : 0.9156\n",
|
| 40 |
-
"GO:0036094 : 0.6652\n",
|
| 41 |
-
"GO:0043167 : 0.6336\n",
|
| 42 |
-
"GO:0016491 : 0.6327\n",
|
| 43 |
-
"GO:0005488 : 0.5595\n",
|
| 44 |
-
"GO:0043169 : 0.4801\n",
|
| 45 |
-
"GO:0140096 : 0.4790\n",
|
| 46 |
-
"GO:0051213 : 0.4551\n",
|
| 47 |
-
"GO:0046872 : 0.4098\n"
|
| 48 |
-
]
|
| 49 |
-
}
|
| 50 |
-
],
|
| 51 |
-
"source": [
|
| 52 |
-
"# %%\n",
|
| 53 |
-
"import numpy as np\n",
|
| 54 |
-
"import torch\n",
|
| 55 |
-
"from transformers import AutoTokenizer, AutoModel\n",
|
| 56 |
-
"from tensorflow.keras.models import load_model\n",
|
| 57 |
-
"import joblib\n",
|
| 58 |
-
"\n",
|
| 59 |
-
"# --- Parâmetros ---\n",
|
| 60 |
-
"SEQ_FASTA = \"MPISSSSSSSTKSMRRAASELERSDSVTSPRFIGRRQSLIEDARKEREAAAAAAEAAEATEQIVFEEEDGKALLNLFFTLRSSKTPALSRSLKVFETFEAKIHHLETRPCRKPRDSLEGLEYFVRCEVHLSDVSTLISSIKRIAEDVKTTKEVKFHWFPKKISELDRCHHLITKFDPDLDQEHPGFTDPVYRQRRKMIGDIAFRYKQGEPIPRVEYTEEEIGTWREVYSTLRDLYTTHACSEHLEAFNLLERHCGYSPENIPQLEDVSRFLRERTGFQLRPVAGLLSARDFLASLAFRVFQCTQYIRHASSPMHSPEPDCVHELLGHVPILADRVFAQFSQNIGLASLGASEEDIEKLSTLYWFTVEFGLCKQGGIVKAYGAGLLSSYGELVHALSDEPERREFDPEAAAIQPYQDQNYQSVYFVSESFTDAKEKLRSYVAGIKRPFSVRFDPYTYSIEVLDNPLKIRGGLESVKDELKMLTDALNVLA\"\n",
|
| 61 |
-
"TOP_N = 10\n",
|
| 62 |
-
"\n",
|
| 63 |
-
"# --- 1. Função para dividir sequência (512 para Protbert e Protbertbfd. 1024 para ESM2) ---\n",
|
| 64 |
-
"def slice_sequence(seq, chunk_size):\n",
|
| 65 |
-
" return [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"# --- 2. Função para gerar embeddings médios ---\n",
|
| 68 |
-
"def get_embedding_mean(model_name, seq, chunk_size):\n",
|
| 69 |
-
" tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)\n",
|
| 70 |
-
" model = AutoModel.from_pretrained(model_name)\n",
|
| 71 |
-
" model.eval()\n",
|
| 72 |
-
"\n",
|
| 73 |
-
" chunks = [seq[i:i+chunk_size] for i in range(0, len(seq), chunk_size)]\n",
|
| 74 |
-
" embeddings = []\n",
|
| 75 |
-
"\n",
|
| 76 |
-
" for chunk in chunks:\n",
|
| 77 |
-
" seq_chunk = \" \".join(list(chunk))\n",
|
| 78 |
-
" # tokenizar SEM truncar\n",
|
| 79 |
-
" inputs = tokenizer(seq_chunk,\n",
|
| 80 |
-
" return_tensors=\"pt\",\n",
|
| 81 |
-
" truncation=False, # ≤ 512 ou 1024 já garantido\n",
|
| 82 |
-
" padding=False)\n",
|
| 83 |
-
" with torch.no_grad():\n",
|
| 84 |
-
" cls = model(**inputs).last_hidden_state[:, 0, :].squeeze().numpy()\n",
|
| 85 |
-
" embeddings.append(cls)\n",
|
| 86 |
-
"\n",
|
| 87 |
-
" return np.mean(embeddings, axis=0, keepdims=True) # (1, dim)\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"print(\"A gerar embeddings por chunks...\")\n",
|
| 90 |
-
"emb_pb = get_embedding_mean(\"Rostlab/prot_bert\", SEQ_FASTA, 512)\n",
|
| 91 |
-
"emb_bfd = get_embedding_mean(\"Rostlab/prot_bert_bfd\", SEQ_FASTA, 512)\n",
|
| 92 |
-
"emb_esm = get_embedding_mean(\"facebook/esm2_t33_650M_UR50D\", SEQ_FASTA, 1024)\n",
|
| 93 |
-
"\n",
|
| 94 |
-
"# --- 3. Carregar os MLPs base ---\n",
|
| 95 |
-
"mlp_pb = load_model(\"models/protbert_mlp.keras\")\n",
|
| 96 |
-
"mlp_bfd = load_model(\"models/protbertbfd_mlp.keras\")\n",
|
| 97 |
-
"mlp_esm = load_model(\"models/esm2_mlp.keras\")\n",
|
| 98 |
-
"\n",
|
| 99 |
-
"# --- 4. Gerar predições base (garantir 597 colunas) ---\n",
|
| 100 |
-
"print(\"A fazer predições base...\")\n",
|
| 101 |
-
"y_pb = mlp_pb.predict(emb_pb)[:, :597]\n",
|
| 102 |
-
"y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]\n",
|
| 103 |
-
"y_esm = mlp_esm.predict(emb_esm)[:, :597]\n",
|
| 104 |
-
"\n",
|
| 105 |
-
"# --- 5. Concatenar para o stacking ---\n",
|
| 106 |
-
"X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)\n",
|
| 107 |
-
"\n",
|
| 108 |
-
"# --- 6. Carregar modelo de stacking ---\n",
|
| 109 |
-
"stacking = load_model(\"models/modelo_ensemble_stacking.keras\")\n",
|
| 110 |
-
"y_pred = stacking.predict(X_stack)\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"# --- 7. Carregar binarizador (597 GO terms) ---\n",
|
| 113 |
-
"mlb = joblib.load(\"data/mlb_597.pkl\")\n",
|
| 114 |
-
"go_terms = mlb.classes_\n",
|
| 115 |
-
"\n",
|
| 116 |
-
"# --- 8. Mostrar resultados ---\n",
|
| 117 |
-
"print(\"\\n GO terms com prob ≥ 0.5:\")\n",
|
| 118 |
-
"predicted_terms = mlb.inverse_transform((y_pred >= 0.5).astype(int))\n",
|
| 119 |
-
"print(predicted_terms[0] if predicted_terms[0] else \"Nenhum GO term acima de 0.5\")\n",
|
| 120 |
-
"\n",
|
| 121 |
-
"print(f\"\\n Top {TOP_N} GO terms mais prováveis:\")\n",
|
| 122 |
-
"top_idx = np.argsort(-y_pred[0])[:TOP_N]\n",
|
| 123 |
-
"for i in top_idx:\n",
|
| 124 |
-
" print(f\"{go_terms[i]} : {y_pred[0][i]:.4f}\")\n"
|
| 125 |
-
]
|
| 126 |
-
},
|
| 127 |
-
{
|
| 128 |
-
"cell_type": "code",
|
| 129 |
-
"execution_count": null,
|
| 130 |
-
"id": "e959e7d9-15ba-4533-a2bb-ddd7df2a639d",
|
| 131 |
-
"metadata": {},
|
| 132 |
-
"outputs": [],
|
| 133 |
-
"source": []
|
| 134 |
-
}
|
| 135 |
-
],
|
| 136 |
-
"metadata": {
|
| 137 |
-
"kernelspec": {
|
| 138 |
-
"display_name": "Python 3 (ipykernel)",
|
| 139 |
-
"language": "python",
|
| 140 |
-
"name": "python3"
|
| 141 |
-
},
|
| 142 |
-
"language_info": {
|
| 143 |
-
"codemirror_mode": {
|
| 144 |
-
"name": "ipython",
|
| 145 |
-
"version": 3
|
| 146 |
-
},
|
| 147 |
-
"file_extension": ".py",
|
| 148 |
-
"mimetype": "text/x-python",
|
| 149 |
-
"name": "python",
|
| 150 |
-
"nbconvert_exporter": "python",
|
| 151 |
-
"pygments_lexer": "ipython3",
|
| 152 |
-
"version": "3.8.18"
|
| 153 |
-
}
|
| 154 |
-
},
|
| 155 |
-
"nbformat": 4,
|
| 156 |
-
"nbformat_minor": 5
|
| 157 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/PAM1_ESM2.ipynb
DELETED
|
@@ -1,548 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 9,
|
| 6 |
-
"id": "641053e3-7fec-4f9b-a75e-ddd957af03c4",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"✓ Dataset preparado:\n",
|
| 15 |
-
" - Training: (31142, 3)\n",
|
| 16 |
-
" - Validation: (1724, 3)\n",
|
| 17 |
-
" - Test: (1724, 3)\n",
|
| 18 |
-
" - GO terms: 602\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"# %%\n",
|
| 24 |
-
"import pandas as pd\n",
|
| 25 |
-
"import numpy as np\n",
|
| 26 |
-
"from Bio import SeqIO\n",
|
| 27 |
-
"from goatools.obo_parser import GODag\n",
|
| 28 |
-
"from collections import Counter\n",
|
| 29 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 30 |
-
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
|
| 31 |
-
"import os, random\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"# --- 1. Carregar ficheiros principais ---\n",
|
| 34 |
-
"FASTA = \"uniprot_sprot_exp.fasta\"\n",
|
| 35 |
-
"ANNOT = \"uniprot_sprot_exp.txt\"\n",
|
| 36 |
-
"GO_OBO = \"go.obo\"\n",
|
| 37 |
-
"\n",
|
| 38 |
-
"# --- 2. Ler sequências ---\n",
|
| 39 |
-
"seqs, ids = [], []\n",
|
| 40 |
-
"for record in SeqIO.parse(FASTA, \"fasta\"):\n",
|
| 41 |
-
" ids.append(record.id)\n",
|
| 42 |
-
" seqs.append(str(record.seq))\n",
|
| 43 |
-
"\n",
|
| 44 |
-
"df_seq = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
|
| 45 |
-
"\n",
|
| 46 |
-
"# --- 3. Ler anotações GO:MF ---\n",
|
| 47 |
-
"df_ann = pd.read_csv(ANNOT, sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"category\"])\n",
|
| 48 |
-
"df_ann = df_ann[df_ann[\"category\"] == \"F\"]\n",
|
| 49 |
-
"\n",
|
| 50 |
-
"# --- 4. Propagação hierárquica dos GO terms ---\n",
|
| 51 |
-
"go_dag = GODag(GO_OBO)\n",
|
| 52 |
-
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"def propagate_terms(terms):\n",
|
| 55 |
-
" expanded = set()\n",
|
| 56 |
-
" for t in terms:\n",
|
| 57 |
-
" if t in go_dag:\n",
|
| 58 |
-
" expanded |= go_dag[t].get_all_parents()\n",
|
| 59 |
-
" expanded.add(t)\n",
|
| 60 |
-
" return list(expanded & mf_terms)\n",
|
| 61 |
-
"\n",
|
| 62 |
-
"grouped = df_ann.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
|
| 63 |
-
"grouped[\"go_term\"] = grouped[\"go_term\"].apply(propagate_terms)\n",
|
| 64 |
-
"\n",
|
| 65 |
-
"# --- 5. Juntar com sequência ---\n",
|
| 66 |
-
"df = df_seq.merge(grouped, on=\"protein_id\")\n",
|
| 67 |
-
"df = df[df[\"go_term\"].str.len() > 0]\n",
|
| 68 |
-
"\n",
|
| 69 |
-
"# --- 6. Filtrar GO terms com ≥50 proteínas ---\n",
|
| 70 |
-
"all_terms = [term for sublist in df[\"go_term\"] for term in sublist]\n",
|
| 71 |
-
"term_counts = Counter(all_terms)\n",
|
| 72 |
-
"valid_terms = {t for t, count in term_counts.items() if count >= 50}\n",
|
| 73 |
-
"\n",
|
| 74 |
-
"df[\"go_term\"] = df[\"go_term\"].apply(lambda ts: [t for t in ts if t in valid_terms])\n",
|
| 75 |
-
"df = df[df[\"go_term\"].str.len() > 0]\n",
|
| 76 |
-
"\n",
|
| 77 |
-
"# --- 7. Preparar labels e dividir por proteína ---\n",
|
| 78 |
-
"df[\"go_terms\"] = df[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
|
| 79 |
-
"df = df[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
|
| 80 |
-
"\n",
|
| 81 |
-
"mlb = MultiLabelBinarizer()\n",
|
| 82 |
-
"Y = mlb.fit_transform(df[\"go_terms\"].str.split(\";\"))\n",
|
| 83 |
-
"X = df[[\"protein_id\", \"sequence\"]].values\n",
|
| 84 |
-
"\n",
|
| 85 |
-
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
|
| 86 |
-
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
|
| 87 |
-
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"df_train = df.iloc[train_idx].copy()\n",
|
| 90 |
-
"df_val = df.iloc[val_idx].copy()\n",
|
| 91 |
-
"df_test = df.iloc[test_idx].copy()\n",
|
| 92 |
-
"\n",
|
| 93 |
-
"os.makedirs(\"data\", exist_ok=True)\n",
|
| 94 |
-
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
|
| 95 |
-
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
|
| 96 |
-
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- 8. Guardar o binarizador ---\n",
|
| 99 |
-
"import joblib\n",
|
| 100 |
-
"joblib.dump(mlb, \"data/mlb.pkl\")\n",
|
| 101 |
-
"\n",
|
| 102 |
-
"print(\"✓ Dataset preparado:\")\n",
|
| 103 |
-
"print(\" - Training:\", df_train.shape)\n",
|
| 104 |
-
"print(\" - Validation:\", df_val.shape)\n",
|
| 105 |
-
"print(\" - Test:\", df_test.shape)\n",
|
| 106 |
-
"print(\" - GO terms:\", len(mlb.classes_))\n"
|
| 107 |
-
]
|
| 108 |
-
},
|
| 109 |
-
{
|
| 110 |
-
"cell_type": "code",
|
| 111 |
-
"execution_count": 10,
|
| 112 |
-
"id": "40ba1798-daf8-4649-ae3f-bfe81df6437f",
|
| 113 |
-
"metadata": {},
|
| 114 |
-
"outputs": [],
|
| 115 |
-
"source": [
|
| 116 |
-
"# %%\n",
|
| 117 |
-
"import random\n",
|
| 118 |
-
"from collections import defaultdict\n",
|
| 119 |
-
"\n",
|
| 120 |
-
"# --- PAM1 matrix normalizada ---\n",
|
| 121 |
-
"pam_data = {\n",
|
| 122 |
-
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
|
| 123 |
-
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
|
| 124 |
-
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
|
| 125 |
-
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
|
| 126 |
-
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
|
| 127 |
-
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
|
| 128 |
-
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
|
| 129 |
-
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
|
| 130 |
-
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
|
| 131 |
-
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
|
| 132 |
-
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
|
| 133 |
-
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
|
| 134 |
-
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
|
| 135 |
-
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
|
| 136 |
-
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
|
| 137 |
-
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
|
| 138 |
-
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
|
| 139 |
-
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
|
| 140 |
-
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
|
| 141 |
-
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
|
| 142 |
-
"}\n",
|
| 143 |
-
"\n",
|
| 144 |
-
"pam_raw = pd.DataFrame(pam_data, index=pam_data.keys())\n",
|
| 145 |
-
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
|
| 146 |
-
"pam_dict = {aa: pam_matrix.loc[aa].to_dict() for aa in pam_matrix.index}\n",
|
| 147 |
-
"\n",
|
| 148 |
-
"def pam1_substitution(aa):\n",
|
| 149 |
-
" if aa not in pam_dict:\n",
|
| 150 |
-
" return aa\n",
|
| 151 |
-
" subs = list(pam_dict[aa].keys())\n",
|
| 152 |
-
" probs = list(pam_dict[aa].values())\n",
|
| 153 |
-
" return np.random.choice(subs, p=probs)\n",
|
| 154 |
-
"\n",
|
| 155 |
-
"def augment_sequence(seq, sub_prob=0.05):\n",
|
| 156 |
-
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
|
| 157 |
-
"\n",
|
| 158 |
-
"def slice_sequence(seq, win=1024):\n",
|
| 159 |
-
" if len(seq) <= win:\n",
|
| 160 |
-
" return [seq]\n",
|
| 161 |
-
" return [seq[i:i+win] for i in range(0, len(seq), win)]\n",
|
| 162 |
-
"\n",
|
| 163 |
-
"def format_seq(seq):\n",
|
| 164 |
-
" return \" \".join(seq)\n",
|
| 165 |
-
"\n",
|
| 166 |
-
"# --- Carregar labels e datasets ---\n",
|
| 167 |
-
"import joblib\n",
|
| 168 |
-
"mlb = joblib.load(\"data/mlb.pkl\")\n",
|
| 169 |
-
"df_train = pd.read_csv(\"data/mf-training.csv\")\n",
|
| 170 |
-
"df_val = pd.read_csv(\"data/mf-validation.csv\")\n",
|
| 171 |
-
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
|
| 172 |
-
"\n",
|
| 173 |
-
"# --- Slicing + augmentação no treino ---\n",
|
| 174 |
-
"X_train, y_train = [], []\n",
|
| 175 |
-
"\n",
|
| 176 |
-
"for _, row in df_train.iterrows():\n",
|
| 177 |
-
" seq_aug = augment_sequence(row[\"sequence\"], sub_prob=0.05)\n",
|
| 178 |
-
" slices = slice_sequence(seq_aug, win=1024)\n",
|
| 179 |
-
" label = mlb.transform([row[\"go_terms\"].split(\";\")])[0]\n",
|
| 180 |
-
" for sl in slices:\n",
|
| 181 |
-
" X_train.append(format_seq(sl))\n",
|
| 182 |
-
" y_train.append(label)\n",
|
| 183 |
-
"\n",
|
| 184 |
-
"# --- Sem slicing no val/test ---\n",
|
| 185 |
-
"X_val = [format_seq(seq) for seq in df_val[\"sequence\"]]\n",
|
| 186 |
-
"X_test = [format_seq(seq) for seq in df_test[\"sequence\"]]\n",
|
| 187 |
-
"\n",
|
| 188 |
-
"y_val = mlb.transform(df_val[\"go_terms\"].str.split(\";\"))\n",
|
| 189 |
-
"y_test = mlb.transform(df_test[\"go_terms\"].str.split(\";\"))\n",
|
| 190 |
-
"\n",
|
| 191 |
-
"np.save(\"embeddings/y_test.npy\", y_test)"
|
| 192 |
-
]
|
| 193 |
-
},
|
| 194 |
-
{
|
| 195 |
-
"cell_type": "code",
|
| 196 |
-
"execution_count": 11,
|
| 197 |
-
"id": "80d5c1fb-9c84-463d-8d8c-bfcc2982afc9",
|
| 198 |
-
"metadata": {},
|
| 199 |
-
"outputs": [
|
| 200 |
-
{
|
| 201 |
-
"name": "stderr",
|
| 202 |
-
"output_type": "stream",
|
| 203 |
-
"text": [
|
| 204 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 205 |
-
" warnings.warn(\n",
|
| 206 |
-
"Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n",
|
| 207 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
|
| 208 |
-
"100%|██████████| 2189/2189 [1:17:26<00:00, 2.12s/it]\n",
|
| 209 |
-
"100%|██████████| 108/108 [03:43<00:00, 2.07s/it]\n",
|
| 210 |
-
"100%|██████████| 108/108 [03:56<00:00, 2.19s/it]\n"
|
| 211 |
-
]
|
| 212 |
-
}
|
| 213 |
-
],
|
| 214 |
-
"source": [
|
| 215 |
-
"# %%\n",
|
| 216 |
-
"from transformers import AutoTokenizer, AutoModel\n",
|
| 217 |
-
"import torch\n",
|
| 218 |
-
"from tqdm import tqdm\n",
|
| 219 |
-
"import numpy as np\n",
|
| 220 |
-
"import os\n",
|
| 221 |
-
"\n",
|
| 222 |
-
"# --- Configurações ---\n",
|
| 223 |
-
"MODEL_NAME = \"facebook/esm2_t33_650M_UR50D\"\n",
|
| 224 |
-
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 225 |
-
"CHUNK_SIZE = 16\n",
|
| 226 |
-
"\n",
|
| 227 |
-
"# --- Carregar modelo ---\n",
|
| 228 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
|
| 229 |
-
"model = AutoModel.from_pretrained(MODEL_NAME)\n",
|
| 230 |
-
"model.to(DEVICE)\n",
|
| 231 |
-
"model.eval()\n",
|
| 232 |
-
"\n",
|
| 233 |
-
"def extract_embeddings(texts):\n",
|
| 234 |
-
" embeddings = []\n",
|
| 235 |
-
" for i in tqdm(range(0, len(texts), CHUNK_SIZE)):\n",
|
| 236 |
-
" batch = texts[i:i+CHUNK_SIZE]\n",
|
| 237 |
-
" with torch.no_grad():\n",
|
| 238 |
-
" inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=1024)\n",
|
| 239 |
-
" inputs = {k: v.to(DEVICE) for k, v in inputs.items()}\n",
|
| 240 |
-
" outputs = model(**inputs).last_hidden_state\n",
|
| 241 |
-
" cls_tokens = outputs[:, 0, :] # token CLS\n",
|
| 242 |
-
" embeddings.append(cls_tokens.cpu().numpy())\n",
|
| 243 |
-
" return np.vstack(embeddings)\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"# --- Extrair e guardar embeddings ---\n",
|
| 246 |
-
"os.makedirs(\"embeddings\", exist_ok=True)\n",
|
| 247 |
-
"\n",
|
| 248 |
-
"emb_train = extract_embeddings(X_train)\n",
|
| 249 |
-
"emb_val = extract_embeddings(X_val)\n",
|
| 250 |
-
"emb_test = extract_embeddings(X_test)\n",
|
| 251 |
-
"\n",
|
| 252 |
-
"np.save(\"embeddings/esm2_train.npy\", emb_train)\n",
|
| 253 |
-
"np.save(\"embeddings/esm2_val.npy\", emb_val)\n",
|
| 254 |
-
"np.save(\"embeddings/esm2_test.npy\", emb_test)\n",
|
| 255 |
-
"\n",
|
| 256 |
-
"np.save(\"embeddings/y_train.npy\", np.array(y_train))\n",
|
| 257 |
-
"np.save(\"embeddings/y_val.npy\", np.array(y_val))\n"
|
| 258 |
-
]
|
| 259 |
-
},
|
| 260 |
-
{
|
| 261 |
-
"cell_type": "code",
|
| 262 |
-
"execution_count": 1,
|
| 263 |
-
"id": "592e4f6c-b871-4f0b-b84c-f3918c698544",
|
| 264 |
-
"metadata": {},
|
| 265 |
-
"outputs": [
|
| 266 |
-
{
|
| 267 |
-
"name": "stderr",
|
| 268 |
-
"output_type": "stream",
|
| 269 |
-
"text": [
|
| 270 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\protein_prediction_env\\lib\\site-packages\\requests\\__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).\n",
|
| 271 |
-
" warnings.warn(\n"
|
| 272 |
-
]
|
| 273 |
-
},
|
| 274 |
-
{
|
| 275 |
-
"name": "stdout",
|
| 276 |
-
"output_type": "stream",
|
| 277 |
-
"text": [
|
| 278 |
-
"Epoch 1/100\n",
|
| 279 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0552 - val_loss: 0.0452\n",
|
| 280 |
-
"Epoch 2/100\n",
|
| 281 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0444 - val_loss: 0.0409\n",
|
| 282 |
-
"Epoch 3/100\n",
|
| 283 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0418 - val_loss: 0.0391\n",
|
| 284 |
-
"Epoch 4/100\n",
|
| 285 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0402 - val_loss: 0.0382\n",
|
| 286 |
-
"Epoch 5/100\n",
|
| 287 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0390 - val_loss: 0.0375\n",
|
| 288 |
-
"Epoch 6/100\n",
|
| 289 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0381 - val_loss: 0.0366\n",
|
| 290 |
-
"Epoch 7/100\n",
|
| 291 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0374 - val_loss: 0.0372\n",
|
| 292 |
-
"Epoch 8/100\n",
|
| 293 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0367 - val_loss: 0.0353\n",
|
| 294 |
-
"Epoch 9/100\n",
|
| 295 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0361 - val_loss: 0.0348\n",
|
| 296 |
-
"Epoch 10/100\n",
|
| 297 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0357 - val_loss: 0.0340\n",
|
| 298 |
-
"Epoch 11/100\n",
|
| 299 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0353 - val_loss: 0.0343\n",
|
| 300 |
-
"Epoch 12/100\n",
|
| 301 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0348 - val_loss: 0.0335\n",
|
| 302 |
-
"Epoch 13/100\n",
|
| 303 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0345 - val_loss: 0.0328\n",
|
| 304 |
-
"Epoch 14/100\n",
|
| 305 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0341 - val_loss: 0.0328\n",
|
| 306 |
-
"Epoch 15/100\n",
|
| 307 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0338 - val_loss: 0.0327\n",
|
| 308 |
-
"Epoch 16/100\n",
|
| 309 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0335 - val_loss: 0.0328\n",
|
| 310 |
-
"Epoch 17/100\n",
|
| 311 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0332 - val_loss: 0.0325\n",
|
| 312 |
-
"Epoch 18/100\n",
|
| 313 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0330 - val_loss: 0.0322\n",
|
| 314 |
-
"Epoch 19/100\n",
|
| 315 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0328 - val_loss: 0.0323\n",
|
| 316 |
-
"Epoch 20/100\n",
|
| 317 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0325 - val_loss: 0.0321\n",
|
| 318 |
-
"Epoch 21/100\n",
|
| 319 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0323 - val_loss: 0.0316\n",
|
| 320 |
-
"Epoch 22/100\n",
|
| 321 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0322 - val_loss: 0.0320\n",
|
| 322 |
-
"Epoch 23/100\n",
|
| 323 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0320 - val_loss: 0.0315\n",
|
| 324 |
-
"Epoch 24/100\n",
|
| 325 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0317 - val_loss: 0.0319\n",
|
| 326 |
-
"Epoch 25/100\n",
|
| 327 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0315 - val_loss: 0.0313\n",
|
| 328 |
-
"Epoch 26/100\n",
|
| 329 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0314 - val_loss: 0.0313\n",
|
| 330 |
-
"Epoch 27/100\n",
|
| 331 |
-
"1095/1095 [==============================] - 13s 12ms/step - loss: 0.0312 - val_loss: 0.0311\n",
|
| 332 |
-
"Epoch 28/100\n",
|
| 333 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0309 - val_loss: 0.0313\n",
|
| 334 |
-
"Epoch 29/100\n",
|
| 335 |
-
"1095/1095 [==============================] - 14s 12ms/step - loss: 0.0308 - val_loss: 0.0312\n",
|
| 336 |
-
"Epoch 30/100\n",
|
| 337 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0308 - val_loss: 0.0313\n",
|
| 338 |
-
"Epoch 31/100\n",
|
| 339 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0307 - val_loss: 0.0307\n",
|
| 340 |
-
"Epoch 32/100\n",
|
| 341 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0306 - val_loss: 0.0311\n",
|
| 342 |
-
"Epoch 33/100\n",
|
| 343 |
-
"1095/1095 [==============================] - 15s 14ms/step - loss: 0.0304 - val_loss: 0.0314\n",
|
| 344 |
-
"Epoch 34/100\n",
|
| 345 |
-
"1095/1095 [==============================] - 17s 16ms/step - loss: 0.0303 - val_loss: 0.0309\n",
|
| 346 |
-
"Epoch 35/100\n",
|
| 347 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0301 - val_loss: 0.0308\n",
|
| 348 |
-
"Epoch 36/100\n",
|
| 349 |
-
"1095/1095 [==============================] - 14s 13ms/step - loss: 0.0299 - val_loss: 0.0309\n",
|
| 350 |
-
"Modelo guardado em models/mlp_esm2.keras\n",
|
| 351 |
-
"54/54 [==============================] - 0s 2ms/step\n",
|
| 352 |
-
" Predições do ESM-2 salvas com forma: (1724, 602)\n"
|
| 353 |
-
]
|
| 354 |
-
}
|
| 355 |
-
],
|
| 356 |
-
"source": [
|
| 357 |
-
"# %%\n",
|
| 358 |
-
"import numpy as np\n",
|
| 359 |
-
"import tensorflow as tf\n",
|
| 360 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 361 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 362 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 363 |
-
"from sklearn.metrics import average_precision_score\n",
|
| 364 |
-
"\n",
|
| 365 |
-
"# --- Carregar os embeddings e labels ---\n",
|
| 366 |
-
"X_train = np.load(\"embeddings/esm2_train.npy\")\n",
|
| 367 |
-
"X_val = np.load(\"embeddings/esm2_val.npy\")\n",
|
| 368 |
-
"X_test = np.load(\"embeddings/esm2_test.npy\")\n",
|
| 369 |
-
"\n",
|
| 370 |
-
"y_train = np.load(\"embeddings/y_train.npy\")\n",
|
| 371 |
-
"y_val = np.load(\"embeddings/y_val.npy\")\n",
|
| 372 |
-
"y_test = np.load(\"embeddings/y_test.npy\")\n",
|
| 373 |
-
"\n",
|
| 374 |
-
"# --- Definir o modelo ---\n",
|
| 375 |
-
"model = Sequential([\n",
|
| 376 |
-
" Dense(1024, activation='relu', input_shape=(X_train.shape[1],)),\n",
|
| 377 |
-
" Dropout(0.3),\n",
|
| 378 |
-
" Dense(512, activation='relu'),\n",
|
| 379 |
-
" Dropout(0.3),\n",
|
| 380 |
-
" Dense(y_train.shape[1], activation='sigmoid')\n",
|
| 381 |
-
"])\n",
|
| 382 |
-
"\n",
|
| 383 |
-
"model.compile(optimizer='adam', loss='binary_crossentropy')\n",
|
| 384 |
-
"\n",
|
| 385 |
-
"# --- Treinar ---\n",
|
| 386 |
-
"early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n",
|
| 387 |
-
"\n",
|
| 388 |
-
"history = model.fit(\n",
|
| 389 |
-
" X_train, y_train,\n",
|
| 390 |
-
" validation_data=(X_val, y_val),\n",
|
| 391 |
-
" epochs=100,\n",
|
| 392 |
-
" batch_size=32,\n",
|
| 393 |
-
" callbacks=[early_stop],\n",
|
| 394 |
-
" verbose=1\n",
|
| 395 |
-
")\n",
|
| 396 |
-
"\n",
|
| 397 |
-
"# --- Salvar o modelo ---\n",
|
| 398 |
-
"model.save(\"models/mlp_esm2.keras\")\n",
|
| 399 |
-
"print(\"Modelo guardado em models/mlp_esm2.keras\")\n",
|
| 400 |
-
"\n",
|
| 401 |
-
"# --- Fazer predições no conjunto de teste ---\n",
|
| 402 |
-
"y_prob = model.predict(X_test)\n",
|
| 403 |
-
"np.save(\"predictions/mf-esm2.npy\", y_prob)\n",
|
| 404 |
-
"\n",
|
| 405 |
-
"print(\" Predições do ESM-2 salvas com forma:\", y_prob.shape)\n"
|
| 406 |
-
]
|
| 407 |
-
},
|
| 408 |
-
{
|
| 409 |
-
"cell_type": "code",
|
| 410 |
-
"execution_count": 15,
|
| 411 |
-
"id": "3dddb0df-3ea5-4e32-8cf0-45e90be8ba66",
|
| 412 |
-
"metadata": {},
|
| 413 |
-
"outputs": [
|
| 414 |
-
{
|
| 415 |
-
"name": "stdout",
|
| 416 |
-
"output_type": "stream",
|
| 417 |
-
"text": [
|
| 418 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 419 |
-
"✓ Dados carregados: (1724, 602) proteínas × 602 GO terms\n",
|
| 420 |
-
"\n",
|
| 421 |
-
" Resultados finais (ESM-2 + PAM1 + propagação):\n",
|
| 422 |
-
"Fmax = 0.6439\n",
|
| 423 |
-
"Thr. = 0.34\n",
|
| 424 |
-
"AuPRC = 0.6948\n",
|
| 425 |
-
"Smin = 14.1500\n"
|
| 426 |
-
]
|
| 427 |
-
}
|
| 428 |
-
],
|
| 429 |
-
"source": [
|
| 430 |
-
"# %%\n",
|
| 431 |
-
"import numpy as np\n",
|
| 432 |
-
"import joblib\n",
|
| 433 |
-
"import math\n",
|
| 434 |
-
"from goatools.obo_parser import GODag\n",
|
| 435 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 436 |
-
"\n",
|
| 437 |
-
"# --- 1. Carregar dados e parâmetros ---\n",
|
| 438 |
-
"GO_FILE = \"go.obo\"\n",
|
| 439 |
-
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
|
| 440 |
-
"ALPHA = 0.5\n",
|
| 441 |
-
"\n",
|
| 442 |
-
"mlb = joblib.load(\"data/mlb.pkl\")\n",
|
| 443 |
-
"y_true = np.load(\"embeddings/y_test.npy\")\n",
|
| 444 |
-
"y_prob = np.load(\"predictions/mf-esm2.npy\")\n",
|
| 445 |
-
"terms = mlb.classes_\n",
|
| 446 |
-
"go_dag = GODag(GO_FILE)\n",
|
| 447 |
-
"\n",
|
| 448 |
-
"print(f\"✓ Dados carregados: {y_true.shape} proteínas × {len(terms)} GO terms\")\n",
|
| 449 |
-
"\n",
|
| 450 |
-
"# --- 2. Fmax ---\n",
|
| 451 |
-
"def compute_fmax(y_true, y_prob, thresholds):\n",
|
| 452 |
-
" fmax, best_thr = 0, 0\n",
|
| 453 |
-
" for t in thresholds:\n",
|
| 454 |
-
" y_pred = (y_prob >= t).astype(int)\n",
|
| 455 |
-
" tp = (y_true * y_pred).sum(axis=1)\n",
|
| 456 |
-
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
|
| 457 |
-
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
|
| 458 |
-
" precision = tp / (tp + fp + 1e-8)\n",
|
| 459 |
-
" recall = tp / (tp + fn + 1e-8)\n",
|
| 460 |
-
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
|
| 461 |
-
" avg_f1 = np.mean(f1)\n",
|
| 462 |
-
" if avg_f1 > fmax:\n",
|
| 463 |
-
" fmax, best_thr = avg_f1, t\n",
|
| 464 |
-
" return fmax, best_thr\n",
|
| 465 |
-
"\n",
|
| 466 |
-
"# --- 3. AuPRC (micro) ---\n",
|
| 467 |
-
"def compute_auprc(y_true, y_prob):\n",
|
| 468 |
-
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
|
| 469 |
-
" return auc(recall, precision)\n",
|
| 470 |
-
"\n",
|
| 471 |
-
"# --- 4. Smin ---\n",
|
| 472 |
-
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
|
| 473 |
-
" y_pred = (y_prob >= threshold).astype(int)\n",
|
| 474 |
-
"\n",
|
| 475 |
-
" # Informação semântica: IC (Information Content)\n",
|
| 476 |
-
" ic = {}\n",
|
| 477 |
-
" total = (y_true + y_pred).sum(axis=0).sum()\n",
|
| 478 |
-
" for i, term in enumerate(terms):\n",
|
| 479 |
-
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
|
| 480 |
-
" ic[term] = -np.log((freq + 1e-8) / total)\n",
|
| 481 |
-
"\n",
|
| 482 |
-
" # Para cada proteína, calcular RU e MI\n",
|
| 483 |
-
" s_values = []\n",
|
| 484 |
-
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
|
| 485 |
-
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
|
| 486 |
-
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
|
| 487 |
-
"\n",
|
| 488 |
-
" anc_true = set()\n",
|
| 489 |
-
" for t in true_terms:\n",
|
| 490 |
-
" if t in go_dag:\n",
|
| 491 |
-
" anc_true |= go_dag[t].get_all_parents()\n",
|
| 492 |
-
" anc_pred = set()\n",
|
| 493 |
-
" for t in pred_terms:\n",
|
| 494 |
-
" if t in go_dag:\n",
|
| 495 |
-
" anc_pred |= go_dag[t].get_all_parents()\n",
|
| 496 |
-
"\n",
|
| 497 |
-
" ru = pred_terms - true_terms\n",
|
| 498 |
-
" mi = true_terms - pred_terms\n",
|
| 499 |
-
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
|
| 500 |
-
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
|
| 501 |
-
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
|
| 502 |
-
" s_values.append(s)\n",
|
| 503 |
-
"\n",
|
| 504 |
-
" return np.mean(s_values)\n",
|
| 505 |
-
"\n",
|
| 506 |
-
"# --- 5. Avaliação ---\n",
|
| 507 |
-
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
|
| 508 |
-
"auprc = compute_auprc(y_true, y_prob)\n",
|
| 509 |
-
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
|
| 510 |
-
"\n",
|
| 511 |
-
"print(f\"\\n Resultados finais (ESM-2 + PAM1 + propagação):\")\n",
|
| 512 |
-
"print(f\"Fmax = {fmax:.4f}\")\n",
|
| 513 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 514 |
-
"print(f\"AuPRC = {auprc:.4f}\")\n",
|
| 515 |
-
"print(f\"Smin = {smin:.4f}\")\n"
|
| 516 |
-
]
|
| 517 |
-
},
|
| 518 |
-
{
|
| 519 |
-
"cell_type": "code",
|
| 520 |
-
"execution_count": null,
|
| 521 |
-
"id": "1a1ea084-01de-4dc4-88da-e7ffeb8c94c9",
|
| 522 |
-
"metadata": {},
|
| 523 |
-
"outputs": [],
|
| 524 |
-
"source": []
|
| 525 |
-
}
|
| 526 |
-
],
|
| 527 |
-
"metadata": {
|
| 528 |
-
"kernelspec": {
|
| 529 |
-
"display_name": "Python 3 (ipykernel)",
|
| 530 |
-
"language": "python",
|
| 531 |
-
"name": "python3"
|
| 532 |
-
},
|
| 533 |
-
"language_info": {
|
| 534 |
-
"codemirror_mode": {
|
| 535 |
-
"name": "ipython",
|
| 536 |
-
"version": 3
|
| 537 |
-
},
|
| 538 |
-
"file_extension": ".py",
|
| 539 |
-
"mimetype": "text/x-python",
|
| 540 |
-
"name": "python",
|
| 541 |
-
"nbconvert_exporter": "python",
|
| 542 |
-
"pygments_lexer": "ipython3",
|
| 543 |
-
"version": "3.10.16"
|
| 544 |
-
}
|
| 545 |
-
},
|
| 546 |
-
"nbformat": 4,
|
| 547 |
-
"nbformat_minor": 5
|
| 548 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/PAM1_protbert.ipynb
DELETED
|
@@ -1,971 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 2,
|
| 6 |
-
"id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"✓ Ficheiros criados:\n",
|
| 15 |
-
" - data/mf-training.csv : (31142, 3)\n",
|
| 16 |
-
" - data/mf-validation.csv: (1724, 3)\n",
|
| 17 |
-
" - data/mf-test.csv : (1724, 3)\n",
|
| 18 |
-
"GO terms únicos (após propagação e filtro): 602\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"import pandas as pd\n",
|
| 24 |
-
"from Bio import SeqIO\n",
|
| 25 |
-
"from collections import Counter\n",
|
| 26 |
-
"from goatools.obo_parser import GODag\n",
|
| 27 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 28 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 29 |
-
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
|
| 30 |
-
"import numpy as np\n",
|
| 31 |
-
"import os\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"# --- 1. Carregar GO anotações ------------------------------------------\n",
|
| 34 |
-
"annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
|
| 35 |
-
"annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
|
| 38 |
-
"# propagação hierárquica\n",
|
| 39 |
-
"# https://geneontology.org/docs/download-ontology/\n",
|
| 40 |
-
"go_dag = GODag(\"go.obo\")\n",
|
| 41 |
-
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"def propagate_terms(term_list):\n",
|
| 44 |
-
" full = set()\n",
|
| 45 |
-
" for t in term_list:\n",
|
| 46 |
-
" if t not in go_dag:\n",
|
| 47 |
-
" continue\n",
|
| 48 |
-
" full.add(t)\n",
|
| 49 |
-
" full.update(go_dag[t].get_all_parents())\n",
|
| 50 |
-
" return list(full & mf_terms)\n",
|
| 51 |
-
"\n",
|
| 52 |
-
"# --- 3. Carregar sequências --------------------------------------------\n",
|
| 53 |
-
"seqs, ids = [], []\n",
|
| 54 |
-
"for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
|
| 55 |
-
" ids.append(record.id)\n",
|
| 56 |
-
" seqs.append(str(record.seq))\n",
|
| 57 |
-
"\n",
|
| 58 |
-
"seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
|
| 61 |
-
"grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
|
| 62 |
-
"data = seq_df.merge(grouped, on=\"protein_id\")\n",
|
| 63 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 64 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
|
| 65 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"# --- 5. Filtrar GO terms raros -----------------------------------------\n",
|
| 68 |
-
"# todos os terms com menos de 50 proteinas associadas\n",
|
| 69 |
-
"all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
|
| 70 |
-
"term_counts = Counter(all_terms)\n",
|
| 71 |
-
"valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
|
| 72 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
|
| 73 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"# --- 6. Preparar dataset final -----------------------------------------\n",
|
| 76 |
-
"data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
|
| 77 |
-
"data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
|
| 78 |
-
"\n",
|
| 79 |
-
"# --- 7. Binarizar labels e dividir -------------------------------------\n",
|
| 80 |
-
"mlb = MultiLabelBinarizer()\n",
|
| 81 |
-
"Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
|
| 82 |
-
"X = data[[\"protein_id\", \"sequence\"]].values\n",
|
| 83 |
-
"\n",
|
| 84 |
-
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
|
| 85 |
-
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
|
| 86 |
-
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
|
| 87 |
-
"\n",
|
| 88 |
-
"df_train = data.iloc[train_idx].copy()\n",
|
| 89 |
-
"df_val = data.iloc[val_idx].copy()\n",
|
| 90 |
-
"df_test = data.iloc[test_idx].copy()\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"# --- 8. Guardar em CSV -------------------------------------------------\n",
|
| 93 |
-
"os.makedirs(\"data\", exist_ok=True)\n",
|
| 94 |
-
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
|
| 95 |
-
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
|
| 96 |
-
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- 9. Confirmar ------------------------------------------------------\n",
|
| 99 |
-
"print(\"✓ Ficheiros criados:\")\n",
|
| 100 |
-
"print(\" - data/mf-training.csv :\", df_train.shape)\n",
|
| 101 |
-
"print(\" - data/mf-validation.csv:\", df_val.shape)\n",
|
| 102 |
-
"print(\" - data/mf-test.csv :\", df_test.shape)\n",
|
| 103 |
-
"print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
|
| 104 |
-
]
|
| 105 |
-
},
|
| 106 |
-
{
|
| 107 |
-
"cell_type": "code",
|
| 108 |
-
"execution_count": 2,
|
| 109 |
-
"id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
|
| 110 |
-
"metadata": {},
|
| 111 |
-
"outputs": [
|
| 112 |
-
{
|
| 113 |
-
"name": "stderr",
|
| 114 |
-
"output_type": "stream",
|
| 115 |
-
"text": [
|
| 116 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 117 |
-
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 118 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 119 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 120 |
-
"100%|██████████| 31142/31142 [00:24<00:00, 1262.18it/s]\n",
|
| 121 |
-
"100%|██████████| 1724/1724 [00:00<00:00, 2628.24it/s]\n",
|
| 122 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
|
| 123 |
-
" warnings.warn(\n"
|
| 124 |
-
]
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"name": "stdout",
|
| 128 |
-
"output_type": "stream",
|
| 129 |
-
"text": [
|
| 130 |
-
"preprocessing train...\n",
|
| 131 |
-
"language: de\n",
|
| 132 |
-
"train sequence lengths:\n",
|
| 133 |
-
"\tmean : 423\n",
|
| 134 |
-
"\t95percentile : 604\n",
|
| 135 |
-
"\t99percentile : 715\n"
|
| 136 |
-
]
|
| 137 |
-
},
|
| 138 |
-
{
|
| 139 |
-
"data": {
|
| 140 |
-
"text/html": [
|
| 141 |
-
"\n",
|
| 142 |
-
"<style>\n",
|
| 143 |
-
" /* Turns off some styling */\n",
|
| 144 |
-
" progress {\n",
|
| 145 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 146 |
-
" border: none;\n",
|
| 147 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 148 |
-
" background-size: auto;\n",
|
| 149 |
-
" }\n",
|
| 150 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 151 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 152 |
-
" }\n",
|
| 153 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 154 |
-
" background: #F44336;\n",
|
| 155 |
-
" }\n",
|
| 156 |
-
"</style>\n"
|
| 157 |
-
],
|
| 158 |
-
"text/plain": [
|
| 159 |
-
"<IPython.core.display.HTML object>"
|
| 160 |
-
]
|
| 161 |
-
},
|
| 162 |
-
"metadata": {},
|
| 163 |
-
"output_type": "display_data"
|
| 164 |
-
},
|
| 165 |
-
{
|
| 166 |
-
"data": {
|
| 167 |
-
"text/html": [],
|
| 168 |
-
"text/plain": [
|
| 169 |
-
"<IPython.core.display.HTML object>"
|
| 170 |
-
]
|
| 171 |
-
},
|
| 172 |
-
"metadata": {},
|
| 173 |
-
"output_type": "display_data"
|
| 174 |
-
},
|
| 175 |
-
{
|
| 176 |
-
"name": "stdout",
|
| 177 |
-
"output_type": "stream",
|
| 178 |
-
"text": [
|
| 179 |
-
"Is Multi-Label? True\n",
|
| 180 |
-
"preprocessing test...\n",
|
| 181 |
-
"language: de\n",
|
| 182 |
-
"test sequence lengths:\n",
|
| 183 |
-
"\tmean : 408\n",
|
| 184 |
-
"\t95percentile : 603\n",
|
| 185 |
-
"\t99percentile : 714\n"
|
| 186 |
-
]
|
| 187 |
-
},
|
| 188 |
-
{
|
| 189 |
-
"data": {
|
| 190 |
-
"text/html": [
|
| 191 |
-
"\n",
|
| 192 |
-
"<style>\n",
|
| 193 |
-
" /* Turns off some styling */\n",
|
| 194 |
-
" progress {\n",
|
| 195 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 196 |
-
" border: none;\n",
|
| 197 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 198 |
-
" background-size: auto;\n",
|
| 199 |
-
" }\n",
|
| 200 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 201 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 202 |
-
" }\n",
|
| 203 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 204 |
-
" background: #F44336;\n",
|
| 205 |
-
" }\n",
|
| 206 |
-
"</style>\n"
|
| 207 |
-
],
|
| 208 |
-
"text/plain": [
|
| 209 |
-
"<IPython.core.display.HTML object>"
|
| 210 |
-
]
|
| 211 |
-
},
|
| 212 |
-
"metadata": {},
|
| 213 |
-
"output_type": "display_data"
|
| 214 |
-
},
|
| 215 |
-
{
|
| 216 |
-
"data": {
|
| 217 |
-
"text/html": [],
|
| 218 |
-
"text/plain": [
|
| 219 |
-
"<IPython.core.display.HTML object>"
|
| 220 |
-
]
|
| 221 |
-
},
|
| 222 |
-
"metadata": {},
|
| 223 |
-
"output_type": "display_data"
|
| 224 |
-
},
|
| 225 |
-
{
|
| 226 |
-
"name": "stderr",
|
| 227 |
-
"output_type": "stream",
|
| 228 |
-
"text": [
|
| 229 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:1093: UserWarning: Could not load a Tensorflow version of model. (If this worked before, it might be an out-of-memory issue.) Attempting to download/load PyTorch version as TensorFlow model using from_pt=True. You will need PyTorch installed for this.\n",
|
| 230 |
-
" warnings.warn(\n"
|
| 231 |
-
]
|
| 232 |
-
},
|
| 233 |
-
{
|
| 234 |
-
"name": "stdout",
|
| 235 |
-
"output_type": "stream",
|
| 236 |
-
"text": [
|
| 237 |
-
"\n",
|
| 238 |
-
"\n",
|
| 239 |
-
"begin training using triangular learning rate policy with max lr of 1e-05...\n",
|
| 240 |
-
"Epoch 1/10\n",
|
| 241 |
-
"40995/40995 [==============================] - 13053s 318ms/step - loss: 0.0745 - binary_accuracy: 0.9866 - val_loss: 0.0582 - val_binary_accuracy: 0.9859\n",
|
| 242 |
-
"Epoch 2/10\n",
|
| 243 |
-
"40995/40995 [==============================] - 14484s 353ms/step - loss: 0.0504 - binary_accuracy: 0.9873 - val_loss: 0.0499 - val_binary_accuracy: 0.9867\n",
|
| 244 |
-
"Epoch 3/10\n",
|
| 245 |
-
"40995/40995 [==============================] - 14472s 353ms/step - loss: 0.0450 - binary_accuracy: 0.9879 - val_loss: 0.0449 - val_binary_accuracy: 0.9873\n",
|
| 246 |
-
"Epoch 4/10\n",
|
| 247 |
-
"40995/40995 [==============================] - 14445s 352ms/step - loss: 0.0407 - binary_accuracy: 0.9884 - val_loss: 0.0413 - val_binary_accuracy: 0.9878\n",
|
| 248 |
-
"Epoch 5/10\n",
|
| 249 |
-
"40995/40995 [==============================] - 12524s 305ms/step - loss: 0.0378 - binary_accuracy: 0.9888 - val_loss: 0.0394 - val_binary_accuracy: 0.9881\n",
|
| 250 |
-
"Epoch 6/10\n",
|
| 251 |
-
"40995/40995 [==============================] - 14737s 359ms/step - loss: 0.0359 - binary_accuracy: 0.9891 - val_loss: 0.0383 - val_binary_accuracy: 0.9883\n",
|
| 252 |
-
"Epoch 7/10\n",
|
| 253 |
-
"40995/40995 [==============================] - 20317s 495ms/step - loss: 0.0343 - binary_accuracy: 0.9894 - val_loss: 0.0371 - val_binary_accuracy: 0.9885\n",
|
| 254 |
-
"Epoch 8/10\n",
|
| 255 |
-
"40995/40995 [==============================] - 9073s 221ms/step - loss: 0.0331 - binary_accuracy: 0.9896 - val_loss: 0.0364 - val_binary_accuracy: 0.9887\n",
|
| 256 |
-
"Epoch 9/10\n",
|
| 257 |
-
"40995/40995 [==============================] - 9001s 219ms/step - loss: 0.0320 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
|
| 258 |
-
"Epoch 10/10\n",
|
| 259 |
-
"40995/40995 [==============================] - 8980s 219ms/step - loss: 0.0311 - binary_accuracy: 0.9900 - val_loss: 0.0356 - val_binary_accuracy: 0.9890\n"
|
| 260 |
-
]
|
| 261 |
-
},
|
| 262 |
-
{
|
| 263 |
-
"ename": "RuntimeError",
|
| 264 |
-
"evalue": "Can't decrement id ref count (unable to extend file properly)",
|
| 265 |
-
"output_type": "error",
|
| 266 |
-
"traceback": [
|
| 267 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 268 |
-
"\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)",
|
| 269 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m \u001b[43mhdf5_format\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights_to_hdf5_group\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
| 270 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\saving\\hdf5_format.py:646\u001b[0m, in \u001b[0;36msave_weights_to_hdf5_group\u001b[1;34m(f, layers)\u001b[0m\n\u001b[0;32m 645\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 646\u001b[0m param_dset[:] \u001b[38;5;241m=\u001b[39m val\n",
|
| 271 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 272 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 273 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\dataset.py:999\u001b[0m, in \u001b[0;36mDataset.__setitem__\u001b[1;34m(self, args, val)\u001b[0m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fspace \u001b[38;5;129;01min\u001b[39;00m selection\u001b[38;5;241m.\u001b[39mbroadcast(mshape):\n\u001b[1;32m--> 999\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdxpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dxpl\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 274 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 275 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 276 |
-
"File \u001b[1;32mh5py\\\\h5d.pyx:282\u001b[0m, in \u001b[0;36mh5py.h5d.DatasetID.write\u001b[1;34m()\u001b[0m\n",
|
| 277 |
-
"File \u001b[1;32mh5py\\\\_proxy.pyx:115\u001b[0m, in \u001b[0;36mh5py._proxy.dset_rw\u001b[1;34m()\u001b[0m\n",
|
| 278 |
-
"\u001b[1;31mOSError\u001b[0m: [Errno 28] Can't write data (file write failed: time = Wed May 7 10:48:36 2025\n, filename = 'mf-fine-tuned-protbert\\weights-10-0.04.hdf5', file descriptor = 4, errno = 28, error message = 'No space left on device', buf = 000002CC552FF040, total write size = 4194304, bytes this sub-write = 4194304, bytes actually written = 18446744073709551615, offset = 1180551864)",
|
| 279 |
-
"\nDuring handling of the above exception, another exception occurred:\n",
|
| 280 |
-
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 281 |
-
"Cell \u001b[1;32mIn[2], line 119\u001b[0m\n\u001b[0;32m 113\u001b[0m model \u001b[38;5;241m=\u001b[39m t\u001b[38;5;241m.\u001b[39mget_classifier()\n\u001b[0;32m 114\u001b[0m learner \u001b[38;5;241m=\u001b[39m ktrain\u001b[38;5;241m.\u001b[39mget_learner(model,\n\u001b[0;32m 115\u001b[0m train_data\u001b[38;5;241m=\u001b[39mtrn,\n\u001b[0;32m 116\u001b[0m val_data\u001b[38;5;241m=\u001b[39mval,\n\u001b[0;32m 117\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mBATCH_SIZE)\n\u001b[1;32m--> 119\u001b[0m \u001b[43mlearner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautofit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-5\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 120\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 121\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmf-fine-tuned-protbert\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
| 282 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1239\u001b[0m, in \u001b[0;36mLearner.autofit\u001b[1;34m(self, lr, epochs, early_stopping, reduce_on_plateau, reduce_factor, cycle_momentum, max_momentum, min_momentum, monitor, checkpoint_folder, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1234\u001b[0m policy \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtriangular learning rate\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1235\u001b[0m U\u001b[38;5;241m.\u001b[39mvprint(\n\u001b[0;32m 1236\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbegin training using \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m policy with max lr of \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (policy, lr),\n\u001b[0;32m 1237\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose,\n\u001b[0;32m 1238\u001b[0m )\n\u001b[1;32m-> 1239\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1240\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1241\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1242\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mearly_stopping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1243\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcheckpoint_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1244\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1245\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1246\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1248\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1249\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 1250\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m clr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miterations\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
|
| 283 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\core.py:1650\u001b[0m, in \u001b[0;36mGenLearner.fit\u001b[1;34m(self, lr, n_cycles, cycle_len, cycle_mult, lr_decay, checkpoint_folder, early_stopping, class_weight, callbacks, steps_per_epoch, verbose)\u001b[0m\n\u001b[0;32m 1648\u001b[0m warnings\u001b[38;5;241m.\u001b[39mfilterwarnings(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mignore\u001b[39m\u001b[38;5;124m\"\u001b[39m, message\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.*Check your callbacks.*\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 1649\u001b[0m fit_fn \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mfit\n\u001b[1;32m-> 1650\u001b[0m hist \u001b[38;5;241m=\u001b[39m \u001b[43mfit_fn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1651\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_data\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1652\u001b[0m \u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1653\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1654\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1655\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mworkers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mworkers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_multiprocessing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muse_multiprocessing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1663\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m sgdr \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 1664\u001b[0m hist\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m sgdr\u001b[38;5;241m.\u001b[39mhistory[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
|
| 284 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:1230\u001b[0m, in \u001b[0;36mModel.fit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[0;32m 1227\u001b[0m val_logs \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m name: val \u001b[38;5;28;01mfor\u001b[39;00m name, val \u001b[38;5;129;01min\u001b[39;00m val_logs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m 1228\u001b[0m epoch_logs\u001b[38;5;241m.\u001b[39mupdate(val_logs)\n\u001b[1;32m-> 1230\u001b[0m \u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch_logs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1231\u001b[0m training_logs \u001b[38;5;241m=\u001b[39m epoch_logs\n\u001b[0;32m 1232\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstop_training:\n",
|
| 285 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:413\u001b[0m, in \u001b[0;36mCallbackList.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 411\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_logs(logs)\n\u001b[0;32m 412\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m callback \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallbacks:\n\u001b[1;32m--> 413\u001b[0m \u001b[43mcallback\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_epoch_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 286 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1368\u001b[0m, in \u001b[0;36mModelCheckpoint.on_epoch_end\u001b[1;34m(self, epoch, logs)\u001b[0m\n\u001b[0;32m 1366\u001b[0m \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n\u001b[0;32m 1367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_freq \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m-> 1368\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlogs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 287 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\callbacks.py:1431\u001b[0m, in \u001b[0;36mModelCheckpoint._save_model\u001b[1;34m(self, epoch, batch, logs)\u001b[0m\n\u001b[0;32m 1429\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m%05d\u001b[39;00m\u001b[38;5;124m: saving model to \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (epoch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, filepath))\n\u001b[0;32m 1430\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_weights_only:\n\u001b[1;32m-> 1431\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_weights\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilepath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverwrite\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_options\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1433\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 1434\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39msave(filepath, overwrite\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, options\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_options)\n",
|
| 288 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\keras\\engine\\training.py:2252\u001b[0m, in \u001b[0;36mModel.save_weights\u001b[1;34m(self, filepath, overwrite, save_format, options)\u001b[0m\n\u001b[0;32m 2250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m save_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mh5\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 2251\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m h5py\u001b[38;5;241m.\u001b[39mFile(filepath, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m-> 2252\u001b[0m hdf5_format\u001b[38;5;241m.\u001b[39msave_weights_to_hdf5_group(f, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers)\n\u001b[0;32m 2253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 2254\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m tf\u001b[38;5;241m.\u001b[39mexecuting_eagerly():\n",
|
| 289 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 290 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 291 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:599\u001b[0m, in \u001b[0;36mFile.__exit__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 596\u001b[0m \u001b[38;5;129m@with_phil\u001b[39m\n\u001b[0;32m 597\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__exit__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs):\n\u001b[0;32m 598\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid:\n\u001b[1;32m--> 599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 292 |
-
"File \u001b[1;32m~\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\h5py\\_hl\\files.py:581\u001b[0m, in \u001b[0;36mFile.close\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 575\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mvalid:\n\u001b[0;32m 576\u001b[0m \u001b[38;5;66;03m# We have to explicitly murder all open objects related to the file\u001b[39;00m\n\u001b[0;32m 577\u001b[0m \n\u001b[0;32m 578\u001b[0m \u001b[38;5;66;03m# Close file-resident objects first, then the files.\u001b[39;00m\n\u001b[0;32m 579\u001b[0m \u001b[38;5;66;03m# Otherwise we get errors in MPI mode.\u001b[39;00m\n\u001b[0;32m 580\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39m_close_open_objects(h5f\u001b[38;5;241m.\u001b[39mOBJ_LOCAL \u001b[38;5;241m|\u001b[39m \u001b[38;5;241m~\u001b[39mh5f\u001b[38;5;241m.\u001b[39mOBJ_FILE)\n\u001b[1;32m--> 581\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_close_open_objects\u001b[49m\u001b[43m(\u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_LOCAL\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m|\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mh5f\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mOBJ_FILE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 583\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;241m.\u001b[39mclose()\n\u001b[0;32m 584\u001b[0m _objects\u001b[38;5;241m.\u001b[39mnonlocal_close()\n",
|
| 293 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 294 |
-
"File \u001b[1;32mh5py\\\\_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[1;34m()\u001b[0m\n",
|
| 295 |
-
"File \u001b[1;32mh5py\\\\h5f.pyx:355\u001b[0m, in \u001b[0;36mh5py.h5f.FileID._close_open_objects\u001b[1;34m()\u001b[0m\n",
|
| 296 |
-
"\u001b[1;31mRuntimeError\u001b[0m: Can't decrement id ref count (unable to extend file properly)"
|
| 297 |
-
]
|
| 298 |
-
}
|
| 299 |
-
],
|
| 300 |
-
"source": [
|
| 301 |
-
"import pandas as pd\n",
|
| 302 |
-
"import numpy as np\n",
|
| 303 |
-
"from tqdm import tqdm\n",
|
| 304 |
-
"import random\n",
|
| 305 |
-
"import os\n",
|
| 306 |
-
"import ktrain\n",
|
| 307 |
-
"from ktrain import text\n",
|
| 308 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 309 |
-
"\n",
|
| 310 |
-
"\n",
|
| 311 |
-
"# PAM1\n",
|
| 312 |
-
"# PAM matrix model of protein evolution\n",
|
| 313 |
-
"# DOI:10.1093/oxfordjournals.molbev.a040360\n",
|
| 314 |
-
"pam_data = {\n",
|
| 315 |
-
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
|
| 316 |
-
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
|
| 317 |
-
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
|
| 318 |
-
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
|
| 319 |
-
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
|
| 320 |
-
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
|
| 321 |
-
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
|
| 322 |
-
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
|
| 323 |
-
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
|
| 324 |
-
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
|
| 325 |
-
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
|
| 326 |
-
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
|
| 327 |
-
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
|
| 328 |
-
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
|
| 329 |
-
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
|
| 330 |
-
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
|
| 331 |
-
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
|
| 332 |
-
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
|
| 333 |
-
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
|
| 334 |
-
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
|
| 335 |
-
"}\n",
|
| 336 |
-
"pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
|
| 337 |
-
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
|
| 338 |
-
"list_amino = pam_raw.columns.tolist()\n",
|
| 339 |
-
"pam_dict = {\n",
|
| 340 |
-
" aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
|
| 341 |
-
" for aa in list_amino\n",
|
| 342 |
-
"}\n",
|
| 343 |
-
"\n",
|
| 344 |
-
"def pam1_substitution(aa):\n",
|
| 345 |
-
" if aa not in pam_dict:\n",
|
| 346 |
-
" return aa\n",
|
| 347 |
-
" subs = list(pam_dict[aa].keys())\n",
|
| 348 |
-
" probs = list(pam_dict[aa].values())\n",
|
| 349 |
-
" return np.random.choice(subs, p=probs)\n",
|
| 350 |
-
"\n",
|
| 351 |
-
"def augment_sequence(seq, sub_prob=0.05):\n",
|
| 352 |
-
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
|
| 353 |
-
"\n",
|
| 354 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 355 |
-
" if len(seq) <= win:\n",
|
| 356 |
-
" return [seq]\n",
|
| 357 |
-
" slices, start = [], 0\n",
|
| 358 |
-
" while start + win <= len(seq):\n",
|
| 359 |
-
" slices.append(seq[start:start+win])\n",
|
| 360 |
-
" start += win\n",
|
| 361 |
-
" leftover = seq[start:]\n",
|
| 362 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 363 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 364 |
-
" slices.append(extra)\n",
|
| 365 |
-
" return slices\n",
|
| 366 |
-
"\n",
|
| 367 |
-
"def generate_data(df, augment=False):\n",
|
| 368 |
-
" X, y = [], []\n",
|
| 369 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 370 |
-
" for _, row in tqdm(df.iterrows(), total=len(df)):\n",
|
| 371 |
-
" seq = row[\"sequence\"]\n",
|
| 372 |
-
" if augment:\n",
|
| 373 |
-
" seq = augment_sequence(seq)\n",
|
| 374 |
-
" seq_slices = slice_sequence(seq)\n",
|
| 375 |
-
" X.extend(seq_slices)\n",
|
| 376 |
-
" lbl = row[label_cols].values.astype(int)\n",
|
| 377 |
-
" y.extend([lbl] * len(seq_slices))\n",
|
| 378 |
-
" return X, np.array(y), label_cols\n",
|
| 379 |
-
"\n",
|
| 380 |
-
"def format_sequence(seq): return \" \".join(list(seq))\n",
|
| 381 |
-
"\n",
|
| 382 |
-
"# Função para carregar e binarizar\n",
|
| 383 |
-
"def load_and_binarize(csv_path, mlb=None):\n",
|
| 384 |
-
" df = pd.read_csv(csv_path)\n",
|
| 385 |
-
" df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
|
| 386 |
-
" if mlb is None:\n",
|
| 387 |
-
" mlb = MultiLabelBinarizer()\n",
|
| 388 |
-
" labels = mlb.fit_transform(df[\"go_terms\"])\n",
|
| 389 |
-
" else:\n",
|
| 390 |
-
" labels = mlb.transform(df[\"go_terms\"])\n",
|
| 391 |
-
" labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
|
| 392 |
-
" df = df.reset_index(drop=True).join(labels_df)\n",
|
| 393 |
-
" return df, mlb\n",
|
| 394 |
-
"\n",
|
| 395 |
-
"# Carregar os dados\n",
|
| 396 |
-
"df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
|
| 397 |
-
"df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
|
| 398 |
-
"\n",
|
| 399 |
-
"# Gerar com augmentation no treino\n",
|
| 400 |
-
"X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
|
| 401 |
-
"X_val, y_val, _ = generate_data(df_val, augment=False)\n",
|
| 402 |
-
"\n",
|
| 403 |
-
"# Preparar texto para tokenizer\n",
|
| 404 |
-
"X_train_fmt = list(map(format_sequence, X_train))\n",
|
| 405 |
-
"X_val_fmt = list(map(format_sequence, X_val))\n",
|
| 406 |
-
"\n",
|
| 407 |
-
"# Fine-tune ProtBERT\n",
|
| 408 |
-
"# https://huggingface.co/Rostlab/prot_bert\n",
|
| 409 |
-
"# https://doi.org/10.1093/bioinformatics/btac020\n",
|
| 410 |
-
"# dados de treino-> UniRef100 (216 milhões de sequências)\n",
|
| 411 |
-
"MODEL_NAME = \"Rostlab/prot_bert\"\n",
|
| 412 |
-
"MAX_LEN = 512\n",
|
| 413 |
-
"BATCH_SIZE = 1\n",
|
| 414 |
-
"\n",
|
| 415 |
-
"t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
|
| 416 |
-
"trn = t.preprocess_train(X_train_fmt, y_train)\n",
|
| 417 |
-
"val = t.preprocess_test(X_val_fmt, y_val)\n",
|
| 418 |
-
"\n",
|
| 419 |
-
"model = t.get_classifier()\n",
|
| 420 |
-
"learner = ktrain.get_learner(model,\n",
|
| 421 |
-
" train_data=trn,\n",
|
| 422 |
-
" val_data=val,\n",
|
| 423 |
-
" batch_size=BATCH_SIZE)\n",
|
| 424 |
-
"\n",
|
| 425 |
-
"learner.autofit(lr=1e-5,\n",
|
| 426 |
-
" epochs=10,\n",
|
| 427 |
-
" early_stopping=1,\n",
|
| 428 |
-
" checkpoint_folder=\"mf-fine-tuned-protbert\")\n"
|
| 429 |
-
]
|
| 430 |
-
},
|
| 431 |
-
{
|
| 432 |
-
"cell_type": "code",
|
| 433 |
-
"execution_count": 7,
|
| 434 |
-
"id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
|
| 435 |
-
"metadata": {},
|
| 436 |
-
"outputs": [
|
| 437 |
-
{
|
| 438 |
-
"name": "stdout",
|
| 439 |
-
"output_type": "stream",
|
| 440 |
-
"text": [
|
| 441 |
-
"✅ Existe: weights/mf-fine-tuned-protbert-epoch10\n",
|
| 442 |
-
"📁 Conteúdo:\n",
|
| 443 |
-
" - config.json\n",
|
| 444 |
-
" - tf_model.h5\n"
|
| 445 |
-
]
|
| 446 |
-
}
|
| 447 |
-
],
|
| 448 |
-
"source": [
|
| 449 |
-
"import os\n",
|
| 450 |
-
"\n",
|
| 451 |
-
"path = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
|
| 452 |
-
"\n",
|
| 453 |
-
"if os.path.exists(path):\n",
|
| 454 |
-
" print(f\"✅ Existe: {path}\")\n",
|
| 455 |
-
" print(\"📁 Conteúdo:\")\n",
|
| 456 |
-
" for f in os.listdir(path):\n",
|
| 457 |
-
" print(\" -\", f)\n",
|
| 458 |
-
"else:\n",
|
| 459 |
-
" print(f\"❌ Não existe: {path}\")\n",
|
| 460 |
-
"\n"
|
| 461 |
-
]
|
| 462 |
-
},
|
| 463 |
-
{
|
| 464 |
-
"cell_type": "code",
|
| 465 |
-
"execution_count": 19,
|
| 466 |
-
"id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
|
| 467 |
-
"metadata": {},
|
| 468 |
-
"outputs": [
|
| 469 |
-
{
|
| 470 |
-
"name": "stderr",
|
| 471 |
-
"output_type": "stream",
|
| 472 |
-
"text": [
|
| 473 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 474 |
-
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 475 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 476 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 477 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
| 478 |
-
" warnings.warn(\n",
|
| 479 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:309: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 480 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 481 |
-
"Some layers from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10 were not used when initializing TFBertModel: ['classifier', 'dropout_183']\n",
|
| 482 |
-
"- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
| 483 |
-
"- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
| 484 |
-
"All the layers of TFBertModel were initialized from the model checkpoint at weights/mf-fine-tuned-protbert-epoch10.\n",
|
| 485 |
-
"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
|
| 486 |
-
]
|
| 487 |
-
},
|
| 488 |
-
{
|
| 489 |
-
"name": "stdout",
|
| 490 |
-
"output_type": "stream",
|
| 491 |
-
"text": [
|
| 492 |
-
"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
|
| 493 |
-
]
|
| 494 |
-
},
|
| 495 |
-
{
|
| 496 |
-
"name": "stderr",
|
| 497 |
-
"output_type": "stream",
|
| 498 |
-
"text": [
|
| 499 |
-
"Processando data/mf-training.csv: 0%| | 25/31142 [00:06<2:23:28, 3.61it/s]\n"
|
| 500 |
-
]
|
| 501 |
-
},
|
| 502 |
-
{
|
| 503 |
-
"ename": "KeyboardInterrupt",
|
| 504 |
-
"evalue": "",
|
| 505 |
-
"output_type": "error",
|
| 506 |
-
"traceback": [
|
| 507 |
-
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
| 508 |
-
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 509 |
-
"Cell \u001b[1;32mIn[19], line 78\u001b[0m\n\u001b[0;32m 75\u001b[0m \u001b[38;5;66;03m# --- 4. Aplicar -----------------------------------------------------------\u001b[39;00m\n\u001b[0;32m 76\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(OUT_DIR, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m---> 78\u001b[0m \u001b[43mprocess_split\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata/mf-training.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mOUT_DIR\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain_protbert.pkl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 79\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-validation.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[0;32m 80\u001b[0m process_split(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mf-test.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(OUT_DIR, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_protbert.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
|
| 510 |
-
"Cell \u001b[1;32mIn[19], line 61\u001b[0m, in \u001b[0;36mprocess_split\u001b[1;34m(csv_path, out_path)\u001b[0m\n\u001b[0;32m 59\u001b[0m embeds\u001b[38;5;241m.\u001b[39mappend(prot_embed\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32))\n\u001b[0;32m 60\u001b[0m labels\u001b[38;5;241m.\u001b[39mappend(row[label_cols]\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mint8))\n\u001b[1;32m---> 61\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[0;32m 63\u001b[0m embeds \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(embeds)\n\u001b[0;32m 64\u001b[0m labels \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mvstack(labels)\n",
|
| 511 |
-
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
| 512 |
-
]
|
| 513 |
-
}
|
| 514 |
-
],
|
| 515 |
-
"source": [
|
| 516 |
-
"import os\n",
|
| 517 |
-
"import pandas as pd\n",
|
| 518 |
-
"import numpy as np\n",
|
| 519 |
-
"from tqdm import tqdm\n",
|
| 520 |
-
"import joblib\n",
|
| 521 |
-
"import gc\n",
|
| 522 |
-
"from transformers import AutoTokenizer, TFAutoModel\n",
|
| 523 |
-
"\n",
|
| 524 |
-
"# --- 1. Parâmetros --------------------------------------------------------\n",
|
| 525 |
-
"MODEL_DIR = \"weights/mf-fine-tuned-protbert-epoch10\"\n",
|
| 526 |
-
"BASE_MODEL = \"Rostlab/prot_bert\"\n",
|
| 527 |
-
"OUT_DIR = \"embeddings\"\n",
|
| 528 |
-
"BATCH_TOK = 16\n",
|
| 529 |
-
"\n",
|
| 530 |
-
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, do_lower_case=False)\n",
|
| 531 |
-
"model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
|
| 532 |
-
"\n",
|
| 533 |
-
"print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
|
| 534 |
-
"\n",
|
| 535 |
-
"# --- 3. Funções auxiliares ------------------------------------------------\n",
|
| 536 |
-
"def format_sequence(seq):\n",
|
| 537 |
-
" return \" \".join(list(seq))\n",
|
| 538 |
-
"\n",
|
| 539 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 540 |
-
" if len(seq) <= win:\n",
|
| 541 |
-
" return [seq]\n",
|
| 542 |
-
" slices, start = [], 0\n",
|
| 543 |
-
" while start + win <= len(seq):\n",
|
| 544 |
-
" slices.append(seq[start:start+win])\n",
|
| 545 |
-
" start += win\n",
|
| 546 |
-
" leftover = seq[start:]\n",
|
| 547 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 548 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 549 |
-
" slices.append(extra)\n",
|
| 550 |
-
" return slices\n",
|
| 551 |
-
"\n",
|
| 552 |
-
"def get_embeddings(batch, tokenizer, model):\n",
|
| 553 |
-
" tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
|
| 554 |
-
" output = model(**tokens)\n",
|
| 555 |
-
" return output.last_hidden_state[:, 0, :].numpy()\n",
|
| 556 |
-
"\n",
|
| 557 |
-
"def process_split(csv_path, out_path):\n",
|
| 558 |
-
" df = pd.read_csv(csv_path)\n",
|
| 559 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 560 |
-
" prot_ids, embeds, labels = [], [], []\n",
|
| 561 |
-
"\n",
|
| 562 |
-
" for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
|
| 563 |
-
" slices = slice_sequence(row[\"sequence\"])\n",
|
| 564 |
-
" slices_fmt = list(map(format_sequence, slices))\n",
|
| 565 |
-
"\n",
|
| 566 |
-
" slice_embeds = []\n",
|
| 567 |
-
" for i in range(0, len(slices_fmt), BATCH_TOK):\n",
|
| 568 |
-
" batch = slices_fmt[i:i+BATCH_TOK]\n",
|
| 569 |
-
" slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
|
| 570 |
-
" slice_embeds = np.vstack(slice_embeds)\n",
|
| 571 |
-
"\n",
|
| 572 |
-
" prot_embed = slice_embeds.mean(axis=0)\n",
|
| 573 |
-
" prot_ids.append(row[\"protein_id\"])\n",
|
| 574 |
-
" embeds.append(prot_embed.astype(np.float32))\n",
|
| 575 |
-
" labels.append(row[label_cols].values.astype(np.int8))\n",
|
| 576 |
-
" gc.collect()\n",
|
| 577 |
-
"\n",
|
| 578 |
-
" embeds = np.vstack(embeds)\n",
|
| 579 |
-
" labels = np.vstack(labels)\n",
|
| 580 |
-
"\n",
|
| 581 |
-
" joblib.dump({\n",
|
| 582 |
-
" \"protein_ids\": prot_ids,\n",
|
| 583 |
-
" \"embeddings\": embeds,\n",
|
| 584 |
-
" \"labels\": labels,\n",
|
| 585 |
-
" \"go_terms\": label_cols\n",
|
| 586 |
-
" }, out_path, compress=3)\n",
|
| 587 |
-
"\n",
|
| 588 |
-
" print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
|
| 589 |
-
"\n",
|
| 590 |
-
"# --- 4. Aplicar -----------------------------------------------------------\n",
|
| 591 |
-
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbert.pkl\"))\n",
|
| 594 |
-
"process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbert.pkl\"))\n",
|
| 595 |
-
"process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbert.pkl\"))\n"
|
| 596 |
-
]
|
| 597 |
-
},
|
| 598 |
-
{
|
| 599 |
-
"cell_type": "code",
|
| 600 |
-
"execution_count": 27,
|
| 601 |
-
"id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
|
| 602 |
-
"metadata": {},
|
| 603 |
-
"outputs": [
|
| 604 |
-
{
|
| 605 |
-
"name": "stdout",
|
| 606 |
-
"output_type": "stream",
|
| 607 |
-
"text": [
|
| 608 |
-
"✓ Corrigido: embeddings/train_protbert.pkl — 31142 exemplos, 597 GO terms\n",
|
| 609 |
-
"✓ Corrigido: embeddings/val_protbert.pkl — 1724 exemplos, 597 GO terms\n",
|
| 610 |
-
"✓ Corrigido: embeddings/test_protbert.pkl — 1724 exemplos, 597 GO terms\n"
|
| 611 |
-
]
|
| 612 |
-
}
|
| 613 |
-
],
|
| 614 |
-
"source": [
|
| 615 |
-
"import pandas as pd\n",
|
| 616 |
-
"import joblib\n",
|
| 617 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 618 |
-
"\n",
|
| 619 |
-
"# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
|
| 620 |
-
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
|
| 621 |
-
"test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
|
| 622 |
-
"\n",
|
| 623 |
-
"# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
|
| 624 |
-
"def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
|
| 625 |
-
" df = pd.read_csv(csv_path)\n",
|
| 626 |
-
" terms_split = df[\"go_terms\"].str.split(\";\")\n",
|
| 627 |
-
" \n",
|
| 628 |
-
" # Apenas termos presentes nos common_terms\n",
|
| 629 |
-
" terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
|
| 630 |
-
" \n",
|
| 631 |
-
" mlb = MultiLabelBinarizer(classes=common_terms)\n",
|
| 632 |
-
" Y = mlb.fit_transform(terms_filtered)\n",
|
| 633 |
-
"\n",
|
| 634 |
-
" data = joblib.load(pkl_path)\n",
|
| 635 |
-
" data[\"labels\"] = Y\n",
|
| 636 |
-
" data[\"go_terms\"] = mlb.classes_.tolist()\n",
|
| 637 |
-
" \n",
|
| 638 |
-
" joblib.dump(data, pkl_path, compress=3)\n",
|
| 639 |
-
" print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
|
| 640 |
-
"\n",
|
| 641 |
-
"# --- 3. Aplicar às 3 partições --------------------------------------------\n",
|
| 642 |
-
"patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbert.pkl\", test_terms)\n",
|
| 643 |
-
"patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbert.pkl\", test_terms)\n",
|
| 644 |
-
"patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbert.pkl\", test_terms)\n"
|
| 645 |
-
]
|
| 646 |
-
},
|
| 647 |
-
{
|
| 648 |
-
"cell_type": "code",
|
| 649 |
-
"execution_count": 2,
|
| 650 |
-
"id": "e01950ba-aaea-4403-9b37-11e653cfa6de",
|
| 651 |
-
"metadata": {},
|
| 652 |
-
"outputs": [
|
| 653 |
-
{
|
| 654 |
-
"name": "stdout",
|
| 655 |
-
"output_type": "stream",
|
| 656 |
-
"text": [
|
| 657 |
-
"Collecting tensorflow==2.13.1\n",
|
| 658 |
-
" Downloading tensorflow-2.13.1-cp38-cp38-win_amd64.whl.metadata (2.6 kB)\n",
|
| 659 |
-
"Note: you may need to restart the kernel to use updated packages.\n"
|
| 660 |
-
]
|
| 661 |
-
},
|
| 662 |
-
{
|
| 663 |
-
"name": "stderr",
|
| 664 |
-
"output_type": "stream",
|
| 665 |
-
"text": [
|
| 666 |
-
"ERROR: Ignored the following versions that require a different python version: 2.14.0 Requires-Python >=3.9; 2.14.0rc0 Requires-Python >=3.9; 3.0.0 Requires-Python >=3.9; 3.0.1 Requires-Python >=3.9; 3.0.2 Requires-Python >=3.9; 3.0.3 Requires-Python >=3.9; 3.0.4 Requires-Python >=3.9; 3.0.5 Requires-Python >=3.9; 3.1.0 Requires-Python >=3.9; 3.1.1 Requires-Python >=3.9; 3.10.0 Requires-Python >=3.9; 3.2.0 Requires-Python >=3.9; 3.2.1 Requires-Python >=3.9; 3.3.0 Requires-Python >=3.9; 3.3.1 Requires-Python >=3.9; 3.3.2 Requires-Python >=3.9; 3.3.3 Requires-Python >=3.9; 3.4.0 Requires-Python >=3.9; 3.4.1 Requires-Python >=3.9; 3.5.0 Requires-Python >=3.9; 3.6.0 Requires-Python >=3.9; 3.7.0 Requires-Python >=3.9; 3.8.0 Requires-Python >=3.9; 3.9.0 Requires-Python >=3.9; 3.9.1 Requires-Python >=3.9; 3.9.2 Requires-Python >=3.9\n",
|
| 667 |
-
"ERROR: Could not find a version that satisfies the requirement keras==3.0.5 (from versions: 0.2.0, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 1.0.0, 1.0.1, 1.0.2, 1.0.3, 1.0.4, 1.0.5, 1.0.6, 1.0.7, 1.0.8, 1.1.0, 1.1.1, 1.1.2, 1.2.0, 1.2.1, 1.2.2, 2.0.0, 2.0.1, 2.0.2, 2.0.3, 2.0.4, 2.0.5, 2.0.6, 2.0.7, 2.0.8, 2.0.9, 2.1.0, 2.1.1, 2.1.2, 2.1.3, 2.1.4, 2.1.5, 2.1.6, 2.2.0, 2.2.1, 2.2.2, 2.2.3, 2.2.4, 2.2.5, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.4.2, 2.4.3, 2.5.0rc0, 2.6.0rc0, 2.6.0rc1, 2.6.0rc2, 2.6.0rc3, 2.6.0, 2.7.0rc0, 2.7.0rc2, 2.7.0, 2.8.0rc0, 2.8.0rc1, 2.8.0, 2.9.0rc0, 2.9.0rc1, 2.9.0rc2, 2.9.0, 2.10.0rc0, 2.10.0rc1, 2.10.0, 2.11.0rc0, 2.11.0rc1, 2.11.0rc2, 2.11.0rc3, 2.11.0, 2.12.0rc0, 2.12.0rc1, 2.12.0, 2.13.1rc0, 2.13.1rc1, 2.13.1, 2.15.0rc0, 2.15.0rc1, 2.15.0)\n",
|
| 668 |
-
"ERROR: No matching distribution found for keras==3.0.5\n"
|
| 669 |
-
]
|
| 670 |
-
}
|
| 671 |
-
],
|
| 672 |
-
"source": [
|
| 673 |
-
"pip install tensorflow==2.13.1 keras==3.0.5"
|
| 674 |
-
]
|
| 675 |
-
},
|
| 676 |
-
{
|
| 677 |
-
"cell_type": "code",
|
| 678 |
-
"execution_count": 1,
|
| 679 |
-
"id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
|
| 680 |
-
"metadata": {},
|
| 681 |
-
"outputs": [],
|
| 682 |
-
"source": [
|
| 683 |
-
"import joblib\n",
|
| 684 |
-
"train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
|
| 685 |
-
"val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
|
| 686 |
-
"test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 687 |
-
"\n",
|
| 688 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 689 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 690 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
|
| 691 |
-
]
|
| 692 |
-
},
|
| 693 |
-
{
|
| 694 |
-
"cell_type": "code",
|
| 695 |
-
"execution_count": 2,
|
| 696 |
-
"id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
|
| 697 |
-
"metadata": {},
|
| 698 |
-
"outputs": [
|
| 699 |
-
{
|
| 700 |
-
"name": "stderr",
|
| 701 |
-
"output_type": "stream",
|
| 702 |
-
"text": [
|
| 703 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\protein_prediction_env\\lib\\site-packages\\requests\\__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).\n",
|
| 704 |
-
" warnings.warn(\n"
|
| 705 |
-
]
|
| 706 |
-
},
|
| 707 |
-
{
|
| 708 |
-
"name": "stdout",
|
| 709 |
-
"output_type": "stream",
|
| 710 |
-
"text": [
|
| 711 |
-
"✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
|
| 712 |
-
"Epoch 1/100\n",
|
| 713 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0357 - binary_accuracy: 0.9894 - val_loss: 0.0336 - val_binary_accuracy: 0.9901\n",
|
| 714 |
-
"Epoch 2/100\n",
|
| 715 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0276 - binary_accuracy: 0.9914 - val_loss: 0.0329 - val_binary_accuracy: 0.9903\n",
|
| 716 |
-
"Epoch 3/100\n",
|
| 717 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0268 - binary_accuracy: 0.9915 - val_loss: 0.0330 - val_binary_accuracy: 0.9904\n",
|
| 718 |
-
"Epoch 4/100\n",
|
| 719 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0263 - binary_accuracy: 0.9917 - val_loss: 0.0324 - val_binary_accuracy: 0.9903\n",
|
| 720 |
-
"Epoch 5/100\n",
|
| 721 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0260 - binary_accuracy: 0.9917 - val_loss: 0.0319 - val_binary_accuracy: 0.9905\n",
|
| 722 |
-
"Epoch 6/100\n",
|
| 723 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0257 - binary_accuracy: 0.9918 - val_loss: 0.0325 - val_binary_accuracy: 0.9903\n",
|
| 724 |
-
"Epoch 7/100\n",
|
| 725 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0255 - binary_accuracy: 0.9918 - val_loss: 0.0318 - val_binary_accuracy: 0.9905\n",
|
| 726 |
-
"Epoch 8/100\n",
|
| 727 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0253 - binary_accuracy: 0.9919 - val_loss: 0.0321 - val_binary_accuracy: 0.9905\n",
|
| 728 |
-
"Epoch 9/100\n",
|
| 729 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0251 - binary_accuracy: 0.9919 - val_loss: 0.0316 - val_binary_accuracy: 0.9905\n",
|
| 730 |
-
"Epoch 10/100\n",
|
| 731 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0250 - binary_accuracy: 0.9919 - val_loss: 0.0320 - val_binary_accuracy: 0.9906\n",
|
| 732 |
-
"Epoch 11/100\n",
|
| 733 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0247 - binary_accuracy: 0.9920 - val_loss: 0.0318 - val_binary_accuracy: 0.9905\n",
|
| 734 |
-
"Epoch 12/100\n",
|
| 735 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0247 - binary_accuracy: 0.9920 - val_loss: 0.0317 - val_binary_accuracy: 0.9906\n",
|
| 736 |
-
"Epoch 13/100\n",
|
| 737 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0246 - binary_accuracy: 0.9920 - val_loss: 0.0316 - val_binary_accuracy: 0.9906\n",
|
| 738 |
-
"Epoch 14/100\n",
|
| 739 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0245 - binary_accuracy: 0.9920 - val_loss: 0.0318 - val_binary_accuracy: 0.9906\n",
|
| 740 |
-
"54/54 [==============================] - 0s 2ms/step\n",
|
| 741 |
-
"Previsões guardadas em mf-protbert-pam1.npy\n",
|
| 742 |
-
"Modelo guardado em models/mlp_protbert.keras\n"
|
| 743 |
-
]
|
| 744 |
-
}
|
| 745 |
-
],
|
| 746 |
-
"source": [
|
| 747 |
-
"import tensorflow as tf\n",
|
| 748 |
-
"import joblib\n",
|
| 749 |
-
"import numpy as np\n",
|
| 750 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 751 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 752 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 753 |
-
"\n",
|
| 754 |
-
"# --- 1. Carregar embeddings ----------------------------------------------\n",
|
| 755 |
-
"train = joblib.load(\"embeddings/train_protbert.pkl\")\n",
|
| 756 |
-
"val = joblib.load(\"embeddings/val_protbert.pkl\")\n",
|
| 757 |
-
"test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 758 |
-
"\n",
|
| 759 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 760 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 761 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
|
| 762 |
-
"\n",
|
| 763 |
-
"print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
|
| 764 |
-
"\n",
|
| 765 |
-
"# --- 2. Garantir consistência de classes ---------------------------------\n",
|
| 766 |
-
"max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
|
| 767 |
-
"\n",
|
| 768 |
-
"def pad_labels(y, target_dim=max_classes):\n",
|
| 769 |
-
" if y.shape[1] < target_dim:\n",
|
| 770 |
-
" padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
|
| 771 |
-
" return np.hstack([y, padding])\n",
|
| 772 |
-
" return y\n",
|
| 773 |
-
"\n",
|
| 774 |
-
"y_val = pad_labels(y_val)\n",
|
| 775 |
-
"y_test = pad_labels(y_test)\n",
|
| 776 |
-
"\n",
|
| 777 |
-
"# --- 3. Modelo MLP ------------------------------------------------------\n",
|
| 778 |
-
"model = Sequential([\n",
|
| 779 |
-
" Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
|
| 780 |
-
" Dropout(0.3),\n",
|
| 781 |
-
" Dense(512, activation=\"relu\"),\n",
|
| 782 |
-
" Dropout(0.3),\n",
|
| 783 |
-
" Dense(max_classes, activation=\"sigmoid\")\n",
|
| 784 |
-
"])\n",
|
| 785 |
-
"\n",
|
| 786 |
-
"model.compile(loss=\"binary_crossentropy\",\n",
|
| 787 |
-
" optimizer=\"adam\",\n",
|
| 788 |
-
" metrics=[\"binary_accuracy\"])\n",
|
| 789 |
-
"\n",
|
| 790 |
-
"# --- 4. Early stopping e treino -----------------------------------------\n",
|
| 791 |
-
"callbacks = [\n",
|
| 792 |
-
" EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
|
| 793 |
-
"]\n",
|
| 794 |
-
"\n",
|
| 795 |
-
"model.fit(X_train, y_train,\n",
|
| 796 |
-
" validation_data=(X_val, y_val),\n",
|
| 797 |
-
" epochs=100,\n",
|
| 798 |
-
" batch_size=32,\n",
|
| 799 |
-
" callbacks=callbacks,\n",
|
| 800 |
-
" verbose=1)\n",
|
| 801 |
-
"\n",
|
| 802 |
-
"# --- 5. Previsões --------------------------------------------------------\n",
|
| 803 |
-
"y_prob = model.predict(X_test)\n",
|
| 804 |
-
"np.save(\"predictions/mf-protbert-pam1.npy\", y_prob)\n",
|
| 805 |
-
"print(\"Previsões guardadas em mf-protbert-pam1.npy\")\n",
|
| 806 |
-
"\n",
|
| 807 |
-
"# --- 6. Modelo ----------------------------------------------------------\n",
|
| 808 |
-
"model.save(\"models/mlp_protbert.keras\")\n",
|
| 809 |
-
"print(\"Modelo guardado em models/mlp_protbert.keras\")"
|
| 810 |
-
]
|
| 811 |
-
},
|
| 812 |
-
{
|
| 813 |
-
"cell_type": "code",
|
| 814 |
-
"execution_count": 30,
|
| 815 |
-
"id": "fdb66630-76dc-43a0-bd56-45052175fdba",
|
| 816 |
-
"metadata": {},
|
| 817 |
-
"outputs": [
|
| 818 |
-
{
|
| 819 |
-
"name": "stdout",
|
| 820 |
-
"output_type": "stream",
|
| 821 |
-
"text": [
|
| 822 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 823 |
-
"✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
|
| 824 |
-
"\n",
|
| 825 |
-
"📊 Resultados finais (ProtBERT + PAM1 + propagação):\n",
|
| 826 |
-
"Fmax = 0.6666\n",
|
| 827 |
-
"Thr. = 0.50\n",
|
| 828 |
-
"AuPRC = 0.7028\n",
|
| 829 |
-
"Smin = 13.1745\n"
|
| 830 |
-
]
|
| 831 |
-
}
|
| 832 |
-
],
|
| 833 |
-
"source": [
|
| 834 |
-
"import numpy as np\n",
|
| 835 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 836 |
-
"from goatools.obo_parser import GODag\n",
|
| 837 |
-
"import joblib\n",
|
| 838 |
-
"import math\n",
|
| 839 |
-
"\n",
|
| 840 |
-
"# --- 1. Parâmetros -------------------------------------------------------\n",
|
| 841 |
-
"GO_FILE = \"go.obo\"\n",
|
| 842 |
-
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
|
| 843 |
-
"ALPHA = 0.5\n",
|
| 844 |
-
"\n",
|
| 845 |
-
"# --- 2. Carregar dados ---------------------------------------------------\n",
|
| 846 |
-
"test = joblib.load(\"embeddings/test_protbert.pkl\")\n",
|
| 847 |
-
"y_true = test[\"labels\"]\n",
|
| 848 |
-
"terms = test[\"go_terms\"]\n",
|
| 849 |
-
"y_prob = np.load(\"predictions/mf-protbert-pam1.npy\")\n",
|
| 850 |
-
"go_dag = GODag(GO_FILE)\n",
|
| 851 |
-
"\n",
|
| 852 |
-
"print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
|
| 853 |
-
"\n",
|
| 854 |
-
"# --- 3. Fmax -------------------------------------------------------------\n",
|
| 855 |
-
"def compute_fmax(y_true, y_prob, thresholds):\n",
|
| 856 |
-
" fmax, best_thr = 0, 0\n",
|
| 857 |
-
" for t in thresholds:\n",
|
| 858 |
-
" y_pred = (y_prob >= t).astype(int)\n",
|
| 859 |
-
" tp = (y_true * y_pred).sum(axis=1)\n",
|
| 860 |
-
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
|
| 861 |
-
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
|
| 862 |
-
" precision = tp / (tp + fp + 1e-8)\n",
|
| 863 |
-
" recall = tp / (tp + fn + 1e-8)\n",
|
| 864 |
-
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
|
| 865 |
-
" avg_f1 = np.mean(f1)\n",
|
| 866 |
-
" if avg_f1 > fmax:\n",
|
| 867 |
-
" fmax, best_thr = avg_f1, t\n",
|
| 868 |
-
" return fmax, best_thr\n",
|
| 869 |
-
"\n",
|
| 870 |
-
"# --- 4. AuPRC micro ------------------------------------------------------\n",
|
| 871 |
-
"def compute_auprc(y_true, y_prob):\n",
|
| 872 |
-
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
|
| 873 |
-
" return auc(recall, precision)\n",
|
| 874 |
-
"\n",
|
| 875 |
-
"# --- 5. Smin -------------------------------------------------------------\n",
|
| 876 |
-
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
|
| 877 |
-
" y_pred = (y_prob >= threshold).astype(int)\n",
|
| 878 |
-
" ic = {}\n",
|
| 879 |
-
" total = (y_true + y_pred).sum(axis=0).sum()\n",
|
| 880 |
-
" for i, term in enumerate(terms):\n",
|
| 881 |
-
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
|
| 882 |
-
" ic[term] = -np.log((freq + 1e-8) / total)\n",
|
| 883 |
-
"\n",
|
| 884 |
-
" s_values = []\n",
|
| 885 |
-
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
|
| 886 |
-
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
|
| 887 |
-
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
|
| 888 |
-
"\n",
|
| 889 |
-
" anc_true = set()\n",
|
| 890 |
-
" for t in true_terms:\n",
|
| 891 |
-
" if t in go_dag:\n",
|
| 892 |
-
" anc_true |= go_dag[t].get_all_parents()\n",
|
| 893 |
-
" anc_pred = set()\n",
|
| 894 |
-
" for t in pred_terms:\n",
|
| 895 |
-
" if t in go_dag:\n",
|
| 896 |
-
" anc_pred |= go_dag[t].get_all_parents()\n",
|
| 897 |
-
"\n",
|
| 898 |
-
" ru = pred_terms - true_terms\n",
|
| 899 |
-
" mi = true_terms - pred_terms\n",
|
| 900 |
-
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
|
| 901 |
-
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
|
| 902 |
-
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
|
| 903 |
-
" s_values.append(s)\n",
|
| 904 |
-
"\n",
|
| 905 |
-
" return np.mean(s_values)\n",
|
| 906 |
-
"\n",
|
| 907 |
-
"# --- 6. Avaliar ----------------------------------------------------------\n",
|
| 908 |
-
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
|
| 909 |
-
"auprc = compute_auprc(y_true, y_prob)\n",
|
| 910 |
-
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
|
| 911 |
-
"\n",
|
| 912 |
-
"print(f\"\\n📊 Resultados finais (ProtBERT + PAM1 + propagação):\")\n",
|
| 913 |
-
"print(f\"Fmax = {fmax:.4f}\")\n",
|
| 914 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 915 |
-
"print(f\"AuPRC = {auprc:.4f}\")\n",
|
| 916 |
-
"print(f\"Smin = {smin:.4f}\")\n"
|
| 917 |
-
]
|
| 918 |
-
},
|
| 919 |
-
{
|
| 920 |
-
"cell_type": "code",
|
| 921 |
-
"execution_count": 3,
|
| 922 |
-
"id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
|
| 923 |
-
"metadata": {},
|
| 924 |
-
"outputs": [
|
| 925 |
-
{
|
| 926 |
-
"data": {
|
| 927 |
-
"text/plain": [
|
| 928 |
-
"['data/mlb_protbert.pkl']"
|
| 929 |
-
]
|
| 930 |
-
},
|
| 931 |
-
"execution_count": 3,
|
| 932 |
-
"metadata": {},
|
| 933 |
-
"output_type": "execute_result"
|
| 934 |
-
}
|
| 935 |
-
],
|
| 936 |
-
"source": [
|
| 937 |
-
"import joblib, pickle\n",
|
| 938 |
-
"joblib.dump(mlb, \"data/mlb_protbert.pkl\")"
|
| 939 |
-
]
|
| 940 |
-
},
|
| 941 |
-
{
|
| 942 |
-
"cell_type": "code",
|
| 943 |
-
"execution_count": null,
|
| 944 |
-
"id": "9f89c3bc-6b78-4a4c-8ddd-b69c7d3d0e65",
|
| 945 |
-
"metadata": {},
|
| 946 |
-
"outputs": [],
|
| 947 |
-
"source": []
|
| 948 |
-
}
|
| 949 |
-
],
|
| 950 |
-
"metadata": {
|
| 951 |
-
"kernelspec": {
|
| 952 |
-
"display_name": "Python 3 (ipykernel)",
|
| 953 |
-
"language": "python",
|
| 954 |
-
"name": "python3"
|
| 955 |
-
},
|
| 956 |
-
"language_info": {
|
| 957 |
-
"codemirror_mode": {
|
| 958 |
-
"name": "ipython",
|
| 959 |
-
"version": 3
|
| 960 |
-
},
|
| 961 |
-
"file_extension": ".py",
|
| 962 |
-
"mimetype": "text/x-python",
|
| 963 |
-
"name": "python",
|
| 964 |
-
"nbconvert_exporter": "python",
|
| 965 |
-
"pygments_lexer": "ipython3",
|
| 966 |
-
"version": "3.10.16"
|
| 967 |
-
}
|
| 968 |
-
},
|
| 969 |
-
"nbformat": 4,
|
| 970 |
-
"nbformat_minor": 5
|
| 971 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
notebooks/PAM1_protbertBFD.ipynb
DELETED
|
@@ -1,872 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 1,
|
| 6 |
-
"id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 14 |
-
"✓ Ficheiros criados:\n",
|
| 15 |
-
" - data/mf-training.csv : (31142, 3)\n",
|
| 16 |
-
" - data/mf-validation.csv: (1724, 3)\n",
|
| 17 |
-
" - data/mf-test.csv : (1724, 3)\n",
|
| 18 |
-
"GO terms únicos (após propagação e filtro): 602\n"
|
| 19 |
-
]
|
| 20 |
-
}
|
| 21 |
-
],
|
| 22 |
-
"source": [
|
| 23 |
-
"import pandas as pd\n",
|
| 24 |
-
"from Bio import SeqIO\n",
|
| 25 |
-
"from collections import Counter\n",
|
| 26 |
-
"from goatools.obo_parser import GODag\n",
|
| 27 |
-
"from sklearn.model_selection import train_test_split\n",
|
| 28 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 29 |
-
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
|
| 30 |
-
"import numpy as np\n",
|
| 31 |
-
"import os\n",
|
| 32 |
-
"\n",
|
| 33 |
-
"# --- 1. Carregar GO anotações ------------------------------------------\n",
|
| 34 |
-
"annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
|
| 35 |
-
"annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"# --- 2. Carregar DAG e propagar GO terms -------------------------------\n",
|
| 38 |
-
"# propagação hierárquica\n",
|
| 39 |
-
"# https://geneontology.org/docs/download-ontology/\n",
|
| 40 |
-
"go_dag = GODag(\"go.obo\")\n",
|
| 41 |
-
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"def propagate_terms(term_list):\n",
|
| 44 |
-
" full = set()\n",
|
| 45 |
-
" for t in term_list:\n",
|
| 46 |
-
" if t not in go_dag:\n",
|
| 47 |
-
" continue\n",
|
| 48 |
-
" full.add(t)\n",
|
| 49 |
-
" full.update(go_dag[t].get_all_parents())\n",
|
| 50 |
-
" return list(full & mf_terms)\n",
|
| 51 |
-
"\n",
|
| 52 |
-
"# --- 3. Carregar sequências --------------------------------------------\n",
|
| 53 |
-
"seqs, ids = [], []\n",
|
| 54 |
-
"for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
|
| 55 |
-
" ids.append(record.id)\n",
|
| 56 |
-
" seqs.append(str(record.seq))\n",
|
| 57 |
-
"\n",
|
| 58 |
-
"seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"# --- 4. Juntar com GO anotado e propagar -------------------------------\n",
|
| 61 |
-
"grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
|
| 62 |
-
"data = seq_df.merge(grouped, on=\"protein_id\")\n",
|
| 63 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 64 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
|
| 65 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 66 |
-
"\n",
|
| 67 |
-
"# --- 5. Filtrar GO terms raros -----------------------------------------\n",
|
| 68 |
-
"# todos os terms com menos de 50 proteinas associadas\n",
|
| 69 |
-
"all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
|
| 70 |
-
"term_counts = Counter(all_terms)\n",
|
| 71 |
-
"valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
|
| 72 |
-
"data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
|
| 73 |
-
"data = data[data[\"go_term\"].apply(len) > 0]\n",
|
| 74 |
-
"\n",
|
| 75 |
-
"# --- 6. Preparar dataset final -----------------------------------------\n",
|
| 76 |
-
"data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
|
| 77 |
-
"data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
|
| 78 |
-
"\n",
|
| 79 |
-
"# --- 7. Binarizar labels e dividir -------------------------------------\n",
|
| 80 |
-
"mlb = MultiLabelBinarizer()\n",
|
| 81 |
-
"Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
|
| 82 |
-
"X = data[[\"protein_id\", \"sequence\"]].values\n",
|
| 83 |
-
"\n",
|
| 84 |
-
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
|
| 85 |
-
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
|
| 86 |
-
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
|
| 87 |
-
"\n",
|
| 88 |
-
"df_train = data.iloc[train_idx].copy()\n",
|
| 89 |
-
"df_val = data.iloc[val_idx].copy()\n",
|
| 90 |
-
"df_test = data.iloc[test_idx].copy()\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"# --- 8. Guardar em CSV -------------------------------------------------\n",
|
| 93 |
-
"os.makedirs(\"data\", exist_ok=True)\n",
|
| 94 |
-
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
|
| 95 |
-
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
|
| 96 |
-
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- 9. Confirmar ------------------------------------------------------\n",
|
| 99 |
-
"print(\"✓ Ficheiros criados:\")\n",
|
| 100 |
-
"print(\" - data/mf-training.csv :\", df_train.shape)\n",
|
| 101 |
-
"print(\" - data/mf-validation.csv:\", df_val.shape)\n",
|
| 102 |
-
"print(\" - data/mf-test.csv :\", df_test.shape)\n",
|
| 103 |
-
"print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
|
| 104 |
-
]
|
| 105 |
-
},
|
| 106 |
-
{
|
| 107 |
-
"cell_type": "code",
|
| 108 |
-
"execution_count": 2,
|
| 109 |
-
"id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
|
| 110 |
-
"metadata": {},
|
| 111 |
-
"outputs": [
|
| 112 |
-
{
|
| 113 |
-
"name": "stderr",
|
| 114 |
-
"output_type": "stream",
|
| 115 |
-
"text": [
|
| 116 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 117 |
-
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 118 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
|
| 119 |
-
" _torch_pytree._register_pytree_node(\n",
|
| 120 |
-
"100%|██████████| 31142/31142 [00:26<00:00, 1192.86it/s]\n",
|
| 121 |
-
"100%|██████████| 1724/1724 [00:00<00:00, 2570.68it/s]\n",
|
| 122 |
-
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
|
| 123 |
-
" warnings.warn(\n"
|
| 124 |
-
]
|
| 125 |
-
},
|
| 126 |
-
{
|
| 127 |
-
"name": "stdout",
|
| 128 |
-
"output_type": "stream",
|
| 129 |
-
"text": [
|
| 130 |
-
"preprocessing train...\n",
|
| 131 |
-
"language: en\n",
|
| 132 |
-
"train sequence lengths:\n",
|
| 133 |
-
"\tmean : 423\n",
|
| 134 |
-
"\t95percentile : 604\n",
|
| 135 |
-
"\t99percentile : 715\n"
|
| 136 |
-
]
|
| 137 |
-
},
|
| 138 |
-
{
|
| 139 |
-
"data": {
|
| 140 |
-
"text/html": [
|
| 141 |
-
"\n",
|
| 142 |
-
"<style>\n",
|
| 143 |
-
" /* Turns off some styling */\n",
|
| 144 |
-
" progress {\n",
|
| 145 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 146 |
-
" border: none;\n",
|
| 147 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 148 |
-
" background-size: auto;\n",
|
| 149 |
-
" }\n",
|
| 150 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 151 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 152 |
-
" }\n",
|
| 153 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 154 |
-
" background: #F44336;\n",
|
| 155 |
-
" }\n",
|
| 156 |
-
"</style>\n"
|
| 157 |
-
],
|
| 158 |
-
"text/plain": [
|
| 159 |
-
"<IPython.core.display.HTML object>"
|
| 160 |
-
]
|
| 161 |
-
},
|
| 162 |
-
"metadata": {},
|
| 163 |
-
"output_type": "display_data"
|
| 164 |
-
},
|
| 165 |
-
{
|
| 166 |
-
"data": {
|
| 167 |
-
"text/html": [],
|
| 168 |
-
"text/plain": [
|
| 169 |
-
"<IPython.core.display.HTML object>"
|
| 170 |
-
]
|
| 171 |
-
},
|
| 172 |
-
"metadata": {},
|
| 173 |
-
"output_type": "display_data"
|
| 174 |
-
},
|
| 175 |
-
{
|
| 176 |
-
"name": "stdout",
|
| 177 |
-
"output_type": "stream",
|
| 178 |
-
"text": [
|
| 179 |
-
"Is Multi-Label? True\n",
|
| 180 |
-
"preprocessing test...\n",
|
| 181 |
-
"language: en\n",
|
| 182 |
-
"test sequence lengths:\n",
|
| 183 |
-
"\tmean : 408\n",
|
| 184 |
-
"\t95percentile : 603\n",
|
| 185 |
-
"\t99percentile : 714\n"
|
| 186 |
-
]
|
| 187 |
-
},
|
| 188 |
-
{
|
| 189 |
-
"data": {
|
| 190 |
-
"text/html": [
|
| 191 |
-
"\n",
|
| 192 |
-
"<style>\n",
|
| 193 |
-
" /* Turns off some styling */\n",
|
| 194 |
-
" progress {\n",
|
| 195 |
-
" /* gets rid of default border in Firefox and Opera. */\n",
|
| 196 |
-
" border: none;\n",
|
| 197 |
-
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
| 198 |
-
" background-size: auto;\n",
|
| 199 |
-
" }\n",
|
| 200 |
-
" progress:not([value]), progress:not([value])::-webkit-progress-bar {\n",
|
| 201 |
-
" background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);\n",
|
| 202 |
-
" }\n",
|
| 203 |
-
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
| 204 |
-
" background: #F44336;\n",
|
| 205 |
-
" }\n",
|
| 206 |
-
"</style>\n"
|
| 207 |
-
],
|
| 208 |
-
"text/plain": [
|
| 209 |
-
"<IPython.core.display.HTML object>"
|
| 210 |
-
]
|
| 211 |
-
},
|
| 212 |
-
"metadata": {},
|
| 213 |
-
"output_type": "display_data"
|
| 214 |
-
},
|
| 215 |
-
{
|
| 216 |
-
"data": {
|
| 217 |
-
"text/html": [],
|
| 218 |
-
"text/plain": [
|
| 219 |
-
"<IPython.core.display.HTML object>"
|
| 220 |
-
]
|
| 221 |
-
},
|
| 222 |
-
"metadata": {},
|
| 223 |
-
"output_type": "display_data"
|
| 224 |
-
},
|
| 225 |
-
{
|
| 226 |
-
"name": "stdout",
|
| 227 |
-
"output_type": "stream",
|
| 228 |
-
"text": [
|
| 229 |
-
"\n",
|
| 230 |
-
"\n",
|
| 231 |
-
"begin training using triangular learning rate policy with max lr of 1e-05...\n",
|
| 232 |
-
"Epoch 1/10\n",
|
| 233 |
-
"40995/40995 [==============================] - 9020s 219ms/step - loss: 0.0740 - binary_accuracy: 0.9869 - val_loss: 0.0526 - val_binary_accuracy: 0.9866\n",
|
| 234 |
-
"Epoch 2/10\n",
|
| 235 |
-
"40995/40995 [==============================] - 8939s 218ms/step - loss: 0.0464 - binary_accuracy: 0.9877 - val_loss: 0.0457 - val_binary_accuracy: 0.9871\n",
|
| 236 |
-
"Epoch 3/10\n",
|
| 237 |
-
"40995/40995 [==============================] - 8881s 217ms/step - loss: 0.0413 - binary_accuracy: 0.9883 - val_loss: 0.0418 - val_binary_accuracy: 0.9877\n",
|
| 238 |
-
"Epoch 4/10\n",
|
| 239 |
-
"40995/40995 [==============================] - 10277s 251ms/step - loss: 0.0380 - binary_accuracy: 0.9888 - val_loss: 0.0396 - val_binary_accuracy: 0.9881\n",
|
| 240 |
-
"Epoch 5/10\n",
|
| 241 |
-
"40995/40995 [==============================] - 10565s 258ms/step - loss: 0.0357 - binary_accuracy: 0.9892 - val_loss: 0.0380 - val_binary_accuracy: 0.9883\n",
|
| 242 |
-
"Epoch 6/10\n",
|
| 243 |
-
"40995/40995 [==============================] - 10693s 261ms/step - loss: 0.0338 - binary_accuracy: 0.9895 - val_loss: 0.0369 - val_binary_accuracy: 0.9885\n",
|
| 244 |
-
"Epoch 7/10\n",
|
| 245 |
-
"40995/40995 [==============================] - 12055s 294ms/step - loss: 0.0323 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
|
| 246 |
-
"Epoch 8/10\n",
|
| 247 |
-
"40995/40995 [==============================] - 10225s 249ms/step - loss: 0.0309 - binary_accuracy: 0.9901 - val_loss: 0.0353 - val_binary_accuracy: 0.9890\n",
|
| 248 |
-
"Epoch 9/10\n",
|
| 249 |
-
"40995/40995 [==============================] - 10308s 251ms/step - loss: 0.0297 - binary_accuracy: 0.9904 - val_loss: 0.0347 - val_binary_accuracy: 0.9891\n",
|
| 250 |
-
"Epoch 10/10\n",
|
| 251 |
-
"40995/40995 [==============================] - 10275s 251ms/step - loss: 0.0286 - binary_accuracy: 0.9907 - val_loss: 0.0346 - val_binary_accuracy: 0.9893\n",
|
| 252 |
-
"Weights from best epoch have been loaded into model.\n"
|
| 253 |
-
]
|
| 254 |
-
},
|
| 255 |
-
{
|
| 256 |
-
"data": {
|
| 257 |
-
"text/plain": [
|
| 258 |
-
"<keras.callbacks.History at 0x2b644b84fd0>"
|
| 259 |
-
]
|
| 260 |
-
},
|
| 261 |
-
"execution_count": 2,
|
| 262 |
-
"metadata": {},
|
| 263 |
-
"output_type": "execute_result"
|
| 264 |
-
}
|
| 265 |
-
],
|
| 266 |
-
"source": [
|
| 267 |
-
"import pandas as pd\n",
|
| 268 |
-
"import numpy as np\n",
|
| 269 |
-
"from tqdm import tqdm\n",
|
| 270 |
-
"import random\n",
|
| 271 |
-
"import os\n",
|
| 272 |
-
"import ktrain\n",
|
| 273 |
-
"from ktrain import text\n",
|
| 274 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 275 |
-
"\n",
|
| 276 |
-
"\n",
|
| 277 |
-
"# PAM1\n",
|
| 278 |
-
"# PAM matrix model of protein evolution\n",
|
| 279 |
-
"# DOI:10.1093/oxfordjournals.molbev.a040360\n",
|
| 280 |
-
"pam_data = {\n",
|
| 281 |
-
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
|
| 282 |
-
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
|
| 283 |
-
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
|
| 284 |
-
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
|
| 285 |
-
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
|
| 286 |
-
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
|
| 287 |
-
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
|
| 288 |
-
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
|
| 289 |
-
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
|
| 290 |
-
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
|
| 291 |
-
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
|
| 292 |
-
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
|
| 293 |
-
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
|
| 294 |
-
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
|
| 295 |
-
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
|
| 296 |
-
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
|
| 297 |
-
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
|
| 298 |
-
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
|
| 299 |
-
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
|
| 300 |
-
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
|
| 301 |
-
"}\n",
|
| 302 |
-
"pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
|
| 303 |
-
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
|
| 304 |
-
"list_amino = pam_raw.columns.tolist()\n",
|
| 305 |
-
"pam_dict = {\n",
|
| 306 |
-
" aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
|
| 307 |
-
" for aa in list_amino\n",
|
| 308 |
-
"}\n",
|
| 309 |
-
"\n",
|
| 310 |
-
"def pam1_substitution(aa):\n",
|
| 311 |
-
" if aa not in pam_dict:\n",
|
| 312 |
-
" return aa\n",
|
| 313 |
-
" subs = list(pam_dict[aa].keys())\n",
|
| 314 |
-
" probs = list(pam_dict[aa].values())\n",
|
| 315 |
-
" return np.random.choice(subs, p=probs)\n",
|
| 316 |
-
"\n",
|
| 317 |
-
"def augment_sequence(seq, sub_prob=0.05):\n",
|
| 318 |
-
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
|
| 319 |
-
"\n",
|
| 320 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 321 |
-
" if len(seq) <= win:\n",
|
| 322 |
-
" return [seq]\n",
|
| 323 |
-
" slices, start = [], 0\n",
|
| 324 |
-
" while start + win <= len(seq):\n",
|
| 325 |
-
" slices.append(seq[start:start+win])\n",
|
| 326 |
-
" start += win\n",
|
| 327 |
-
" leftover = seq[start:]\n",
|
| 328 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 329 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 330 |
-
" slices.append(extra)\n",
|
| 331 |
-
" return slices\n",
|
| 332 |
-
"\n",
|
| 333 |
-
"def generate_data(df, augment=False):\n",
|
| 334 |
-
" X, y = [], []\n",
|
| 335 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 336 |
-
" for _, row in tqdm(df.iterrows(), total=len(df)):\n",
|
| 337 |
-
" seq = row[\"sequence\"]\n",
|
| 338 |
-
" if augment:\n",
|
| 339 |
-
" seq = augment_sequence(seq)\n",
|
| 340 |
-
" seq_slices = slice_sequence(seq)\n",
|
| 341 |
-
" X.extend(seq_slices)\n",
|
| 342 |
-
" lbl = row[label_cols].values.astype(int)\n",
|
| 343 |
-
" y.extend([lbl] * len(seq_slices))\n",
|
| 344 |
-
" return X, np.array(y), label_cols\n",
|
| 345 |
-
"\n",
|
| 346 |
-
"def format_sequence(seq): return \" \".join(list(seq))\n",
|
| 347 |
-
"\n",
|
| 348 |
-
"# Função para carregar e binarizar\n",
|
| 349 |
-
"def load_and_binarize(csv_path, mlb=None):\n",
|
| 350 |
-
" df = pd.read_csv(csv_path)\n",
|
| 351 |
-
" df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
|
| 352 |
-
" if mlb is None:\n",
|
| 353 |
-
" mlb = MultiLabelBinarizer()\n",
|
| 354 |
-
" labels = mlb.fit_transform(df[\"go_terms\"])\n",
|
| 355 |
-
" else:\n",
|
| 356 |
-
" labels = mlb.transform(df[\"go_terms\"])\n",
|
| 357 |
-
" labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
|
| 358 |
-
" df = df.reset_index(drop=True).join(labels_df)\n",
|
| 359 |
-
" return df, mlb\n",
|
| 360 |
-
"\n",
|
| 361 |
-
"# Carregar os dados\n",
|
| 362 |
-
"df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
|
| 363 |
-
"df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
|
| 364 |
-
"\n",
|
| 365 |
-
"# Gerar com augmentation no treino\n",
|
| 366 |
-
"X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
|
| 367 |
-
"X_val, y_val, _ = generate_data(df_val, augment=False)\n",
|
| 368 |
-
"\n",
|
| 369 |
-
"# Preparar texto para tokenizer\n",
|
| 370 |
-
"X_train_fmt = list(map(format_sequence, X_train))\n",
|
| 371 |
-
"X_val_fmt = list(map(format_sequence, X_val))\n",
|
| 372 |
-
"\n",
|
| 373 |
-
"# Fine-tune ProtBERT\n",
|
| 374 |
-
"# https://huggingface.co/Rostlab/prot_bert\n",
|
| 375 |
-
"# https://doi.org/10.1093/bioinformatics/btac020\n",
|
| 376 |
-
"# dados de treino-> UniRef100 (216 milhões de sequências)\n",
|
| 377 |
-
"MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
|
| 378 |
-
"MAX_LEN = 512\n",
|
| 379 |
-
"BATCH_SIZE = 1\n",
|
| 380 |
-
"\n",
|
| 381 |
-
"t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
|
| 382 |
-
"trn = t.preprocess_train(X_train_fmt, y_train)\n",
|
| 383 |
-
"val = t.preprocess_test(X_val_fmt, y_val)\n",
|
| 384 |
-
"\n",
|
| 385 |
-
"model = t.get_classifier()\n",
|
| 386 |
-
"learner = ktrain.get_learner(model,\n",
|
| 387 |
-
" train_data=trn,\n",
|
| 388 |
-
" val_data=val,\n",
|
| 389 |
-
" batch_size=BATCH_SIZE)\n",
|
| 390 |
-
"\n",
|
| 391 |
-
"learner.autofit(lr=1e-5,\n",
|
| 392 |
-
" epochs=10,\n",
|
| 393 |
-
" early_stopping=1,\n",
|
| 394 |
-
" checkpoint_folder=\"mf-fine-tuned-protbertbfd\")\n"
|
| 395 |
-
]
|
| 396 |
-
},
|
| 397 |
-
{
|
| 398 |
-
"cell_type": "code",
|
| 399 |
-
"execution_count": 6,
|
| 400 |
-
"id": "c66774b3-6cf0-41c5-bb01-9467a5283102",
|
| 401 |
-
"metadata": {},
|
| 402 |
-
"outputs": [
|
| 403 |
-
{
|
| 404 |
-
"name": "stdout",
|
| 405 |
-
"output_type": "stream",
|
| 406 |
-
"text": [
|
| 407 |
-
"✅ Existe: weights/mf-fine-tuned-protbertbfd\n",
|
| 408 |
-
"📁 Conteúdo:\n",
|
| 409 |
-
" - config.json\n",
|
| 410 |
-
" - tf_model.h5\n"
|
| 411 |
-
]
|
| 412 |
-
}
|
| 413 |
-
],
|
| 414 |
-
"source": [
|
| 415 |
-
"import os\n",
|
| 416 |
-
"learner.save_model('weights/mf-fine-tuned-protbertbfd')\n",
|
| 417 |
-
"path = \"weights/mf-fine-tuned-protbertbfd\"\n",
|
| 418 |
-
"\n",
|
| 419 |
-
"if os.path.exists(path):\n",
|
| 420 |
-
" print(f\"✅ Existe: {path}\")\n",
|
| 421 |
-
" print(\"📁 Conteúdo:\")\n",
|
| 422 |
-
" for f in os.listdir(path):\n",
|
| 423 |
-
" print(\" -\", f)\n",
|
| 424 |
-
"else:\n",
|
| 425 |
-
" print(f\"❌ Não existe: {path}\")\n",
|
| 426 |
-
"\n"
|
| 427 |
-
]
|
| 428 |
-
},
|
| 429 |
-
{
|
| 430 |
-
"cell_type": "code",
|
| 431 |
-
"execution_count": 8,
|
| 432 |
-
"id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
|
| 433 |
-
"metadata": {},
|
| 434 |
-
"outputs": [
|
| 435 |
-
{
|
| 436 |
-
"name": "stdout",
|
| 437 |
-
"output_type": "stream",
|
| 438 |
-
"text": [
|
| 439 |
-
"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
|
| 440 |
-
]
|
| 441 |
-
},
|
| 442 |
-
{
|
| 443 |
-
"name": "stderr",
|
| 444 |
-
"output_type": "stream",
|
| 445 |
-
"text": [
|
| 446 |
-
"Processando data/mf-training.csv: 100%|██████████| 31142/31142 [5:17:56<00:00, 1.63it/s] \n"
|
| 447 |
-
]
|
| 448 |
-
},
|
| 449 |
-
{
|
| 450 |
-
"name": "stdout",
|
| 451 |
-
"output_type": "stream",
|
| 452 |
-
"text": [
|
| 453 |
-
"✓ Guardado embeddings\\train_protbertbfd.pkl — 31142 proteínas\n"
|
| 454 |
-
]
|
| 455 |
-
},
|
| 456 |
-
{
|
| 457 |
-
"name": "stderr",
|
| 458 |
-
"output_type": "stream",
|
| 459 |
-
"text": [
|
| 460 |
-
"Processando data/mf-validation.csv: 100%|██████████| 1724/1724 [19:15<00:00, 1.49it/s]\n"
|
| 461 |
-
]
|
| 462 |
-
},
|
| 463 |
-
{
|
| 464 |
-
"name": "stdout",
|
| 465 |
-
"output_type": "stream",
|
| 466 |
-
"text": [
|
| 467 |
-
"✓ Guardado embeddings\\val_protbertbfd.pkl — 1724 proteínas\n"
|
| 468 |
-
]
|
| 469 |
-
},
|
| 470 |
-
{
|
| 471 |
-
"name": "stderr",
|
| 472 |
-
"output_type": "stream",
|
| 473 |
-
"text": [
|
| 474 |
-
"Processando data/mf-test.csv: 100%|██████████| 1724/1724 [17:15<00:00, 1.66it/s]\n"
|
| 475 |
-
]
|
| 476 |
-
},
|
| 477 |
-
{
|
| 478 |
-
"name": "stdout",
|
| 479 |
-
"output_type": "stream",
|
| 480 |
-
"text": [
|
| 481 |
-
"✓ Guardado embeddings\\test_protbertbfd.pkl — 1724 proteínas\n"
|
| 482 |
-
]
|
| 483 |
-
}
|
| 484 |
-
],
|
| 485 |
-
"source": [
|
| 486 |
-
"import os\n",
|
| 487 |
-
"import pandas as pd\n",
|
| 488 |
-
"import numpy as np\n",
|
| 489 |
-
"from tqdm import tqdm\n",
|
| 490 |
-
"import joblib\n",
|
| 491 |
-
"import gc\n",
|
| 492 |
-
"from transformers import AutoTokenizer, TFAutoModel\n",
|
| 493 |
-
"\n",
|
| 494 |
-
"# --- 1. Parâmetros --------------------------------------------------------\n",
|
| 495 |
-
"MODEL_DIR = \"weights/mf-fine-tuned-protbertbfd\"\n",
|
| 496 |
-
"MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
|
| 497 |
-
"OUT_DIR = \"embeddings\"\n",
|
| 498 |
-
"BATCH_TOK = 16\n",
|
| 499 |
-
"\n",
|
| 500 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
|
| 501 |
-
"model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
|
| 502 |
-
"\n",
|
| 503 |
-
"print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
|
| 504 |
-
"\n",
|
| 505 |
-
"# --- 3. Funções auxiliares ------------------------------------------------\n",
|
| 506 |
-
"def format_sequence(seq):\n",
|
| 507 |
-
" return \" \".join(list(seq))\n",
|
| 508 |
-
"\n",
|
| 509 |
-
"def slice_sequence(seq, win=500, min_overlap=250):\n",
|
| 510 |
-
" if len(seq) <= win:\n",
|
| 511 |
-
" return [seq]\n",
|
| 512 |
-
" slices, start = [], 0\n",
|
| 513 |
-
" while start + win <= len(seq):\n",
|
| 514 |
-
" slices.append(seq[start:start+win])\n",
|
| 515 |
-
" start += win\n",
|
| 516 |
-
" leftover = seq[start:]\n",
|
| 517 |
-
" if leftover and len(leftover) >= min_overlap and len(slices[-1]) >= min_overlap:\n",
|
| 518 |
-
" extra = slices[-1][-min_overlap:] + leftover\n",
|
| 519 |
-
" slices.append(extra)\n",
|
| 520 |
-
" return slices\n",
|
| 521 |
-
"\n",
|
| 522 |
-
"def get_embeddings(batch, tokenizer, model):\n",
|
| 523 |
-
" tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
|
| 524 |
-
" output = model(**tokens)\n",
|
| 525 |
-
" return output.last_hidden_state[:, 0, :].numpy()\n",
|
| 526 |
-
"\n",
|
| 527 |
-
"def process_split(csv_path, out_path):\n",
|
| 528 |
-
" df = pd.read_csv(csv_path)\n",
|
| 529 |
-
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
|
| 530 |
-
" prot_ids, embeds, labels = [], [], []\n",
|
| 531 |
-
"\n",
|
| 532 |
-
" for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
|
| 533 |
-
" slices = slice_sequence(row[\"sequence\"])\n",
|
| 534 |
-
" slices_fmt = list(map(format_sequence, slices))\n",
|
| 535 |
-
"\n",
|
| 536 |
-
" slice_embeds = []\n",
|
| 537 |
-
" for i in range(0, len(slices_fmt), BATCH_TOK):\n",
|
| 538 |
-
" batch = slices_fmt[i:i+BATCH_TOK]\n",
|
| 539 |
-
" slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
|
| 540 |
-
" slice_embeds = np.vstack(slice_embeds)\n",
|
| 541 |
-
"\n",
|
| 542 |
-
" prot_embed = slice_embeds.mean(axis=0)\n",
|
| 543 |
-
" prot_ids.append(row[\"protein_id\"])\n",
|
| 544 |
-
" embeds.append(prot_embed.astype(np.float32))\n",
|
| 545 |
-
" labels.append(row[label_cols].values.astype(np.int8))\n",
|
| 546 |
-
" gc.collect()\n",
|
| 547 |
-
"\n",
|
| 548 |
-
" embeds = np.vstack(embeds)\n",
|
| 549 |
-
" labels = np.vstack(labels)\n",
|
| 550 |
-
"\n",
|
| 551 |
-
" joblib.dump({\n",
|
| 552 |
-
" \"protein_ids\": prot_ids,\n",
|
| 553 |
-
" \"embeddings\": embeds,\n",
|
| 554 |
-
" \"labels\": labels,\n",
|
| 555 |
-
" \"go_terms\": label_cols\n",
|
| 556 |
-
" }, out_path, compress=3)\n",
|
| 557 |
-
"\n",
|
| 558 |
-
" print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
|
| 559 |
-
"\n",
|
| 560 |
-
"# --- 4. Aplicar -----------------------------------------------------------\n",
|
| 561 |
-
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
| 562 |
-
"\n",
|
| 563 |
-
"process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbertbfd.pkl\"))\n",
|
| 564 |
-
"process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbertbfd.pkl\"))\n",
|
| 565 |
-
"process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbertbfd.pkl\"))\n"
|
| 566 |
-
]
|
| 567 |
-
},
|
| 568 |
-
{
|
| 569 |
-
"cell_type": "code",
|
| 570 |
-
"execution_count": 9,
|
| 571 |
-
"id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
|
| 572 |
-
"metadata": {},
|
| 573 |
-
"outputs": [
|
| 574 |
-
{
|
| 575 |
-
"name": "stdout",
|
| 576 |
-
"output_type": "stream",
|
| 577 |
-
"text": [
|
| 578 |
-
"✓ Corrigido: embeddings/train_protbertbfd.pkl — 31142 exemplos, 597 GO terms\n",
|
| 579 |
-
"✓ Corrigido: embeddings/val_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n",
|
| 580 |
-
"✓ Corrigido: embeddings/test_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n"
|
| 581 |
-
]
|
| 582 |
-
}
|
| 583 |
-
],
|
| 584 |
-
"source": [
|
| 585 |
-
"import pandas as pd\n",
|
| 586 |
-
"import joblib\n",
|
| 587 |
-
"from sklearn.preprocessing import MultiLabelBinarizer\n",
|
| 588 |
-
"\n",
|
| 589 |
-
"# --- 1. Obter GO terms do ficheiro de teste --------------------------------\n",
|
| 590 |
-
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
|
| 591 |
-
"test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"# --- 2. Função para corrigir um .pkl com base nos GO terms do teste --------\n",
|
| 594 |
-
"def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
|
| 595 |
-
" df = pd.read_csv(csv_path)\n",
|
| 596 |
-
" terms_split = df[\"go_terms\"].str.split(\";\")\n",
|
| 597 |
-
" \n",
|
| 598 |
-
" # Apenas termos presentes nos common_terms\n",
|
| 599 |
-
" terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
|
| 600 |
-
" \n",
|
| 601 |
-
" mlb = MultiLabelBinarizer(classes=common_terms)\n",
|
| 602 |
-
" Y = mlb.fit_transform(terms_filtered)\n",
|
| 603 |
-
"\n",
|
| 604 |
-
" data = joblib.load(pkl_path)\n",
|
| 605 |
-
" data[\"labels\"] = Y\n",
|
| 606 |
-
" data[\"go_terms\"] = mlb.classes_.tolist()\n",
|
| 607 |
-
" \n",
|
| 608 |
-
" joblib.dump(data, pkl_path, compress=3)\n",
|
| 609 |
-
" print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
|
| 610 |
-
"\n",
|
| 611 |
-
"# --- 3. Aplicar às 3 partições --------------------------------------------\n",
|
| 612 |
-
"patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbertbfd.pkl\", test_terms)\n",
|
| 613 |
-
"patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbertbfd.pkl\", test_terms)\n",
|
| 614 |
-
"patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbertbfd.pkl\", test_terms)\n"
|
| 615 |
-
]
|
| 616 |
-
},
|
| 617 |
-
{
|
| 618 |
-
"cell_type": "code",
|
| 619 |
-
"execution_count": 1,
|
| 620 |
-
"id": "dbd5c35f-4a08-4906-9cf4-e1df501d1ecb",
|
| 621 |
-
"metadata": {},
|
| 622 |
-
"outputs": [],
|
| 623 |
-
"source": [
|
| 624 |
-
"import joblib\n",
|
| 625 |
-
"train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
|
| 626 |
-
"val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
|
| 627 |
-
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
|
| 628 |
-
"\n",
|
| 629 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 630 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 631 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n"
|
| 632 |
-
]
|
| 633 |
-
},
|
| 634 |
-
{
|
| 635 |
-
"cell_type": "code",
|
| 636 |
-
"execution_count": 5,
|
| 637 |
-
"id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
|
| 638 |
-
"metadata": {},
|
| 639 |
-
"outputs": [
|
| 640 |
-
{
|
| 641 |
-
"name": "stdout",
|
| 642 |
-
"output_type": "stream",
|
| 643 |
-
"text": [
|
| 644 |
-
"✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
|
| 645 |
-
"Epoch 1/100\n",
|
| 646 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0339 - binary_accuracy: 0.9899 - val_loss: 0.0332 - val_binary_accuracy: 0.9904\n",
|
| 647 |
-
"Epoch 2/100\n",
|
| 648 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0252 - binary_accuracy: 0.9921 - val_loss: 0.0328 - val_binary_accuracy: 0.9904\n",
|
| 649 |
-
"Epoch 3/100\n",
|
| 650 |
-
"974/974 [==============================] - 11s 12ms/step - loss: 0.0244 - binary_accuracy: 0.9924 - val_loss: 0.0328 - val_binary_accuracy: 0.9907\n",
|
| 651 |
-
"Epoch 4/100\n",
|
| 652 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0240 - binary_accuracy: 0.9924 - val_loss: 0.0322 - val_binary_accuracy: 0.9905\n",
|
| 653 |
-
"Epoch 5/100\n",
|
| 654 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0235 - binary_accuracy: 0.9925 - val_loss: 0.0330 - val_binary_accuracy: 0.9908\n",
|
| 655 |
-
"Epoch 6/100\n",
|
| 656 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0233 - binary_accuracy: 0.9926 - val_loss: 0.0330 - val_binary_accuracy: 0.9907\n",
|
| 657 |
-
"Epoch 7/100\n",
|
| 658 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0231 - binary_accuracy: 0.9926 - val_loss: 0.0323 - val_binary_accuracy: 0.9908\n",
|
| 659 |
-
"Epoch 8/100\n",
|
| 660 |
-
"974/974 [==============================] - 10s 11ms/step - loss: 0.0229 - binary_accuracy: 0.9927 - val_loss: 0.0325 - val_binary_accuracy: 0.9906\n",
|
| 661 |
-
"Epoch 9/100\n",
|
| 662 |
-
"974/974 [==============================] - 11s 11ms/step - loss: 0.0227 - binary_accuracy: 0.9927 - val_loss: 0.0325 - val_binary_accuracy: 0.9907\n",
|
| 663 |
-
"54/54 [==============================] - 0s 2ms/step\n",
|
| 664 |
-
"Previsões guardadas em mf-protbertbfd-pam1.npy\n",
|
| 665 |
-
"Modelo guardado em models/mlp_protbertbfd.keras\n"
|
| 666 |
-
]
|
| 667 |
-
}
|
| 668 |
-
],
|
| 669 |
-
"source": [
|
| 670 |
-
"import tensorflow as tf\n",
|
| 671 |
-
"import joblib\n",
|
| 672 |
-
"import numpy as np\n",
|
| 673 |
-
"from tensorflow.keras.models import Sequential\n",
|
| 674 |
-
"from tensorflow.keras.layers import Dense, Dropout\n",
|
| 675 |
-
"from tensorflow.keras.callbacks import EarlyStopping\n",
|
| 676 |
-
"\n",
|
| 677 |
-
"# --- 1. Carregar embeddings ----------------------------------------------\n",
|
| 678 |
-
"train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
|
| 679 |
-
"val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
|
| 680 |
-
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
|
| 681 |
-
"\n",
|
| 682 |
-
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
|
| 683 |
-
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
|
| 684 |
-
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
|
| 685 |
-
"\n",
|
| 686 |
-
"print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
|
| 687 |
-
"\n",
|
| 688 |
-
"# --- 2. Garantir consistência de classes ---------------------------------\n",
|
| 689 |
-
"max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
|
| 690 |
-
"\n",
|
| 691 |
-
"def pad_labels(y, target_dim=max_classes):\n",
|
| 692 |
-
" if y.shape[1] < target_dim:\n",
|
| 693 |
-
" padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
|
| 694 |
-
" return np.hstack([y, padding])\n",
|
| 695 |
-
" return y\n",
|
| 696 |
-
"\n",
|
| 697 |
-
"y_val = pad_labels(y_val)\n",
|
| 698 |
-
"y_test = pad_labels(y_test)\n",
|
| 699 |
-
"\n",
|
| 700 |
-
"# --- 3. Modelo MLP ------------------------------------------------------\n",
|
| 701 |
-
"model = Sequential([\n",
|
| 702 |
-
" Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
|
| 703 |
-
" Dropout(0.3),\n",
|
| 704 |
-
" Dense(512, activation=\"relu\"),\n",
|
| 705 |
-
" Dropout(0.3),\n",
|
| 706 |
-
" Dense(max_classes, activation=\"sigmoid\")\n",
|
| 707 |
-
"])\n",
|
| 708 |
-
"\n",
|
| 709 |
-
"model.compile(loss=\"binary_crossentropy\",\n",
|
| 710 |
-
" optimizer=\"adam\",\n",
|
| 711 |
-
" metrics=[\"binary_accuracy\"])\n",
|
| 712 |
-
"\n",
|
| 713 |
-
"# --- 4. Early stopping e treino -----------------------------------------\n",
|
| 714 |
-
"callbacks = [\n",
|
| 715 |
-
" EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
|
| 716 |
-
"]\n",
|
| 717 |
-
"\n",
|
| 718 |
-
"model.fit(X_train, y_train,\n",
|
| 719 |
-
" validation_data=(X_val, y_val),\n",
|
| 720 |
-
" epochs=100,\n",
|
| 721 |
-
" batch_size=32,\n",
|
| 722 |
-
" callbacks=callbacks,\n",
|
| 723 |
-
" verbose=1)\n",
|
| 724 |
-
"\n",
|
| 725 |
-
"# --- 5. Previsões --------------------------------------------------------\n",
|
| 726 |
-
"y_prob = model.predict(X_test)\n",
|
| 727 |
-
"np.save(\"predictions/mf-protbertbfd-pam1.npy\", y_prob)\n",
|
| 728 |
-
"print(\"Previsões guardadas em mf-protbertbfd-pam1.npy\")\n",
|
| 729 |
-
"\n",
|
| 730 |
-
"# --- 6. Modelo ----------------------------------------------------------\n",
|
| 731 |
-
"model.save(\"models/mlp_protbertbfd.keras\")\n",
|
| 732 |
-
"print(\"Modelo guardado em models/mlp_protbertbfd.keras\")"
|
| 733 |
-
]
|
| 734 |
-
},
|
| 735 |
-
{
|
| 736 |
-
"cell_type": "code",
|
| 737 |
-
"execution_count": 12,
|
| 738 |
-
"id": "fdb66630-76dc-43a0-bd56-45052175fdba",
|
| 739 |
-
"metadata": {},
|
| 740 |
-
"outputs": [
|
| 741 |
-
{
|
| 742 |
-
"name": "stdout",
|
| 743 |
-
"output_type": "stream",
|
| 744 |
-
"text": [
|
| 745 |
-
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
|
| 746 |
-
"✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
|
| 747 |
-
"\n",
|
| 748 |
-
"📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\n",
|
| 749 |
-
"Fmax = 0.6570\n",
|
| 750 |
-
"Thr. = 0.41\n",
|
| 751 |
-
"AuPRC = 0.6929\n",
|
| 752 |
-
"Smin = 13.8114\n"
|
| 753 |
-
]
|
| 754 |
-
}
|
| 755 |
-
],
|
| 756 |
-
"source": [
|
| 757 |
-
"import numpy as np\n",
|
| 758 |
-
"from sklearn.metrics import precision_recall_curve, auc\n",
|
| 759 |
-
"from goatools.obo_parser import GODag\n",
|
| 760 |
-
"import joblib\n",
|
| 761 |
-
"import math\n",
|
| 762 |
-
"\n",
|
| 763 |
-
"# --- 1. Parâmetros -------------------------------------------------------\n",
|
| 764 |
-
"GO_FILE = \"go.obo\"\n",
|
| 765 |
-
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
|
| 766 |
-
"ALPHA = 0.5\n",
|
| 767 |
-
"\n",
|
| 768 |
-
"# --- 2. Carregar dados ---------------------------------------------------\n",
|
| 769 |
-
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
|
| 770 |
-
"y_true = test[\"labels\"]\n",
|
| 771 |
-
"terms = test[\"go_terms\"]\n",
|
| 772 |
-
"y_prob = np.load(\"predictions/mf-protbertbfd-pam1.npy\")\n",
|
| 773 |
-
"go_dag = GODag(GO_FILE)\n",
|
| 774 |
-
"\n",
|
| 775 |
-
"print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
|
| 776 |
-
"\n",
|
| 777 |
-
"# --- 3. Fmax -------------------------------------------------------------\n",
|
| 778 |
-
"def compute_fmax(y_true, y_prob, thresholds):\n",
|
| 779 |
-
" fmax, best_thr = 0, 0\n",
|
| 780 |
-
" for t in thresholds:\n",
|
| 781 |
-
" y_pred = (y_prob >= t).astype(int)\n",
|
| 782 |
-
" tp = (y_true * y_pred).sum(axis=1)\n",
|
| 783 |
-
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
|
| 784 |
-
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
|
| 785 |
-
" precision = tp / (tp + fp + 1e-8)\n",
|
| 786 |
-
" recall = tp / (tp + fn + 1e-8)\n",
|
| 787 |
-
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
|
| 788 |
-
" avg_f1 = np.mean(f1)\n",
|
| 789 |
-
" if avg_f1 > fmax:\n",
|
| 790 |
-
" fmax, best_thr = avg_f1, t\n",
|
| 791 |
-
" return fmax, best_thr\n",
|
| 792 |
-
"\n",
|
| 793 |
-
"# --- 4. AuPRC micro ------------------------------------------------------\n",
|
| 794 |
-
"def compute_auprc(y_true, y_prob):\n",
|
| 795 |
-
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
|
| 796 |
-
" return auc(recall, precision)\n",
|
| 797 |
-
"\n",
|
| 798 |
-
"# --- 5. Smin -------------------------------------------------------------\n",
|
| 799 |
-
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
|
| 800 |
-
" y_pred = (y_prob >= threshold).astype(int)\n",
|
| 801 |
-
" ic = {}\n",
|
| 802 |
-
" total = (y_true + y_pred).sum(axis=0).sum()\n",
|
| 803 |
-
" for i, term in enumerate(terms):\n",
|
| 804 |
-
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
|
| 805 |
-
" ic[term] = -np.log((freq + 1e-8) / total)\n",
|
| 806 |
-
"\n",
|
| 807 |
-
" s_values = []\n",
|
| 808 |
-
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
|
| 809 |
-
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
|
| 810 |
-
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
|
| 811 |
-
"\n",
|
| 812 |
-
" anc_true = set()\n",
|
| 813 |
-
" for t in true_terms:\n",
|
| 814 |
-
" if t in go_dag:\n",
|
| 815 |
-
" anc_true |= go_dag[t].get_all_parents()\n",
|
| 816 |
-
" anc_pred = set()\n",
|
| 817 |
-
" for t in pred_terms:\n",
|
| 818 |
-
" if t in go_dag:\n",
|
| 819 |
-
" anc_pred |= go_dag[t].get_all_parents()\n",
|
| 820 |
-
"\n",
|
| 821 |
-
" ru = pred_terms - true_terms\n",
|
| 822 |
-
" mi = true_terms - pred_terms\n",
|
| 823 |
-
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
|
| 824 |
-
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
|
| 825 |
-
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
|
| 826 |
-
" s_values.append(s)\n",
|
| 827 |
-
"\n",
|
| 828 |
-
" return np.mean(s_values)\n",
|
| 829 |
-
"\n",
|
| 830 |
-
"# --- 6. Avaliar ----------------------------------------------------------\n",
|
| 831 |
-
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
|
| 832 |
-
"auprc = compute_auprc(y_true, y_prob)\n",
|
| 833 |
-
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
|
| 834 |
-
"\n",
|
| 835 |
-
"print(f\"\\n📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\")\n",
|
| 836 |
-
"print(f\"Fmax = {fmax:.4f}\")\n",
|
| 837 |
-
"print(f\"Thr. = {thr:.2f}\")\n",
|
| 838 |
-
"print(f\"AuPRC = {auprc:.4f}\")\n",
|
| 839 |
-
"print(f\"Smin = {smin:.4f}\")\n"
|
| 840 |
-
]
|
| 841 |
-
},
|
| 842 |
-
{
|
| 843 |
-
"cell_type": "code",
|
| 844 |
-
"execution_count": null,
|
| 845 |
-
"id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
|
| 846 |
-
"metadata": {},
|
| 847 |
-
"outputs": [],
|
| 848 |
-
"source": []
|
| 849 |
-
}
|
| 850 |
-
],
|
| 851 |
-
"metadata": {
|
| 852 |
-
"kernelspec": {
|
| 853 |
-
"display_name": "Python 3 (ipykernel)",
|
| 854 |
-
"language": "python",
|
| 855 |
-
"name": "python3"
|
| 856 |
-
},
|
| 857 |
-
"language_info": {
|
| 858 |
-
"codemirror_mode": {
|
| 859 |
-
"name": "ipython",
|
| 860 |
-
"version": 3
|
| 861 |
-
},
|
| 862 |
-
"file_extension": ".py",
|
| 863 |
-
"mimetype": "text/x-python",
|
| 864 |
-
"name": "python",
|
| 865 |
-
"nbconvert_exporter": "python",
|
| 866 |
-
"pygments_lexer": "ipython3",
|
| 867 |
-
"version": "3.10.16"
|
| 868 |
-
}
|
| 869 |
-
},
|
| 870 |
-
"nbformat": 4,
|
| 871 |
-
"nbformat_minor": 5
|
| 872 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|