{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c6dbc330-062a-48f0-8242-3f21cc1c9c2b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
"✓ Ficheiros criados:\n",
" - data/mf-training.csv : (31142, 3)\n",
" - data/mf-validation.csv: (1724, 3)\n",
" - data/mf-test.csv : (1724, 3)\n",
"GO terms únicos (após propagação e filtro): 602\n"
]
}
],
"source": [
"import pandas as pd\n",
"from Bio import SeqIO\n",
"from collections import Counter\n",
"from goatools.obo_parser import GODag\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
"import numpy as np\n",
"import os\n",
"\n",
"# Carregar GO anotações\n",
"annotations = pd.read_csv(\"uniprot_sprot_exp.txt\", sep=\"\\t\", names=[\"protein_id\", \"go_term\", \"go_category\"])\n",
"annotations_f = annotations[annotations[\"go_category\"] == \"F\"]\n",
"\n",
"# Carregar DAG e propagar GO terms\n",
"# propagação hierárquica\n",
"# https://geneontology.org/docs/download-ontology/\n",
"go_dag = GODag(\"go.obo\")\n",
"mf_terms = {t for t, o in go_dag.items() if o.namespace == \"molecular_function\"}\n",
"\n",
"def propagate_terms(term_list):\n",
" full = set()\n",
" for t in term_list:\n",
" if t not in go_dag:\n",
" continue\n",
" full.add(t)\n",
" full.update(go_dag[t].get_all_parents())\n",
" return list(full & mf_terms)\n",
"\n",
"# Carregar sequências\n",
"seqs, ids = [], []\n",
"for record in SeqIO.parse(\"uniprot_sprot_exp.fasta\", \"fasta\"):\n",
" ids.append(record.id)\n",
" seqs.append(str(record.seq))\n",
"\n",
"seq_df = pd.DataFrame({\"protein_id\": ids, \"sequence\": seqs})\n",
"\n",
"# Juntar com GO anotado e propagar\n",
"grouped = annotations_f.groupby(\"protein_id\")[\"go_term\"].apply(list).reset_index()\n",
"data = seq_df.merge(grouped, on=\"protein_id\")\n",
"data = data[data[\"go_term\"].apply(len) > 0]\n",
"data[\"go_term\"] = data[\"go_term\"].apply(propagate_terms)\n",
"data = data[data[\"go_term\"].apply(len) > 0]\n",
"\n",
"# Filtrar GO terms raros\n",
"# todos os terms com menos de 50 proteinas associadas\n",
"all_terms = [term for sublist in data[\"go_term\"] for term in sublist]\n",
"term_counts = Counter(all_terms)\n",
"valid_terms = {term for term, count in term_counts.items() if count >= 50}\n",
"data[\"go_term\"] = data[\"go_term\"].apply(lambda terms: [t for t in terms if t in valid_terms])\n",
"data = data[data[\"go_term\"].apply(len) > 0]\n",
"\n",
"# Preparar dataset final\n",
"data[\"go_terms\"] = data[\"go_term\"].apply(lambda x: ';'.join(sorted(set(x))))\n",
"data = data[[\"protein_id\", \"sequence\", \"go_terms\"]].drop_duplicates()\n",
"\n",
"# Binarizar labels e dividir\n",
"mlb = MultiLabelBinarizer()\n",
"Y = mlb.fit_transform(data[\"go_terms\"].str.split(\";\"))\n",
"X = data[[\"protein_id\", \"sequence\"]].values\n",
"\n",
"mskf = MultilabelStratifiedKFold(n_splits=10, random_state=42, shuffle=True)\n",
"train_idx, temp_idx = next(mskf.split(X, Y))\n",
"val_idx, test_idx = np.array_split(temp_idx, 2)\n",
"\n",
"df_train = data.iloc[train_idx].copy()\n",
"df_val = data.iloc[val_idx].copy()\n",
"df_test = data.iloc[test_idx].copy()\n",
"\n",
"# Guardar em CSV\n",
"os.makedirs(\"data\", exist_ok=True)\n",
"df_train.to_csv(\"data/mf-training.csv\", index=False)\n",
"df_val.to_csv(\"data/mf-validation.csv\", index=False)\n",
"df_test.to_csv(\"data/mf-test.csv\", index=False)\n",
"\n",
"# Confirmar\n",
"print(\"✓ Ficheiros criados:\")\n",
"print(\" - data/mf-training.csv :\", df_train.shape)\n",
"print(\" - data/mf-validation.csv:\", df_val.shape)\n",
"print(\" - data/mf-test.csv :\", df_test.shape)\n",
"print(f\"GO terms únicos (após propagação e filtro): {len(mlb.classes_)}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6cf7aaa6-4941-4951-8d73-1f4f1f4362f3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\transformers\\utils\\generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n",
" _torch_pytree._register_pytree_node(\n",
"100%|██████████| 31142/31142 [00:26<00:00, 1192.86it/s]\n",
"100%|██████████| 1724/1724 [00:00<00:00, 2570.68it/s]\n",
"C:\\Users\\Melvin\\anaconda3\\envs\\projeto_proteina2\\lib\\site-packages\\ktrain\\text\\preprocessor.py:382: UserWarning: The class_names argument is replacing the classes argument. Please update your code.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"preprocessing train...\n",
"language: en\n",
"train sequence lengths:\n",
"\tmean : 423\n",
"\t95percentile : 604\n",
"\t99percentile : 715\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Is Multi-Label? True\n",
"preprocessing test...\n",
"language: en\n",
"test sequence lengths:\n",
"\tmean : 408\n",
"\t95percentile : 603\n",
"\t99percentile : 714\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"begin training using triangular learning rate policy with max lr of 1e-05...\n",
"Epoch 1/10\n",
"40995/40995 [==============================] - 9020s 219ms/step - loss: 0.0740 - binary_accuracy: 0.9869 - val_loss: 0.0526 - val_binary_accuracy: 0.9866\n",
"Epoch 2/10\n",
"40995/40995 [==============================] - 8939s 218ms/step - loss: 0.0464 - binary_accuracy: 0.9877 - val_loss: 0.0457 - val_binary_accuracy: 0.9871\n",
"Epoch 3/10\n",
"40995/40995 [==============================] - 8881s 217ms/step - loss: 0.0413 - binary_accuracy: 0.9883 - val_loss: 0.0418 - val_binary_accuracy: 0.9877\n",
"Epoch 4/10\n",
"40995/40995 [==============================] - 10277s 251ms/step - loss: 0.0380 - binary_accuracy: 0.9888 - val_loss: 0.0396 - val_binary_accuracy: 0.9881\n",
"Epoch 5/10\n",
"40995/40995 [==============================] - 10565s 258ms/step - loss: 0.0357 - binary_accuracy: 0.9892 - val_loss: 0.0380 - val_binary_accuracy: 0.9883\n",
"Epoch 6/10\n",
"40995/40995 [==============================] - 10693s 261ms/step - loss: 0.0338 - binary_accuracy: 0.9895 - val_loss: 0.0369 - val_binary_accuracy: 0.9885\n",
"Epoch 7/10\n",
"40995/40995 [==============================] - 12055s 294ms/step - loss: 0.0323 - binary_accuracy: 0.9898 - val_loss: 0.0360 - val_binary_accuracy: 0.9888\n",
"Epoch 8/10\n",
"40995/40995 [==============================] - 10225s 249ms/step - loss: 0.0309 - binary_accuracy: 0.9901 - val_loss: 0.0353 - val_binary_accuracy: 0.9890\n",
"Epoch 9/10\n",
"40995/40995 [==============================] - 10308s 251ms/step - loss: 0.0297 - binary_accuracy: 0.9904 - val_loss: 0.0347 - val_binary_accuracy: 0.9891\n",
"Epoch 10/10\n",
"40995/40995 [==============================] - 10275s 251ms/step - loss: 0.0286 - binary_accuracy: 0.9907 - val_loss: 0.0346 - val_binary_accuracy: 0.9893\n",
"Weights from best epoch have been loaded into model.\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"import random\n",
"import os\n",
"import ktrain\n",
"from ktrain import text\n",
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"\n",
"\n",
"# PAM1\n",
"# PAM matrix model of protein evolution\n",
"# DOI:10.1093/oxfordjournals.molbev.a040360\n",
"pam_data = {\n",
" 'A': [9948, 19, 27, 42, 31, 46, 50, 92, 17, 7, 40, 88, 42, 41, 122, 279, 255, 9, 72, 723],\n",
" 'R': [14, 9871, 24, 38, 37, 130, 38, 62, 49, 4, 58, 205, 26, 33, 47, 103, 104, 5, 36, 52],\n",
" 'N': [20, 22, 9860, 181, 29, 36, 41, 67, 31, 5, 22, 49, 23, 10, 33, 83, 66, 3, 43, 32],\n",
" 'D': [40, 34, 187, 9818, 11, 63, 98, 61, 23, 5, 25, 54, 43, 13, 27, 88, 55, 4, 29, 36],\n",
" 'C': [20, 16, 26, 9, 9987, 10, 17, 37, 12, 2, 16, 26, 10, 19, 27, 26, 25, 2, 6, 67],\n",
" 'Q': [29, 118, 29, 49, 8, 9816, 72, 55, 36, 4, 60, 158, 35, 22, 39, 86, 74, 3, 34, 28],\n",
" 'E': [35, 29, 41, 101, 12, 71, 9804, 56, 33, 5, 36, 107, 42, 20, 38, 87, 69, 4, 30, 42],\n",
" 'G': [96, 61, 77, 70, 38, 51, 58, 9868, 26, 6, 37, 53, 39, 28, 69, 134, 116, 5, 47, 60],\n",
" 'H': [17, 53, 33, 19, 15, 39, 34, 24, 9907, 3, 32, 57, 24, 15, 27, 47, 43, 2, 22, 19],\n",
" 'I': [6, 3, 6, 6, 3, 5, 6, 7, 3, 9973, 23, 13, 12, 41, 93, 84, 115, 3, 8, 102],\n",
" 'L': [26, 39, 17, 15, 7, 33, 22, 20, 19, 27, 9864, 49, 24, 78, 117, 148, 193, 5, 24, 70],\n",
" 'K': [60, 198, 43, 52, 12, 142, 96, 53, 42, 10, 63, 9710, 33, 26, 54, 109, 102, 5, 43, 42],\n",
" 'M': [21, 22, 15, 18, 6, 20, 18, 18, 17, 11, 27, 32, 9945, 26, 34, 61, 71, 3, 12, 31],\n",
" 'F': [18, 17, 8, 6, 8, 11, 10, 16, 10, 44, 92, 24, 29, 9899, 89, 88, 142, 7, 14, 68],\n",
" 'P': [97, 47, 35, 29, 23, 35, 38, 57, 21, 24, 47, 56, 28, 76, 9785, 115, 77, 4, 24, 35],\n",
" 'S': [241, 87, 76, 73, 17, 56, 60, 99, 32, 13, 69, 92, 42, 67, 100, 9605, 212, 8, 63, 70],\n",
" 'T': [186, 78, 54, 37, 14, 42, 42, 83, 28, 23, 84, 85, 53, 93, 66, 182, 9676, 8, 39, 90],\n",
" 'W': [2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 2, 2, 1, 5, 3, 4, 4, 9960, 3, 4],\n",
" 'Y': [29, 21, 17, 9, 4, 13, 9, 21, 10, 7, 20, 17, 11, 23, 19, 41, 31, 3, 9935, 23],\n",
" 'V': [368, 27, 18, 18, 50, 23, 34, 64, 15, 85, 72, 42, 33, 88, 42, 112, 137, 4, 20, 9514]\n",
"}\n",
"pam_raw = pd.DataFrame(pam_data, index=list(pam_data.keys()))\n",
"pam_matrix = pam_raw.div(pam_raw.sum(axis=1), axis=0)\n",
"list_amino = pam_raw.columns.tolist()\n",
"pam_dict = {\n",
" aa: {sub: pam_matrix.loc[aa, sub] for sub in list_amino}\n",
" for aa in list_amino\n",
"}\n",
"\n",
"def pam1_substitution(aa):\n",
" if aa not in pam_dict:\n",
" return aa\n",
" subs = list(pam_dict[aa].keys())\n",
" probs = list(pam_dict[aa].values())\n",
" return np.random.choice(subs, p=probs)\n",
"\n",
"def augment_sequence(seq, sub_prob=0.05):\n",
" return ''.join([pam1_substitution(aa) if random.random() < sub_prob else aa for aa in seq])\n",
"\n",
"def slice_sequence(seq, win=512):\n",
" return [seq[i:i+win] for i in range(0, len(seq), win)]\n",
"\n",
"def generate_data(df, augment=False):\n",
" X, y = [], []\n",
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
" for _, row in tqdm(df.iterrows(), total=len(df)):\n",
" seq = row[\"sequence\"]\n",
" if augment:\n",
" seq = augment_sequence(seq)\n",
" seq_slices = slice_sequence(seq)\n",
" X.extend(seq_slices)\n",
" lbl = row[label_cols].values.astype(int)\n",
" y.extend([lbl] * len(seq_slices))\n",
" return X, np.array(y), label_cols\n",
"\n",
"def format_sequence(seq): return \" \".join(list(seq))\n",
"\n",
"# Função para carregar e binarizar\n",
"def load_and_binarize(csv_path, mlb=None):\n",
" df = pd.read_csv(csv_path)\n",
" df[\"go_terms\"] = df[\"go_terms\"].str.split(\";\")\n",
" if mlb is None:\n",
" mlb = MultiLabelBinarizer()\n",
" labels = mlb.fit_transform(df[\"go_terms\"])\n",
" else:\n",
" labels = mlb.transform(df[\"go_terms\"])\n",
" labels_df = pd.DataFrame(labels, columns=mlb.classes_)\n",
" df = df.reset_index(drop=True).join(labels_df)\n",
" return df, mlb\n",
"\n",
"# Carregar os dados\n",
"df_train, mlb = load_and_binarize(\"data/mf-training.csv\")\n",
"df_val, _ = load_and_binarize(\"data/mf-validation.csv\", mlb=mlb)\n",
"\n",
"# Gerar com augmentation no treino\n",
"X_train, y_train, term_cols = generate_data(df_train, augment=True)\n",
"X_val, y_val, _ = generate_data(df_val, augment=False)\n",
"\n",
"# Preparar texto para tokenizer\n",
"X_train_fmt = list(map(format_sequence, X_train))\n",
"X_val_fmt = list(map(format_sequence, X_val))\n",
"\n",
"# Fine-tune ProtBERT-BFD\n",
"# https://huggingface.co/Rostlab/prot_bert_bfd\n",
"# https://doi.org/10.1093/bioinformatics/btac020\n",
"# Dados de treino -> BFD (Big Fantastic Database) (2.1 bilhões de sequências)\n",
"MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
"MAX_LEN = 512\n",
"BATCH_SIZE = 1\n",
"\n",
"t = text.Transformer(MODEL_NAME, maxlen=MAX_LEN, classes=term_cols)\n",
"trn = t.preprocess_train(X_train_fmt, y_train)\n",
"val = t.preprocess_test(X_val_fmt, y_val)\n",
"\n",
"model = t.get_classifier()\n",
"learner = ktrain.get_learner(model,\n",
" train_data=trn,\n",
" val_data=val,\n",
" batch_size=BATCH_SIZE)\n",
"\n",
"learner.autofit(lr=1e-5,\n",
" epochs=10,\n",
" early_stopping=1,\n",
" checkpoint_folder=\"mf-fine-tuned-protbertbfd\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9b39c439-5708-4787-bfee-d3a4d3aa190d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processando data/mf-training.csv: 100%|██████████| 31142/31142 [5:17:56<00:00, 1.63it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Guardado embeddings\\train_protbertbfd.pkl — 31142 proteínas\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processando data/mf-validation.csv: 100%|██████████| 1724/1724 [19:15<00:00, 1.49it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Guardado embeddings\\val_protbertbfd.pkl — 1724 proteínas\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processando data/mf-test.csv: 100%|██████████| 1724/1724 [17:15<00:00, 1.66it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Guardado embeddings\\test_protbertbfd.pkl — 1724 proteínas\n"
]
}
],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"import joblib\n",
"import gc\n",
"from transformers import AutoTokenizer, TFAutoModel\n",
"\n",
"# Parâmetros\n",
"MODEL_DIR = \"weights/mf-fine-tuned-protbertbfd\"\n",
"MODEL_NAME = \"Rostlab/prot_bert_bfd\"\n",
"OUT_DIR = \"embeddings\"\n",
"BATCH_TOK = 16\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)\n",
"model = TFAutoModel.from_pretrained(MODEL_DIR, from_pt=False)\n",
"\n",
"print(\"✓ Tokenizer base e modelo fine-tuned carregados com sucesso\")\n",
"\n",
"# Funções auxiliares\n",
"\n",
"def get_embeddings(batch, tokenizer, model):\n",
" tokens = tokenizer(batch, return_tensors=\"tf\", padding=True, truncation=True, max_length=512)\n",
" output = model(**tokens)\n",
" return output.last_hidden_state[:, 0, :].numpy()\n",
"\n",
"def process_split(csv_path, out_path):\n",
" df = pd.read_csv(csv_path)\n",
" label_cols = [col for col in df.columns if col.startswith(\"GO:\")]\n",
" prot_ids, embeds, labels = [], [], []\n",
"\n",
" for _, row in tqdm(df.iterrows(), total=len(df), desc=f\"Processando {csv_path}\"):\n",
" slices = slice_sequence(row[\"sequence\"])\n",
" slices_fmt = list(map(format_sequence, slices))\n",
"\n",
" slice_embeds = []\n",
" for i in range(0, len(slices_fmt), BATCH_TOK):\n",
" batch = slices_fmt[i:i+BATCH_TOK]\n",
" slice_embeds.append(get_embeddings(batch, tokenizer, model))\n",
" slice_embeds = np.vstack(slice_embeds)\n",
"\n",
" prot_embed = slice_embeds.mean(axis=0)\n",
" prot_ids.append(row[\"protein_id\"])\n",
" embeds.append(prot_embed.astype(np.float32))\n",
" labels.append(row[label_cols].values.astype(np.int8))\n",
" gc.collect()\n",
"\n",
" embeds = np.vstack(embeds)\n",
" labels = np.vstack(labels)\n",
"\n",
" joblib.dump({\n",
" \"protein_ids\": prot_ids,\n",
" \"embeddings\": embeds,\n",
" \"labels\": labels,\n",
" \"go_terms\": label_cols\n",
" }, out_path, compress=3)\n",
"\n",
" print(f\"✓ Guardado {out_path} — {embeds.shape[0]} proteínas\")\n",
"\n",
"# Aplicar\n",
"os.makedirs(OUT_DIR, exist_ok=True)\n",
"\n",
"process_split(\"data/mf-training.csv\", os.path.join(OUT_DIR, \"train_protbertbfd.pkl\"))\n",
"process_split(\"data/mf-validation.csv\", os.path.join(OUT_DIR, \"val_protbertbfd.pkl\"))\n",
"process_split(\"data/mf-test.csv\", os.path.join(OUT_DIR, \"test_protbertbfd.pkl\"))\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ad0c5421-e0a1-4a6a-8ace-2c69aeab0e0d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Corrigido: embeddings/train_protbertbfd.pkl — 31142 exemplos, 597 GO terms\n",
"✓ Corrigido: embeddings/val_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n",
"✓ Corrigido: embeddings/test_protbertbfd.pkl — 1724 exemplos, 597 GO terms\n"
]
}
],
"source": [
"import pandas as pd\n",
"import joblib\n",
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"\n",
"# Obter GO terms do ficheiro de teste\n",
"df_test = pd.read_csv(\"data/mf-test.csv\")\n",
"test_terms = sorted(set(term for row in df_test[\"go_terms\"].str.split(\";\") for term in row))\n",
"\n",
"# Função para corrigir um .pkl com base nos GO terms do teste\n",
"def patch_to_common_terms(csv_path, pkl_path, common_terms):\n",
" df = pd.read_csv(csv_path)\n",
" terms_split = df[\"go_terms\"].str.split(\";\")\n",
" \n",
" # Apenas termos presentes nos common_terms\n",
" terms_filtered = terms_split.apply(lambda lst: [t for t in lst if t in common_terms])\n",
" \n",
" mlb = MultiLabelBinarizer(classes=common_terms)\n",
" Y = mlb.fit_transform(terms_filtered)\n",
"\n",
" data = joblib.load(pkl_path)\n",
" data[\"labels\"] = Y\n",
" data[\"go_terms\"] = mlb.classes_.tolist()\n",
" \n",
" joblib.dump(data, pkl_path, compress=3)\n",
" print(f\"✓ Corrigido: {pkl_path} — {Y.shape[0]} exemplos, {Y.shape[1]} GO terms\")\n",
"\n",
"# Aplicar às 3 partições\n",
"patch_to_common_terms(\"data/mf-training.csv\", \"embeddings/train_protbertbfd.pkl\", test_terms)\n",
"patch_to_common_terms(\"data/mf-validation.csv\", \"embeddings/val_protbertbfd.pkl\", test_terms)\n",
"patch_to_common_terms(\"data/mf-test.csv\", \"embeddings/test_protbertbfd.pkl\", test_terms)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1785d8a9-23fc-4490-8d71-29cc91a4cb57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✓ Embeddings carregados: (31142, 1024) → 597 GO terms\n",
"Epoch 1/100\n",
"974/974 [==============================] - 12s 11ms/step - loss: 0.0339 - binary_accuracy: 0.9900 - val_loss: 0.0327 - val_binary_accuracy: 0.9905\n",
"Epoch 2/100\n",
"974/974 [==============================] - 11s 11ms/step - loss: 0.0253 - binary_accuracy: 0.9922 - val_loss: 0.0323 - val_binary_accuracy: 0.9906\n",
"Epoch 3/100\n",
"974/974 [==============================] - 11s 11ms/step - loss: 0.0244 - binary_accuracy: 0.9923 - val_loss: 0.0326 - val_binary_accuracy: 0.9906\n",
"Epoch 4/100\n",
"974/974 [==============================] - 11s 11ms/step - loss: 0.0239 - binary_accuracy: 0.9925 - val_loss: 0.0328 - val_binary_accuracy: 0.9906\n",
"Epoch 5/100\n",
"974/974 [==============================] - 11s 11ms/step - loss: 0.0236 - binary_accuracy: 0.9925 - val_loss: 0.0321 - val_binary_accuracy: 0.9906\n",
"Epoch 6/100\n",
"974/974 [==============================] - 11s 11ms/step - loss: 0.0233 - binary_accuracy: 0.9926 - val_loss: 0.0328 - val_binary_accuracy: 0.9907\n",
"Epoch 7/100\n",
"974/974 [==============================] - 11s 11ms/step - loss: 0.0232 - binary_accuracy: 0.9926 - val_loss: 0.0330 - val_binary_accuracy: 0.9908\n",
"Epoch 8/100\n",
"974/974 [==============================] - 11s 12ms/step - loss: 0.0229 - binary_accuracy: 0.9927 - val_loss: 0.0325 - val_binary_accuracy: 0.9907\n",
"Epoch 9/100\n",
"974/974 [==============================] - 12s 12ms/step - loss: 0.0226 - binary_accuracy: 0.9927 - val_loss: 0.0327 - val_binary_accuracy: 0.9906\n",
"Epoch 10/100\n",
"974/974 [==============================] - 12s 12ms/step - loss: 0.0226 - binary_accuracy: 0.9927 - val_loss: 0.0327 - val_binary_accuracy: 0.9907\n",
"54/54 [==============================] - 0s 2ms/step\n",
"Previsões guardadas em mf-protbertbfd-pam1.npy\n",
"Modelo guardado em models/\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import joblib\n",
"import numpy as np\n",
"from tensorflow.keras import Input\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Dropout\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"\n",
"# Carregar embeddings\n",
"train = joblib.load(\"embeddings/train_protbertbfd.pkl\")\n",
"val = joblib.load(\"embeddings/val_protbertbfd.pkl\")\n",
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
"\n",
"X_train, y_train = train[\"embeddings\"], train[\"labels\"]\n",
"X_val, y_val = val[\"embeddings\"], val[\"labels\"]\n",
"X_test, y_test = test[\"embeddings\"], test[\"labels\"]\n",
"\n",
"print(f\"✓ Embeddings carregados: {X_train.shape} → {y_train.shape[1]} GO terms\")\n",
"\n",
"# Garantir consistência de classes\n",
"max_classes = y_train.shape[1] # 602 GO terms (do treino)\n",
"\n",
"def pad_labels(y, target_dim=max_classes):\n",
" if y.shape[1] < target_dim:\n",
" padding = np.zeros((y.shape[0], target_dim - y.shape[1]), dtype=np.int8)\n",
" return np.hstack([y, padding])\n",
" return y\n",
"\n",
"y_val = pad_labels(y_val)\n",
"y_test = pad_labels(y_test)\n",
"\n",
"# Modelo MLP\n",
"model = Sequential([\n",
" Dense(1024, activation=\"relu\", input_shape=(X_train.shape[1],)),\n",
" Dropout(0.3),\n",
" Dense(512, activation=\"relu\"),\n",
" Dropout(0.3),\n",
" Dense(max_classes, activation=\"sigmoid\")\n",
"])\n",
"\n",
"model.compile(loss=\"binary_crossentropy\",\n",
" optimizer=\"adam\",\n",
" metrics=[\"binary_accuracy\"])\n",
"\n",
"# Early stopping e treino\n",
"callbacks = [\n",
" EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True)\n",
"]\n",
"\n",
"model.fit(X_train, y_train,\n",
" validation_data=(X_val, y_val),\n",
" epochs=100,\n",
" batch_size=32,\n",
" callbacks=callbacks,\n",
" verbose=1)\n",
"\n",
"# Previsões\n",
"y_prob = model.predict(X_test)\n",
"np.save(\"predictions/mf-protbertbfd-pam1.npy\", y_prob)\n",
"print(\"Previsões guardadas em mf-protbertbfd-pam1.npy\")\n",
"\n",
"# Modelo\n",
"model.save(\"models/mlp_protbertbfd.h5\")\n",
"model.save(\"models/mlp_protbertbfd.keras\")\n",
"print(\"Modelo guardado em models/\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fdb66630-76dc-43a0-bd56-45052175fdba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"go.obo: fmt(1.2) rel(2025-03-16) 43,544 Terms\n",
"✓ Embeddings: (1724, 597) labels × 597 GO terms\n",
"\n",
"📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\n",
"Fmax = 0.6588\n",
"Thr. = 0.46\n",
"AuPRC = 0.6991\n",
"Smin = 13.5461\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn.metrics import precision_recall_curve, auc\n",
"from goatools.obo_parser import GODag\n",
"import joblib\n",
"import math\n",
"\n",
"# Parâmetros\n",
"GO_FILE = \"go.obo\"\n",
"THRESHOLDS = np.arange(0.0, 1.01, 0.01)\n",
"ALPHA = 0.5\n",
"\n",
"# Carregar dados\n",
"test = joblib.load(\"embeddings/test_protbertbfd.pkl\")\n",
"y_true = test[\"labels\"]\n",
"terms = test[\"go_terms\"]\n",
"y_prob = np.load(\"predictions/mf-protbertbfd-pam1.npy\")\n",
"go_dag = GODag(GO_FILE)\n",
"\n",
"print(f\"✓ Embeddings: {y_true.shape} labels × {len(terms)} GO terms\")\n",
"\n",
"# Fmax\n",
"def compute_fmax(y_true, y_prob, thresholds):\n",
" fmax, best_thr = 0, 0\n",
" for t in thresholds:\n",
" y_pred = (y_prob >= t).astype(int)\n",
" tp = (y_true * y_pred).sum(axis=1)\n",
" fp = ((1 - y_true) * y_pred).sum(axis=1)\n",
" fn = (y_true * (1 - y_pred)).sum(axis=1)\n",
" precision = tp / (tp + fp + 1e-8)\n",
" recall = tp / (tp + fn + 1e-8)\n",
" f1 = 2 * precision * recall / (precision + recall + 1e-8)\n",
" avg_f1 = np.mean(f1)\n",
" if avg_f1 > fmax:\n",
" fmax, best_thr = avg_f1, t\n",
" return fmax, best_thr\n",
"\n",
"# AuPRC micro\n",
"def compute_auprc(y_true, y_prob):\n",
" precision, recall, _ = precision_recall_curve(y_true.ravel(), y_prob.ravel())\n",
" return auc(recall, precision)\n",
"\n",
"# Smin\n",
"def compute_smin(y_true, y_prob, terms, threshold, go_dag, alpha=ALPHA):\n",
" y_pred = (y_prob >= threshold).astype(int)\n",
" ic = {}\n",
" total = (y_true + y_pred).sum(axis=0).sum()\n",
" for i, term in enumerate(terms):\n",
" freq = (y_true[:, i] + y_pred[:, i]).sum()\n",
" ic[term] = -np.log((freq + 1e-8) / total)\n",
"\n",
" s_values = []\n",
" for true_vec, pred_vec in zip(y_true, y_pred):\n",
" true_terms = {terms[i] for i in np.where(true_vec)[0]}\n",
" pred_terms = {terms[i] for i in np.where(pred_vec)[0]}\n",
"\n",
" anc_true = set()\n",
" for t in true_terms:\n",
" if t in go_dag:\n",
" anc_true |= go_dag[t].get_all_parents()\n",
" anc_pred = set()\n",
" for t in pred_terms:\n",
" if t in go_dag:\n",
" anc_pred |= go_dag[t].get_all_parents()\n",
"\n",
" ru = pred_terms - true_terms\n",
" mi = true_terms - pred_terms\n",
" dist_ru = sum(ic.get(t, 0) for t in ru)\n",
" dist_mi = sum(ic.get(t, 0) for t in mi)\n",
" s = math.sqrt((alpha * dist_ru)**2 + ((1 - alpha) * dist_mi)**2)\n",
" s_values.append(s)\n",
"\n",
" return np.mean(s_values)\n",
"\n",
"# Avaliar\n",
"fmax, thr = compute_fmax(y_true, y_prob, THRESHOLDS)\n",
"auprc = compute_auprc(y_true, y_prob)\n",
"smin = compute_smin(y_true, y_prob, terms, thr, go_dag)\n",
"\n",
"print(f\"\\n📊 Resultados finais (ProtBERTBFD + PAM1 + propagação):\")\n",
"print(f\"Fmax = {fmax:.4f}\")\n",
"print(f\"Thr. = {thr:.2f}\")\n",
"print(f\"AuPRC = {auprc:.4f}\")\n",
"print(f\"Smin = {smin:.4f}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70d131ef-ef84-42ee-953b-0d3f1268694d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}