{ "cells": [ { "cell_type": "markdown", "id": "b69b2ab7", "metadata": {}, "source": [ "# FateZ Clustering \n", "\n", "This notebook demonstrate how to implement clustering method with FateZ's representing method" ] }, { "cell_type": "code", "execution_count": 5, "id": "dd050393", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "This part is yet to be modified!\n", "Done Import\n" ] } ], "source": [ "print('This part is yet to be modified!')\n", "\n", "import os\n", "import torch\n", "import numpy as np\n", "from torch.utils.data import DataLoader\n", "from pkg_resources import resource_filename\n", "from sklearn import cluster\n", "import fatez.test as test\n", "import fatez.model as model\n", "# import scanpy as sc\n", "\n", "print('Done Import')" ] }, { "cell_type": "markdown", "id": "10210c8d", "metadata": {}, "source": [ "### Initialize testing model first." ] }, { "cell_type": "code", "execution_count": 6, "id": "3948e182", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Testing Full Model.\n", "\n", "\tPre-Trainer Green.\n", "\n", "\tFine-Tuner Green.\n", "\n", "Edge Explain:\n", " tensor([[0.1573, 0.1524, 0.1506, 0.1357, 0.1349, 0.1365, 0.1361, 0.1353, 0.0000,\n", " 0.1383],\n", " [0.1358, 0.1370, 0.1243, 0.1215, 0.1255, 0.1237, 0.1196, 0.1621, 0.0000,\n", " 0.1627],\n", " [0.1638, 0.1475, 0.1416, 0.1444, 0.1399, 0.1464, 0.1395, 0.1368, 0.0000,\n", " 0.1393],\n", " [0.1305, 0.1263, 0.1264, 0.1260, 0.1198, 0.1493, 0.1356, 0.1604, 0.0000,\n", " 0.1371]]) \n", "\n", "Reg Explain:\n", " tensor([0., 0., 0., 0.], dtype=torch.float64) \n", "\n", "Node Explain:\n", " tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64) \n", "\n", "\tExplainer Green.\n", "\n" ] } ], "source": [ "faker = test.Faker()\n", "testM, _ = faker.test_full_model()\n", "# model.Save(faker.test_gat(), '../data/ignore/gat.model')\n", "# model.Save(testM, '../data/ignore/trainer.model')" ] }, { "cell_type": "markdown", "id": "255aa807", "metadata": {}, "source": [ "### Get the fake dataset" ] }, { "cell_type": "code", "execution_count": 8, "id": "82ce2fb9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
       " in <cell line: 2>:2                                                                              \n",
       "                                                                                                  \n",
       "   1 dataset = faker.make_data_loader().dataset                                                   \n",
       " 2 for x in DataLoader(dataset, batch_size = len(dataset)):                                     \n",
       "   3 │   all_fea_mat = x[0]                                                                       \n",
       "   4 │   all_adj_mat = x[1]                                                                       \n",
       "   5 print(f'Labels:\\n{labels.tolist()}')                                                         \n",
       "                                                                                                  \n",
       " /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:634 in __next__           \n",
       "                                                                                                  \n",
       "    631 │   │   │   if self._sampler_iter is None:                                                \n",
       "    632 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   \n",
       "    633 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   \n",
       "  634 │   │   │   data = self._next_data()                                                      \n",
       "    635 │   │   │   self._num_yielded += 1                                                        \n",
       "    636 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \\                          \n",
       "    637 │   │   │   │   │   self._IterableDataset_len_called is not None and \\                    \n",
       "                                                                                                  \n",
       " /usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:678 in _next_data         \n",
       "                                                                                                  \n",
       "    675 │                                                                                         \n",
       "    676 │   def _next_data(self):                                                                 \n",
       "    677 │   │   index = self._next_index()  # may raise StopIteration                             \n",
       "  678 │   │   data = self._dataset_fetcher.fetch(index)  # may raise StopIteration              \n",
       "    679 │   │   if self._pin_memory:                                                              \n",
       "    680 │   │   │   data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)            \n",
       "    681 │   │   return data                                                                       \n",
       "                                                                                                  \n",
       " /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py:54 in fetch             \n",
       "                                                                                                  \n",
       "   51 │   │   │   │   data = [self.dataset[idx] for idx in possibly_batched_index]                \n",
       "   52 │   │   else:                                                                               \n",
       "   53 │   │   │   data = self.dataset[possibly_batched_index]                                     \n",
       " 54 │   │   return self.collate_fn(data)                                                        \n",
       "   55                                                                                             \n",
       "                                                                                                  \n",
       " /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py:264 in                \n",
       " default_collate                                                                                  \n",
       "                                                                                                  \n",
       "   261 │   │   │   >>> default_collate_fn_map.update(CustoType, collate_customtype_fn)            \n",
       "   262 │   │   │   >>> default_collate(batch)  # Handle `CustomType` automatically                \n",
       "   263 \"\"\"                                                                                    \n",
       " 264 return collate(batch, collate_fn_map=default_collate_fn_map)                           \n",
       "   265                                                                                            \n",
       "                                                                                                  \n",
       " /usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/collate.py:150 in collate        \n",
       "                                                                                                  \n",
       "   147 │   │   │   │   # The sequence type may not support `__init__(iterable)` (e.g., `range`)   \n",
       "   148 │   │   │   │   return [collate(samples, collate_fn_map=collate_fn_map) for samples in t   \n",
       "   149 │                                                                                          \n",
       " 150 raise TypeError(default_collate_err_msg_format.format(elem_type))                      \n",
       "   151                                                                                            \n",
       "   152                                                                                            \n",
       "   153 def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ..   \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class \n",
       "'torch_geometric.data.data.Data'>\n",
       "
\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1 \u001b[0mdataset = faker.make_data_loader().dataset \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2 \u001b[94mfor\u001b[0m x \u001b[95min\u001b[0m DataLoader(dataset, batch_size = \u001b[96mlen\u001b[0m(dataset)): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0m\u001b[2m│ \u001b[0mall_fea_mat = x[\u001b[94m0\u001b[0m] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m4 \u001b[0m\u001b[2m│ \u001b[0mall_adj_mat = x[\u001b[94m1\u001b[0m] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0m\u001b[96mprint\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m'\u001b[0m\u001b[33mLabels:\u001b[0m\u001b[33m\\n\u001b[0m\u001b[33m{\u001b[0mlabels.tolist()\u001b[33m}\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/torch/utils/data/\u001b[0m\u001b[1;33mdataloader.py\u001b[0m:\u001b[94m634\u001b[0m in \u001b[92m__next__\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 631 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m._sampler_iter \u001b[95mis\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 632 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 633 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._reset() \u001b[2m# type: ignore[call-arg]\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 634 \u001b[2m│ │ │ \u001b[0mdata = \u001b[96mself\u001b[0m._next_data() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 635 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m._num_yielded += \u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 636 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m._dataset_kind == _DatasetKind.Iterable \u001b[95mand\u001b[0m \\ \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 637 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._IterableDataset_len_called \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m \u001b[95mand\u001b[0m \\ \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/torch/utils/data/\u001b[0m\u001b[1;33mdataloader.py\u001b[0m:\u001b[94m678\u001b[0m in \u001b[92m_next_data\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 675 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 676 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_next_data\u001b[0m(\u001b[96mself\u001b[0m): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 677 \u001b[0m\u001b[2m│ │ \u001b[0mindex = \u001b[96mself\u001b[0m._next_index() \u001b[2m# may raise StopIteration\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 678 \u001b[2m│ │ \u001b[0mdata = \u001b[96mself\u001b[0m._dataset_fetcher.fetch(index) \u001b[2m# may raise StopIteration\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 679 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m._pin_memory: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 680 \u001b[0m\u001b[2m│ │ │ \u001b[0mdata = _utils.pin_memory.pin_memory(data, \u001b[96mself\u001b[0m._pin_memory_device) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 681 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m data \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/\u001b[0m\u001b[1;33mfetch.py\u001b[0m:\u001b[94m54\u001b[0m in \u001b[92mfetch\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m51 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mdata = [\u001b[96mself\u001b[0m.dataset[idx] \u001b[94mfor\u001b[0m idx \u001b[95min\u001b[0m possibly_batched_index] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m52 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m53 \u001b[0m\u001b[2m│ │ │ \u001b[0mdata = \u001b[96mself\u001b[0m.dataset[possibly_batched_index] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m54 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[96mself\u001b[0m.collate_fn(data) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m55 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/\u001b[0m\u001b[1;33mcollate.py\u001b[0m:\u001b[94m264\u001b[0m in \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[92mdefault_collate\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m261 \u001b[0m\u001b[2;33m│ │ │ \u001b[0m\u001b[33m>>> default_collate_fn_map.update(CustoType, collate_customtype_fn)\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m262 \u001b[0m\u001b[2;33m│ │ │ \u001b[0m\u001b[33m>>> default_collate(batch) # Handle `CustomType` automatically\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m263 \u001b[0m\u001b[2;33m│ \u001b[0m\u001b[33m\"\"\"\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m264 \u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m collate(batch, collate_fn_map=default_collate_fn_map) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m265 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/\u001b[0m\u001b[1;33mcollate.py\u001b[0m:\u001b[94m150\u001b[0m in \u001b[92mcollate\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m147 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# The sequence type may not support `__init__(iterable)` (e.g., `range`)\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m148 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mreturn\u001b[0m [collate(samples, collate_fn_map=collate_fn_map) \u001b[94mfor\u001b[0m samples \u001b[95min\u001b[0m t \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m149 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m150 \u001b[2m│ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mTypeError\u001b[0m(default_collate_err_msg_format.format(elem_type)) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m151 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m152 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m153 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mcollate_tensor_fn\u001b[0m(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, .. \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mTypeError: \u001b[0mdefault_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found \u001b[1m<\u001b[0m\u001b[1;95mclass\u001b[0m\u001b[39m \u001b[0m\n", "\u001b[32m'torch_geometric.data.data.Data'\u001b[0m\u001b[1m>\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = faker.make_data_loader().dataset\n", "for x in DataLoader(dataset, batch_size = len(dataset)):\n", " all_fea_mat = x[0]\n", " all_adj_mat = x[1]\n", "print(f'Labels:\\n{labels.tolist()}')" ] }, { "cell_type": "markdown", "id": "f82982f2", "metadata": {}, "source": [ "### Process origin data" ] }, { "cell_type": "code", "execution_count": 9, "id": "59d0aef5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
       " in <cell line: 2>:2                                                                              \n",
       "                                                                                                  \n",
       "   1 # Flatten Data                                                                               \n",
       " 2 origin = np.array([torch.reshape(ele.to_dense(), (-1,)).tolist() for ele in all_fea_mat]     \n",
       "   3                                                                                              \n",
       "   4 # PCA analysis for dimensionality deduction                                                  \n",
       "   5 pca_analysis = sc.pp.pca(origin, n_comps = 9, return_info = True,)                           \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "NameError: name 'all_fea_mat' is not defined\n",
       "
\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1 \u001b[0m\u001b[2m# Flatten Data\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2 origin = np.array([torch.reshape(ele.to_dense(), (-\u001b[94m1\u001b[0m,)).tolist() \u001b[94mfor\u001b[0m ele \u001b[95min\u001b[0m all_fea_mat] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m4 \u001b[0m\u001b[2m# PCA analysis for dimensionality deduction\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0mpca_analysis = sc.pp.pca(origin, n_comps = \u001b[94m9\u001b[0m, return_info = \u001b[94mTrue\u001b[0m,) \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mNameError: \u001b[0mname \u001b[32m'all_fea_mat'\u001b[0m is not defined\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Flatten Data\n", "origin = np.array([torch.reshape(ele.to_dense(), (-1,)).tolist() for ele in all_fea_mat])\n", "\n", "# PCA analysis for dimensionality deduction\n", "pca_analysis = sc.pp.pca(origin, n_comps = 9, return_info = True,)\n", "origin_pca = pca_analysis[0]\n", "var_ratios = pca_analysis[2]\n", "print(f'Origin Data Var Ratios:\\n{var_ratios}')\n" ] }, { "cell_type": "markdown", "id": "96be974b", "metadata": {}, "source": [ "### Process data with encoder" ] }, { "cell_type": "code", "execution_count": 10, "id": "d90f7dc8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
       " in <cell line: 2>:3                                                                              \n",
       "                                                                                                  \n",
       "    1 # Get encoded representaions made by GAT -> BERT encoder                                    \n",
       "    2 encode = np.array([                                                                         \n",
       "  3 torch.reshape(ele, (-1,)).tolist() for ele in testM.get_encoder_output(                 \n",
       "    4 │   │   all_fea_mat, all_adj_mat                                                            \n",
       "    5 │   )                                                                                       \n",
       "    6 ])                                                                                          \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "AttributeError: 'Trainer' object has no attribute 'get_encoder_output'\n",
       "
\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m3\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 1 \u001b[0m\u001b[2m# Get encoded representaions made by GAT -> BERT encoder\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 2 \u001b[0mencode = np.array([ \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 3 \u001b[2m│ \u001b[0mtorch.reshape(ele, (-\u001b[94m1\u001b[0m,)).tolist() \u001b[94mfor\u001b[0m ele \u001b[95min\u001b[0m testM.get_encoder_output( \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 4 \u001b[0m\u001b[2m│ │ \u001b[0mall_fea_mat, all_adj_mat \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 5 \u001b[0m\u001b[2m│ \u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 6 \u001b[0m]) \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mAttributeError: \u001b[0m\u001b[32m'Trainer'\u001b[0m object has no attribute \u001b[32m'get_encoder_output'\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Get encoded representaions made by GAT -> BERT encoder\n", "encode = np.array([\n", " torch.reshape(ele, (-1,)).tolist() for ele in testM.get_encoder_output(\n", " all_fea_mat, all_adj_mat\n", " )\n", "])\n", "\n", "# PCA analysis for dimensionality deduction\n", "pca_analysis = sc.pp.pca(encode, n_comps = 9, return_info = True,)\n", "encode_pca = pca_analysis[0]\n", "var_ratios = pca_analysis[2]\n", "print(f'Encoded Rep Var Ratios:\\n{var_ratios}')" ] }, { "cell_type": "markdown", "id": "b5299e3c", "metadata": {}, "source": [ "### Set clustering models and fit models with original data" ] }, { "cell_type": "code", "execution_count": 11, "id": "48d8fea6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
       " in <cell line: 2>:2                                                                              \n",
       "                                                                                                  \n",
       "    1 eps = 0.5                                                                                   \n",
       "  2 n_clusters = len(np.unique(labels))                                                         \n",
       "    3 min_samples = 5                                                                             \n",
       "    4                                                                                             \n",
       "    5 dbscan = cluster.DBSCAN(eps = eps)                                                          \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "NameError: name 'labels' is not defined\n",
       "
\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 1 \u001b[0meps = \u001b[94m0.5\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 2 n_clusters = \u001b[96mlen\u001b[0m(np.unique(labels)) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 3 \u001b[0mmin_samples = \u001b[94m5\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 4 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 5 \u001b[0mdbscan = cluster.DBSCAN(eps = eps) \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mNameError: \u001b[0mname \u001b[32m'labels'\u001b[0m is not defined\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "eps = 0.5\n", "n_clusters = len(np.unique(labels))\n", "min_samples = 5\n", "\n", "dbscan = cluster.DBSCAN(eps = eps)\n", "kmeans = cluster.KMeans(n_clusters = n_clusters)\n", "optics = cluster.OPTICS(min_samples = min_samples)\n", "\n", "dbscan.fit(origin_pca)\n", "kmeans.fit(origin_pca)\n", "optics.fit(origin_pca)\n", "\n", "# Get labels\n", "print(dbscan.labels_.astype(int))\n", "print(kmeans.labels_.astype(int))\n", "print(optics.labels_.astype(int))" ] }, { "cell_type": "markdown", "id": "eba48c10", "metadata": {}, "source": [ "### Reset models and fit with encoded representaions" ] }, { "cell_type": "code", "execution_count": 12, "id": "1f1098d5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
       " in <cell line: 2>:2                                                                              \n",
       "                                                                                                  \n",
       "    1 dbscan = cluster.DBSCAN(eps = eps)                                                          \n",
       "  2 kmeans = cluster.KMeans(n_clusters = n_clusters)                                            \n",
       "    3 optics = cluster.OPTICS(min_samples = min_samples)                                          \n",
       "    4 dbscan.fit(encode_pca)                                                                      \n",
       "    5 kmeans.fit(encode_pca)                                                                      \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "NameError: name 'n_clusters' is not defined\n",
       "
\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m2\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 1 \u001b[0mdbscan = cluster.DBSCAN(eps = eps) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 2 kmeans = cluster.KMeans(n_clusters = n_clusters) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 3 \u001b[0moptics = cluster.OPTICS(min_samples = min_samples) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 4 \u001b[0mdbscan.fit(encode_pca) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 5 \u001b[0mkmeans.fit(encode_pca) \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mNameError: \u001b[0mname \u001b[32m'n_clusters'\u001b[0m is not defined\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dbscan = cluster.DBSCAN(eps = eps)\n", "kmeans = cluster.KMeans(n_clusters = n_clusters)\n", "optics = cluster.OPTICS(min_samples = min_samples)\n", "dbscan.fit(encode_pca)\n", "kmeans.fit(encode_pca)\n", "optics.fit(encode_pca)\n", "\n", "# Get labels\n", "print(dbscan.labels_.astype(int))\n", "print(kmeans.labels_.astype(int))\n", "print(optics.labels_.astype(int))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }