{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FateZ Explain \n", "\n", "This notebook demonstrate how to utilize explanatory methods of FateZ models." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done Import\n" ] } ], "source": [ "import os\n", "import sys\n", "import torch\n", "from torch.utils.data import DataLoader\n", "import fatez.lib as lib\n", "import fatez.test as test\n", "import fatez.model as model\n", "import fatez.tool.JSON as JSON\n", "import fatez.process as process\n", "import fatez.process.worker as worker\n", "import fatez.process.fine_tuner as fine_tuner\n", "import fatez.process.pre_trainer as pre_trainer\n", "from pkg_resources import resource_filename\n", "\n", "print('Done Import')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build model and make some fake data first." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done Init\n" ] } ], "source": [ "# Parameters\n", "params = {\n", " 'n_sample': 10, # Fake samples to make\n", " 'batch_size': 1, # Batch size\n", "}\n", "\n", "# Init worker env\n", "config = JSON.decode(resource_filename(\n", " __name__, '../../fatez/data/config/gat_bert_config.json'\n", " )\n", ")\n", "suppressor = process.Quiet_Mode()\n", "device = [0]\n", "dtype = torch.float32\n", "worker.setup(device, master_port = '2307')\n", "\n", "# Generate Fake data\n", "faker = test.Faker(model_config = config, **params)\n", "train_dataloader = faker.make_data_loader()\n", "\n", "print('Done Init')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now we perform pre-training with no label.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here trainer's $train\\_adj$ is set to False, and the model is NOT reconstructing the adjacency matrices, etc." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Loss\n", "0 1.307709\n", "1 7.125144\n", "2 7.212573\n", "3 7.185053\n", "4 1.215182\n", "5 7.039515\n", "6 6.969513\n", "7 1.040767\n", "8 1.182723\n", "9 1.147289\n", "10 4.142547\n" ] } ], "source": [ "trainer = pre_trainer.Set(config, dtype=dtype, device=device)\n", "suppressor.on()\n", "report = trainer.train(train_dataloader, report_batch = True)\n", "suppressor.off()\n", "print(report)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the case of pre-training with reconstructing adjacency matrices as well." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Loss\n", "0 1.197746\n", "1 7.171856\n", "2 7.135652\n", "3 7.115325\n", "4 1.138201\n", "5 1.099432\n", "6 7.150834\n", "7 1.107591\n", "8 1.277280\n", "9 7.157175\n", "10 4.155109\n" ] } ], "source": [ "config['pre_trainer']['train_adj'] = True\n", "trainer = pre_trainer.Set(config, dtype = dtype, device = device)\n", "suppressor.on()\n", "report = trainer.train(train_dataloader, report_batch = True)\n", "suppressor.off()\n", "print(report)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Then, we can go for fine tuning part with class labels." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Loss ACC\n", "0 0.649586 1.0\n", "1 0.743881 0.0\n", "2 0.645328 1.0\n", "3 0.744279 0.0\n", "4 0.645847 1.0\n", "5 0.645064 1.0\n", "6 0.744720 0.0\n", "7 0.747999 0.0\n", "8 0.645316 1.0\n", "9 0.746752 0.0\n", "10 0.695877 0.5\n" ] } ], "source": [ "tuner = fine_tuner.Set(config, prev_model = trainer, dtype = dtype, device = device)\n", "report = tuner.train(train_dataloader, report_batch = True,)\n", "print(report)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### To explain model.\n", "\n", "Three kinds of explanations are available:\n", "1. edge_explain\n", "2. regulon_explain\n", "3. node_explain" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Explaining 2 classes.\n", "Each class has regulon explain in shape of (4, 4).\n", "\n", "Edge Explain:\n", " tensor([[0.0302, 0.0293, 0.0292, 0.0280, 0.0272, 0.0294, 0.0258, 0.0309, 0.0000,\n", " 0.0291],\n", " [0.0342, 0.0287, 0.0284, 0.0255, 0.0264, 0.0269, 0.0298, 0.0260, 0.0000,\n", " 0.0301],\n", " [0.0294, 0.0231, 0.0260, 0.0229, 0.0267, 0.0267, 0.0304, 0.0262, 0.0000,\n", " 0.0266],\n", " [0.0254, 0.0262, 0.0260, 0.0251, 0.0285, 0.0309, 0.0297, 0.0265, 0.0000,\n", " 0.0287]]) \n", "\n", "Reg Explain:\n", " tensor([0.0008, 0.0008, 0.0003, 0.0016], dtype=torch.float64) \n", "\n", "Node Explain:\n", " tensor([9.7898e-05, 9.2607e-05, 9.2633e-05, 8.7263e-05, 9.3675e-05, 9.9449e-05,\n", " 9.7947e-05, 9.3009e-05, 0.0000e+00, 9.8218e-05], dtype=torch.float64) \n", "\n" ] } ], "source": [ "# Initializing edge explain matrix and regulon explain matrix\n", "adj_exp = torch.zeros((config['input_sizes']['n_reg'], config['input_sizes']['n_node']))\n", "reg_exp = torch.zeros((config['input_sizes']['n_reg'], config['encoder']['d_model']))\n", "\n", "# Make background data\n", "bg = [a for a,_ in DataLoader(train_dataloader.dataset, faker.n_sample, collate_fn = lib.collate_fn,)][0]\n", "# Set explainer through taking input data from pseudo-dataloader\n", "explain = tuner.model.make_explainer([a.to(tuner.device) for a in bg])\n", "\n", "for x,_ in train_dataloader:\n", " data = [a.to(tuner.device) for a in x]\n", " adj_temp, reg_temp, vars = tuner.model.explain_batch(data, explain)\n", " adj_exp += adj_temp\n", " \n", " print(f'Explaining {len(reg_temp)} classes.')\n", " \n", " # Only the feat mat explanation should be working\n", " print(f'Each class has regulon explain in shape of {reg_temp[0][0].shape}.\\n')\n", "\n", " # Only taking explainations for class 0\n", " for exp in reg_temp[0]: reg_exp += abs(exp)\n", " break\n", "\n", "reg_exp = torch.sum(reg_exp, dim = -1)\n", "node_exp = torch.matmul(reg_exp, adj_exp.type(reg_exp.dtype))\n", "print('Edge Explain:\\n', adj_exp, '\\n')\n", "print('Reg Explain:\\n', reg_exp, '\\n')\n", "print('Node Explain:\\n', node_exp, '\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cleanup Env\n", "Need to clean up environment once finsihed." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Clean up worker env.\n" ] } ], "source": [ "worker.cleanup(device)\n", "print('Clean up worker env.')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 2 }