{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# FateZ Multiomic Pertubation Effect Prediction(?)\n", "This notebook demonstrate how to implement Pertubation Effect Prediction method with FateZ's modules." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done Import\n" ] } ], "source": [ "import os\n", "import sys\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "import pandas as pd\n", "import fatez.lib as lib\n", "import fatez.test as test\n", "import fatez.model as model\n", "import fatez.tool.JSON as JSON\n", "import fatez.process as process\n", "import fatez.process.worker as worker\n", "import fatez.process.fine_tuner as fine_tuner\n", "import fatez.process.pre_trainer as pre_trainer\n", "from pkg_resources import resource_filename\n", "\n", "print('Done Import')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build model and make some fake data first." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done Init\n", "Done Fake Data\n" ] } ], "source": [ "# Parameters\n", "params = {\n", " 'n_sample': 10, # Fake samples to make\n", " 'batch_size': 2, # Batch size\n", "}\n", "\n", "# Init worker env\n", "config = JSON.decode(resource_filename(\n", " __name__, '../../fatez/data/config/gat_bert_config.json'\n", " )\n", ")\n", "suppressor = process.Quiet_Mode()\n", "device = 'cuda'\n", "# device = [0] # Applying DDP if having multiple devices\n", "dtype = torch.float32\n", "worker.setup(device)\n", "\n", "print('Done Init')\n", "\n", "# Generate Fake data\n", "faker = test.Faker(model_config = config, dtype = dtype, **params)\n", "pertubation_dataloader = faker.make_data_loader()\n", "result_dataloader = faker.make_data_loader()\n", "\n", "# Make id of pertubation result the 'label' of each sample\n", "for i,k in enumerate(pertubation_dataloader.dataset.samples):\n", " k.y = i\n", " \n", "print('Done Fake Data')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The model will be architecturally similar with a pretrainer" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model Set\n" ] } ], "source": [ "trainer = pre_trainer.Set(config, dtype = dtype, device=device)\n", "\n", "print('Model Set')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### However, the training part will be littel bit different\n", "This part is modified based on pre_trainer.Trainer.train()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Loss\n", "0 4.120818\n" ] } ], "source": [ "report_batch = False\n", "size = trainer.input_sizes\n", "\n", "trainer.worker.train(True)\n", "best_loss = 99\n", "loss_all = 0\n", "report = list()\n", "\n", "for x,y in pertubation_dataloader:\n", " \n", " # Prepare input data as always\n", " input = [ele.to(trainer.device) for ele in x]\n", " \n", " # Mute some debug outputs\n", " suppressor.on()\n", " node_rec, adj_rec = trainer.worker(input)\n", " suppressor.off()\n", " \n", " # Prepare pertubation result data using a seperate dataloader\n", " y = [result_dataloader.dataset.samples[ele].to(trainer.device) for ele in y]\n", " # Please be noted here that this script is only reconstructing TF parts\n", " # To reconstruct whole genome, we can certainly add an additionaly layer which takes adj_rec and node_rec to do the job.\n", " node_results = torch.stack([ele.x for ele in input], 0)\n", " adj_results = lib.get_dense_adjs(\n", " y, (size['n_reg'],size['n_node'],size['edge_attr'])\n", " )\n", " \n", " # Get total loss\n", " loss = trainer.criterion(node_rec, node_results)\n", " if adj_rec is not None:\n", " loss += trainer.criterion(adj_rec, adj_results)\n", " \n", " # Some backward stuffs here\n", " loss.backward()\n", " nn.utils.clip_grad_norm_(trainer.model.parameters(), trainer.max_norm)\n", " trainer.optimizer.step()\n", " trainer.optimizer.zero_grad()\n", "\n", " # Accumulate\n", " best_loss = min(best_loss, loss.item())\n", " loss_all += loss.item()\n", "\n", " # Some logs\n", " if report_batch: report.append([loss.item()])\n", "\n", "\n", "trainer.scheduler.step()\n", "report.append([loss_all / len(pertubation_dataloader)])\n", "report = pd.DataFrame(report)\n", "report.columns = ['Loss', ]\n", "print(report)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### In the case of tuning unlabeled data, which does not have pertubation results... \n", "We shall set another trainer using previous model." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Loss\n", "0 4.340387\n" ] } ], "source": [ "tuner = pre_trainer.Set(config, prev_model = trainer, dtype = dtype, device = device)\n", "\n", "# Some new fake data\n", "tuner_dataloader = faker.make_data_loader()\n", "\n", "# And the tuning process is also based on input reconstruction as pretraining\n", "suppressor.on()\n", "report = tuner.train(tuner_dataloader, report_batch = False,)\n", "suppressor.off()\n", "print(report)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Then we shall just use trainer object to make predictions.\n", "Similar with the training block above for trainer, but no need to prepare y." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 10, 2]) torch.Size([2, 4, 10])\n", "torch.Size([2, 10, 2]) torch.Size([2, 4, 10])\n", "torch.Size([2, 10, 2]) torch.Size([2, 4, 10])\n", "torch.Size([2, 10, 2]) torch.Size([2, 4, 10])\n", "torch.Size([2, 10, 2]) torch.Size([2, 4, 10])\n" ] } ], "source": [ "trainer.model.eval()\n", "\n", "for x,_ in tuner_dataloader:\n", " \n", " # Prepare input data as always\n", " input = [ele.to(trainer.device) for ele in x]\n", " \n", " # Mute some debug outputs\n", " suppressor.on()\n", " node_rec, adj_rec = trainer.model(input)\n", " suppressor.off()\n", " print(node_rec.shape, adj_rec.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cleanup Env\n", "Need to clean up environment once finsihed." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Clean up worker env.\n" ] } ], "source": [ "worker.cleanup(device)\n", "print('Clean up worker env.')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 2 }