{ "cells": [ { "cell_type": "code", "execution_count": 15, "id": "78731790-cecc-4e7b-9599-c35a9fad1c11", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A gerar embeddings …\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "A fazer predições individuais …\n", "1/1 [==============================] - 0s 47ms/step\n", "1/1 [==============================] - 0s 33ms/step\n", "1/1 [==============================] - 0s 30ms/step\n" ] }, { "ename": "ValueError", "evalue": "in user code:\n\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2341, in predict_function *\n return step_function(self, iterator)\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2327, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2315, in run_step **\n outputs = model.predict_step(data)\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2283, in predict_step\n return self(x, training=False)\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\utils\\traceback_utils.py\", line 70, in error_handler\n raise e.with_traceback(filtered_tb) from None\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\input_spec.py\", line 298, in assert_input_compatibility\n raise ValueError(\n\n ValueError: Input 0 of layer \"sequential\" is incompatible with the layer: expected shape=(None, 1779), found shape=(None, 1791)\n", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[15], line 47\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[38;5;66;03m# --- 4. Ensemble (stacking) -----------------------------------------------\u001b[39;00m\n\u001b[0;32m 46\u001b[0m X_stack \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate([y_pb, y_bfd, y_esm], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 47\u001b[0m y_ens \u001b[38;5;241m=\u001b[39m \u001b[43mstacking\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_stack\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 49\u001b[0m \u001b[38;5;66;03m# --- 5. Carregar MultiLabelBinarizer ---------------------------------------\u001b[39;00m\n\u001b[0;32m 50\u001b[0m mlb \u001b[38;5;241m=\u001b[39m joblib\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata/mlb_597.pkl\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[1;32m~\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\utils\\traceback_utils.py:70\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 67\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n\u001b[0;32m 68\u001b[0m \u001b[38;5;66;03m# To get the full stack trace, call:\u001b[39;00m\n\u001b[0;32m 69\u001b[0m \u001b[38;5;66;03m# `tf.debugging.disable_traceback_filtering()`\u001b[39;00m\n\u001b[1;32m---> 70\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\u001b[38;5;241m.\u001b[39mwith_traceback(filtered_tb) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 72\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m filtered_tb\n", "File \u001b[1;32m~\\AppData\\Local\\Temp\\__autograph_generated_filen1meoyfq.py:15\u001b[0m, in \u001b[0;36mouter_factory..inner_factory..tf__predict_function\u001b[1;34m(iterator)\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 14\u001b[0m do_return \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m---> 15\u001b[0m retval_ \u001b[38;5;241m=\u001b[39m ag__\u001b[38;5;241m.\u001b[39mconverted_call(ag__\u001b[38;5;241m.\u001b[39mld(step_function), (ag__\u001b[38;5;241m.\u001b[39mld(\u001b[38;5;28mself\u001b[39m), ag__\u001b[38;5;241m.\u001b[39mld(iterator)), \u001b[38;5;28;01mNone\u001b[39;00m, fscope)\n\u001b[0;32m 16\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[0;32m 17\u001b[0m do_return \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", "\u001b[1;31mValueError\u001b[0m: in user code:\n\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2341, in predict_function *\n return step_function(self, iterator)\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2327, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2315, in run_step **\n outputs = model.predict_step(data)\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\training.py\", line 2283, in predict_step\n return self(x, training=False)\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\utils\\traceback_utils.py\", line 70, in error_handler\n raise e.with_traceback(filtered_tb) from None\n File \"C:\\Users\\Melvin\\anaconda3\\envs\\protein_env\\lib\\site-packages\\keras\\src\\engine\\input_spec.py\", line 298, in assert_input_compatibility\n raise ValueError(\n\n ValueError: Input 0 of layer \"sequential\" is incompatible with the layer: expected shape=(None, 1779), found shape=(None, 1791)\n" ] } ], "source": [ "# %%\n", "import numpy as np\n", "import torch\n", "from transformers import AutoTokenizer, AutoModel\n", "from tensorflow.keras.models import load_model\n", "import joblib\n", "\n", "# Parâmetros\n", "SEQ_FASTA = \"MFNVESVERVELCESLLTWIQTFNVDAPCQTAEDLTNGVVMSQVLQKIDPVYFDDNWLNRIKTEVGDNWRLKISNLKKILKGILDYNHEILGQQINDFTLPDVNLIGEHSDAAELGRMLQLILGCAVNCEQKQEYIQAIMMMEESVQHVVMTAIQELMSKESPVSAGHDAYVDLDRQLKKTTEELNEALSAKEEIAQRCHELDMQVAALQEEKSSLLAENQILMERLNQSDSIEDPNSPAGRRHLQLQTQLEQLQEETFRLEAAKDDYRIRCEELEKEISELRQQNDELTTLADEAQSLKDEIDVLRHSSDKVSKLEGQVESYKKKLEDLGDLRRQVKLLEEKNTMYMQNTVSLEEELRKANAARGQLETYKRQVVELQNRLSDESKKADKLDFEYKRLKEKVDGLQKEKDRLRTERDSLKETIEELRCVQAQEGQLTTQGLMPLGSQESSDSLAAEIVTPEIREKLIRLQHENKMLKLNQEDSDNEKIALLQSLLDDANLRKNELETENRLVNQRLLEVQSQVEELQKSLQDQGSKAEDSVLLKKKLEEHLEKLHEANNELQKKRAIIEDLEPRFNNSSLRIEELQEALRKKEEEMKQMEERYKKYLEKAKSVIRTLDPKQNQGAAPEIQALKNQLQERDRLFHSLEKEYEKTKSQRDMEEKYIVSAWYNMGMTLHKKAAEDRLASTGSGQSFLARQRQATSTRRSYPGHVQPATAR\" # (mantém a tua sequência completa)\n", "TOP_N = 10\n", "THRESH = 0.37 \n", "\n", "# Funções auxiliares\n", "def get_embedding_mean(model_name, seq, chunk):\n", " tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)\n", " model = AutoModel.from_pretrained(model_name)\n", " model.eval()\n", "\n", " chunks = [seq[i:i+chunk] for i in range(0, len(seq), chunk)]\n", " reps = []\n", " for c in chunks:\n", " tokens = tokenizer(\" \".join(c), return_tensors=\"pt\", truncation=False, padding=False)\n", " with torch.no_grad():\n", " reps.append(model(**tokens).last_hidden_state[:, 0, :].squeeze().numpy())\n", " return np.mean(reps, axis=0, keepdims=True) # shape (1, dim)\n", "\n", "# Embeddings\n", "print(\"A gerar embeddings …\")\n", "emb_pb = get_embedding_mean(\"Rostlab/prot_bert\", SEQ_FASTA, 512)\n", "emb_bfd = get_embedding_mean(\"Rostlab/prot_bert_bfd\", SEQ_FASTA, 512)\n", "emb_esm = get_embedding_mean(\"facebook/esm2_t33_650M_UR50D\", SEQ_FASTA, 1024)\n", "\n", "# Carregar modelos\n", "mlp_pb = load_model(\"models/mlp_protbert.h5\")\n", "mlp_bfd = load_model(\"models/mlp_protbertbfd.h5\")\n", "mlp_esm = load_model(\"models/mlp_esm2.h5\")\n", "stacking = load_model(\"models/ensemble_stack.h5\")\n", "\n", "# Predições dos MLPs base\n", "print(\"A fazer predições individuais …\")\n", "y_pb = mlp_pb.predict(emb_pb)[:, :597]\n", "y_bfd = mlp_bfd.predict(emb_bfd)[:, :597]\n", "y_esm = mlp_esm.predict(emb_esm)[:, :597]\n", "\n", "# --- 4. Ensemble (stacking)\n", "X_stack = np.concatenate([y_pb, y_bfd, y_esm], axis=1)\n", "y_ens = stacking.predict(X_stack)\n", "\n", "# --- 5. Carregar MultiLabelBinarizer\n", "mlb = joblib.load(\"data/mlb_597.pkl\")\n", "GO = mlb.classes_\n", "\n", "# --- 6. Função para mostrar resultados\n", "def print_results(name, y_pred):\n", " print(f\"\\n {name}\")\n", " # GO terms acima do limiar\n", " terms = mlb.inverse_transform((y_pred >= THRESH).astype(int))\n", " print(f\" GO terms com prob ≥ {THRESH}:\")\n", " print(\" \", terms[0] if terms[0] else \"Nenhum\")\n", "\n", " # Top-N\n", " top_idx = np.argsort(-y_pred[0])[:TOP_N]\n", " print(f\" Top {TOP_N} mais prováveis:\")\n", " for i in top_idx:\n", " print(f\" {GO[i]} : {y_pred[0][i]:.4f}\")\n", "\n", "# Imprimir tudo\n", "print_results(\"ProtBERT (MLP)\", y_pb)\n", "print_results(\"ProtBERT-BFD (MLP)\", y_bfd)\n", "print_results(\"ESM-2 (MLP)\", y_esm)\n", "print_results(\"Ensemble (Stacking)\", y_ens)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "70a3035b-01cd-4c63-b34d-d520d2aa88bf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }