{ "cells": [ { "cell_type": "code", "execution_count": 9, "id": "641053e3-7fec-4f9b-a75e-ddd957af03c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n", "✓ Dataset preparado:\n", " - Training: (31142, 3)\n", " - Validation: (1724, 3)\n", " - Test: (1724, 3)\n", " - GO terms: 602\n" ] } ], "source": [ "# %%\n", "import pandas as pd\n", "import numpy as np\n", "from Bio import SeqIO\n", "from goatools.obo_parser import GODag\n", "from collections import Counter\n", "from sklearn.preprocessing import MultiLabelBinarizer\n", "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n", "import os, random\n", "\n", "# Carregar ficheiros principais\n", "FASTA = \"uniprot_sprot_exp.fasta\"\n", "ANNOT = \"uniprot_sprot_exp.txt\"\n", "GO_OBO = \"go.obo\"\n", "\n", "# Ler sequências\n", "seqs, ids = [], []\n", "for record in SeqIO.parse(FASTA, \"fasta\"):\n", " ids.append(record.id)\n", " seqs.append(str(record.seq))\n", "\n", "df_seq = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n", "\n", "# Ler anotações GO:MF\n", "df_ann = pd.read_csv(ANNOT, sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"category\"])\n", "df_ann = df_ann[df_ann[\"category\"] == \"F\"]\n", "\n", "# Propagação hierárquica dos GO terms\n", "go_dag = GODag(GO_OBO)\n", "mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n", "\n", "def propagate_terms(terms):\n", " expanded = set()\n", " for t in terms:\n", " if t in go_dag:\n", " expanded |= go_dag[t].get_all_parents()\n", " expanded.add(t)\n", " return list(expanded & mf_terms)\n", "\n", "grouped = df_ann.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n", "grouped[\"go_term\"] = grouped[\"go_term\"].apply(propagate_terms)\n", "\n", "# Juntar com sequência\n", "df = df_seq.merge(grouped, on=\"protein_id\")\n", "df = df[df[\"go_term\"].str.len() > 0]\n", "\n", "# Filtrar GO terms com ≥50 proteínas\n", "all_terms = [term for sublist in df[\"go_term\"] for term in sublist]\n", "term_counts = Counter(all_terms)\n", "valid_terms = {t for t, count in term_counts.items() if count >= 50}\n", "\n", "df[\"go_term\"] = df[\"go_term\"].apply(lambda ts: [t for t in ts if t in valid_terms])\n", "df = df[df[\"go_term\"].str.len() > 0]\n", "\n", "# Preparar labels e dividir por proteína\n", "df[\"go_terms\"] = df[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n", "df = df[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n", "\n", "mlb = MultiLabelBinarizer()\n", "Y = mlb.fit_transform(df[\"go_terms\"].str.split(\";\"))\n", "X = df[[\"protein_id\", \"sequence\"]].values\n", "\n", "mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n", "train_idx, temp_idx = next(mskf.split(X, Y))\n", "val_idx, test_idx = np.array_split(temp_idx, 2)\n", "\n", "df_train = df.iloc[train_idx].copy()\n", "df_val = df.iloc[val_idx].copy()\n", "df_test = df.iloc[test_idx].copy()\n", "\n", "os.makedirs(\"data\", exist_ok=True)\n", "df_train.to_csv(\"data/mf-training.csv\", index=False)\n", "df_val.to_csv(\"data/mf-validation.csv\", index=False)\n", "df_test.to_csv(\"data/mf-test.csv\", index=False)\n", "\n", "# Guardar o binarizador\n", "import joblib\n", "joblib.dump(mlb, \"data/mlb.pkl\")\n", "\n", "print(\"✓ Dataset preparado:\")\n", "print(\" - Training:\", df_train.shape)\n", "print(\" - Validation:\", df_val.shape)\n", "print(\" - Test:\", df_test.shape)\n", "print(\" - GO terms:\", len(mlb.classes_))\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "40ba1798-daf8-4649-ae3f-bfe81df6437f", "metadata": {}, "outputs": [], "source": [ "# %%\n", "import random\n", "from collections import defaultdict\n", "\n", "# PAM1\n", "# PAM matrix model of protein evolution\n", "# DOI:10.1093/oxfordjournals.molbev.a040360\n", "pam_data = {\n", " 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n", " 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n", " 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n", " 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n", " 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n", " 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n", " 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n", " 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n", " 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n", " 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n", " 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n", " 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n", " 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n", " 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n", " 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n", " 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n", " 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n", " 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n", " 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n", " 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n", "}\n", "\n", "pam_raw = pd.DataFrame(pam_data, index=pam_data.keys())\n", "pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n", "pam_dict = {aa: pam_matrix.loc[aa].to_dict() for aa in pam_matrix.index}\n", "\n", "def pam1_substitution(aa):\n", " if aa not in pam_dict:\n", " return aa\n", " subs = list(pam_dict[aa].keys())\n", " probs = list(pam_dict[aa].values())\n", " return np.random.choice(subs, p=probs)\n", "\n", "def augment_sequence(seq, sub_prob=0.05):\n", " return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n", "\n", "def slice_sequence(seq, win=1024):\n", " if len(seq) <= win:\n", " return [seq]\n", " return [seq[i:i+win] for i in range(0, len(seq), win)]\n", "\n", "def format_seq(seq):\n", " return \" \".join(seq)\n", "\n", "# Carregar labels e datasets\n", "import joblib\n", "mlb = joblib.load(\"data/mlb.pkl\")\n", "df_train = pd.read_csv(\"data/mf-training.csv\")\n", "df_val = pd.read_csv(\"data/mf-validation.csv\")\n", "df_test = pd.read_csv(\"data/mf-test.csv\")\n", "\n", "# Slicing + augmentação no treino\n", "X_train, y_train = [], []\n", "\n", "for _, row in df_train.iterrows():\n", " seq_aug = augment_sequence(row[\"sequence\"], sub_prob=0.05)\n", " slices = slice_sequence(seq_aug, win=1024)\n", " label = mlb.transform([row[\"go_terms\"].split(\";\")])[0]\n", " for sl in slices:\n", " X_train.append(format_seq(sl))\n", " y_train.append(label)\n", "\n", "# Sem slicing no val/test\n", "X_val = [format_seq(seq) for seq in df_val[\"sequence\"]]\n", "X_test = [format_seq(seq) for seq in df_test[\"sequence\"]]\n", "\n", "y_val = mlb.transform(df_val[\"go_terms\"].str.split(\";\"))\n", "y_test = mlb.transform(df_test[\"go_terms\"].str.split(\";\"))\n", "\n", "np.save(\"embeddings/y_test.npy\", y_test)" ] }, { "cell_type": "code", "execution_count": 11, "id": "80d5c1fb-9c84-463d-8d8c-bfcc2982afc9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\huggingface_hub\\file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n", "Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "100%|██████████| 2189/2189 [1:17:26<00:00, 2.12s/it]\n", "100%|██████████| 108/108 [03:43<00:00, 2.07s/it]\n", "100%|██████████| 108/108 [03:56<00:00, 2.19s/it]\n" ] } ], "source": [ "# %%\n", "from transformers import AutoTokenizer, AutoModel\n", "import torch\n", "from tqdm import tqdm\n", "import numpy as np\n", "import os\n", "\n", "# Configurações\n", "MODEL_NAME = \"facebook/esm2_t33_650M_UR50D\"\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "CHUNK_SIZE = 16\n", "\n", "# Carregar modelo\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n", "model = AutoModel.from_pretrained(MODEL_NAME)\n", "model.to(DEVICE)\n", "model.eval()\n", "\n", "def extract_embeddings(texts):\n", " embeddings = []\n", " for i in tqdm(range(0, len(texts), CHUNK_SIZE)):\n", " batch = texts[i:i+CHUNK_SIZE]\n", " with torch.no_grad():\n", " inputs = tokenizer(batch, return_tensors=\"pt\", padding=True, truncation=True, max_length=1024)\n", " inputs = {k: v.to(DEVICE) for k, v in inputs.items()}\n", " outputs = model(**inputs).last_hidden_state\n", " cls_tokens = outputs[:, 0, :] # token CLS\n", " embeddings.append(cls_tokens.cpu().numpy())\n", " return np.vstack(embeddings)\n", "\n", "# Extrair e guardar embeddings\n", "os.makedirs(\"embeddings\", exist_ok=True)\n", "\n", "emb_train = extract_embeddings(X_train)\n", "emb_val = extract_embeddings(X_val)\n", "emb_test = extract_embeddings(X_test)\n", "\n", "np.save(\"embeddings/esm2_train.npy\", emb_train)\n", "np.save(\"embeddings/esm2_val.npy\", emb_val)\n", "np.save(\"embeddings/esm2_test.npy\", emb_test)\n", "\n", "np.save(\"embeddings/y_train.npy\", np.array(y_train))\n", "np.save(\"embeddings/y_val.npy\", np.array(y_val))\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "592e4f6c-b871-4f0b-b84c-f3918c698544", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0557 - val_loss: 0.0448\n", "Epoch 2/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0444 - val_loss: 0.0413\n", "Epoch 3/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0418 - val_loss: 0.0393\n", "Epoch 4/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0404 - val_loss: 0.0385\n", "Epoch 5/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0392 - val_loss: 0.0373\n", "Epoch 6/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0382 - val_loss: 0.0372\n", "Epoch 7/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0374 - val_loss: 0.0355\n", "Epoch 8/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0368 - val_loss: 0.0350\n", "Epoch 9/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0362 - val_loss: 0.0349\n", "Epoch 10/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0357 - val_loss: 0.0342\n", "Epoch 11/100\n", "1095/1095 [==============================] - 16s 14ms/step - loss: 0.0353 - val_loss: 0.0339\n", "Epoch 12/100\n", "1095/1095 [==============================] - 16s 14ms/step - loss: 0.0348 - val_loss: 0.0336\n", "Epoch 13/100\n", "1095/1095 [==============================] - 15s 13ms/step - loss: 0.0344 - val_loss: 0.0335\n", "Epoch 14/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0341 - val_loss: 0.0337\n", "Epoch 15/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0338 - val_loss: 0.0331\n", "Epoch 16/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0335 - val_loss: 0.0327\n", "Epoch 17/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0332 - val_loss: 0.0328\n", "Epoch 18/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0330 - val_loss: 0.0326\n", "Epoch 19/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0326 - val_loss: 0.0326\n", "Epoch 20/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0324 - val_loss: 0.0319\n", "Epoch 21/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0321 - val_loss: 0.0319\n", "Epoch 22/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0320 - val_loss: 0.0321\n", "Epoch 23/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0319 - val_loss: 0.0314\n", "Epoch 24/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0316 - val_loss: 0.0315\n", "Epoch 25/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0315 - val_loss: 0.0314\n", "Epoch 26/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0314 - val_loss: 0.0316\n", "Epoch 27/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0310 - val_loss: 0.0315\n", "Epoch 28/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0311 - val_loss: 0.0312\n", "Epoch 29/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0307 - val_loss: 0.0312\n", "Epoch 30/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0307 - val_loss: 0.0309\n", "Epoch 31/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0305 - val_loss: 0.0310\n", "Epoch 32/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0305 - val_loss: 0.0311\n", "Epoch 33/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0303 - val_loss: 0.0307\n", "Epoch 34/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0301 - val_loss: 0.0309\n", "Epoch 35/100\n", "1095/1095 [==============================] - 13s 12ms/step - loss: 0.0300 - val_loss: 0.0310\n", "Epoch 36/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0299 - val_loss: 0.0311\n", "Epoch 37/100\n", "1095/1095 [==============================] - 14s 12ms/step - loss: 0.0298 - val_loss: 0.0305\n", "Epoch 38/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0296 - val_loss: 0.0308\n", "Epoch 39/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0296 - val_loss: 0.0310\n", "Epoch 40/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0295 - val_loss: 0.0313\n", "Epoch 41/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0293 - val_loss: 0.0306\n", "Epoch 42/100\n", "1095/1095 [==============================] - 14s 13ms/step - loss: 0.0292 - val_loss: 0.0306\n", "Modelo guardado em models/\n", "54/54 [==============================] - 0s 2ms/step\n", " Predições do ESM-2 salvas com forma: (1724, 602)\n" ] } ], "source": [ "# %%\n", "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow.keras import Input\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Dropout\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "from sklearn.metrics import average_precision_score\n", "\n", "# Carregar os embeddings e labels\n", "X_train = np.load(\"embeddings/esm2_train.npy\")\n", "X_val = np.load(\"embeddings/esm2_val.npy\")\n", "X_test = np.load(\"embeddings/esm2_test.npy\")\n", "\n", "y_train = np.load(\"embeddings/y_train.npy\")\n", "y_val = np.load(\"embeddings/y_val.npy\")\n", "y_test = np.load(\"embeddings/y_test.npy\")\n", "\n", "# Definir o modelo\n", "model = Sequential([\n", " Dense(1024, activation='relu', input_shape=(X_train.shape[1],)),\n", " Dropout(0.3),\n", " Dense(512, activation='relu'),\n", " Dropout(0.3),\n", " Dense(y_train.shape[1], activation='sigmoid')\n", "])\n", "\n", "model.compile(optimizer='adam', loss='binary_crossentropy')\n", "\n", "# Treinar\n", "early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)\n", "\n", "history = model.fit(\n", " X_train, y_train,\n", " validation_data=(X_val, y_val),\n", " epochs=100,\n", " batch_size=32,\n", " callbacks=[early_stop],\n", " verbose=1\n", ")\n", "\n", "# Salvar o modelo\n", "model.save(\"models/mlp_esm2.h5\")\n", "model.save(\"models/mlp_esm2.keras\")\n", "print(\"Modelo guardado em models/\")\n", "\n", "# Fazer predições no conjunto de teste\n", "y_prob = model.predict(X_test)\n", "np.save(\"predictions/mf-esm2.npy\", y_prob)\n", "\n", "print(\" Predições do ESM-2 salvas com forma:\", y_prob.shape)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "3dddb0df-3ea5-4e32-8cf0-45e90be8ba66", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\sklearn\\base.py:380: InconsistentVersionWarning: Trying to unpickle estimator MultiLabelBinarizer from version 1.1.3 when using version 1.6.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n", "✓ Dados carregados: (1724, 602) proteínas × 602 GO terms\n", "\n", " Resultados finais (ESM-2 + PAM1 + propagação):\n", "Fmax = 0.6377\n", "Thr. = 0.35\n", "AuPRC = 0.6848\n", "Smin = 14.4202\n" ] } ], "source": [ "# %%\n", "import numpy as np\n", "import joblib\n", "import math\n", "from goatools.obo_parser import GODag\n", "from sklearn.metrics import precision_recall_curve, auc\n", "\n", "# Carregar dados e parâmetros\n", "GO_FILE = \"go.obo\"\n", "THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n", "ALPHA = 0.5\n", "\n", "mlb = joblib.load(\"data/mlb.pkl\")\n", "y_true = np.load(\"embeddings/y_test.npy\")\n", "y_prob = np.load(\"predictions/mf-esm2.npy\")\n", "terms = mlb.classes_\n", "go_dag = GODag(GO_FILE)\n", "\n", "print(f\"✓ Dados carregados: {y_true.shape} proteínas × {len(terms)} GO terms\")\n", "\n", "# Fmax\n", "def compute_fmax(y_true, y_prob, thresholds):\n", " fmax, best_thr = 0, 0\n", " for t in thresholds:\n", " y_pred = (y_prob >= t).astype(int)\n", " tp = (y_true * y_pred).sum(axis=1)\n", " fp = ((1 - y_true) * y_pred).sum(axis=1)\n", " fn = (y_true * (1 - y_pred)).sum(axis=1)\n", " precision = tp / (tp + fp + 1e-8)\n", " recall = tp / (tp + fn + 1e-8)\n", " f1 = 2 * precision * recall / (precision + recall + 1e-8)\n", " avg_f1 = np.mean(f1)\n", " if avg_f1 > fmax:\n", " fmax, best_thr = avg_f1, t\n", " return fmax, best_thr\n", "\n", "# AuPRC (micro)\n", "def compute_auprc(y_true, y_prob):\n", " precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n", " return auc(recall, precision)\n", "\n", "# Smin\n", "def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n", " y_pred = (y_prob >= threshold).astype(int)\n", "\n", " # Informação semântica: IC (Information Content)\n", " ic = {}\n", " total = (y_true + y_pred).sum(axis=0).sum()\n", " for i, term in enumerate(terms):\n", " freq = (y_true[:, i] + y_pred[:, i]).sum()\n", " ic[term] = -np.log((freq + 1e-8) / total)\n", "\n", " # Para cada proteína, calcular RU e MI\n", " s_values = []\n", " for true_vec, pred_vec in zip(y_true, y_pred):\n", " true_terms = {terms[i] for i in np.where(true_vec)[0]}\n", " pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n", "\n", " anc_true = set()\n", " for t in true_terms:\n", " if t in go_dag:\n", " anc_true |= go_dag[t].get_all_parents()\n", " anc_pred = set()\n", " for t in pred_terms:\n", " if t in go_dag:\n", " anc_pred |= go_dag[t].get_all_parents()\n", "\n", " ru = pred_terms - true_terms\n", " mi = true_terms - pred_terms\n", " dist_ru = sum(ic.get(t, 0) for t in ru)\n", " dist_mi = sum(ic.get(t, 0) for t in mi)\n", " s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n", " s_values.append(s)\n", "\n", " return np.mean(s_values)\n", "\n", "# Avaliação\n", "fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n", "auprc = compute_auprc(y_true, y_prob)\n", "smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n", "\n", "print(f\"\\n Resultados finais (ESM-2 + PAM1 + propagação):\")\n", "print(f\"Fmax = {fmax:.4f}\")\n", "print(f\"Thr. = {thr:.2f}\")\n", "print(f\"AuPRC = {auprc:.4f}\")\n", "print(f\"Smin = {smin:.4f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1a1ea084-01de-4dc4-88da-e7ffeb8c94c9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }