FateZ Explain

This notebook demonstrate how to utilize explanatory methods of FateZ models.

[1]:
import os
import sys
import torch
from torch.utils.data import DataLoader
import fatez.lib as lib
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process as process
import fatez.process.worker as worker
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

print('Done Import')
Done Import

Build model and make some fake data first.

[2]:
# Parameters
params = {
    'n_sample': 10,       # Fake samples to make
    'batch_size': 1,      # Batch size
}

# Init worker env
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/gat_bert_config.json'
    )
)
suppressor = process.Quiet_Mode()
device = [0]
dtype = torch.float32
worker.setup(device, master_port = '2307')

# Generate Fake data
faker = test.Faker(model_config = config, **params)
train_dataloader = faker.make_data_loader()

print('Done Init')
Done Init

Now we perform pre-training with no label.

Here trainer’s \(train\_adj\) is set to False, and the model is NOT reconstructing the adjacency matrices, etc.

[3]:
trainer = pre_trainer.Set(config, dtype=dtype, device=device)
suppressor.on()
report = trainer.train(train_dataloader, report_batch = True)
suppressor.off()
print(report)
        Loss
0   1.307709
1   7.125144
2   7.212573
3   7.185053
4   1.215182
5   7.039515
6   6.969513
7   1.040767
8   1.182723
9   1.147289
10  4.142547

In the case of pre-training with reconstructing adjacency matrices as well.

[4]:
config['pre_trainer']['train_adj'] = True
trainer = pre_trainer.Set(config, dtype = dtype, device = device)
suppressor.on()
report = trainer.train(train_dataloader, report_batch = True)
suppressor.off()
print(report)
        Loss
0   1.197746
1   7.171856
2   7.135652
3   7.115325
4   1.138201
5   1.099432
6   7.150834
7   1.107591
8   1.277280
9   7.157175
10  4.155109

Then, we can go for fine tuning part with class labels.

[5]:
tuner = fine_tuner.Set(config, prev_model = trainer, dtype = dtype, device = device)
report = tuner.train(train_dataloader, report_batch = True,)
print(report)
        Loss  ACC
0   0.649586  1.0
1   0.743881  0.0
2   0.645328  1.0
3   0.744279  0.0
4   0.645847  1.0
5   0.645064  1.0
6   0.744720  0.0
7   0.747999  0.0
8   0.645316  1.0
9   0.746752  0.0
10  0.695877  0.5

To explain model.

Three kinds of explanations are available: 1. edge_explain 2. regulon_explain 3. node_explain

[6]:
# Initializing edge explain matrix and regulon explain matrix
adj_exp = torch.zeros((config['input_sizes']['n_reg'], config['input_sizes']['n_node']))
reg_exp = torch.zeros((config['input_sizes']['n_reg'], config['encoder']['d_model']))

# Make background data
bg = [a for a,_ in DataLoader(train_dataloader.dataset, faker.n_sample, collate_fn = lib.collate_fn,)][0]
# Set explainer through taking input data from pseudo-dataloader
explain = tuner.model.make_explainer([a.to(tuner.device) for a in bg])

for x,_ in train_dataloader:
    data = [a.to(tuner.device) for a in x]
    adj_temp, reg_temp, vars = tuner.model.explain_batch(data, explain)
    adj_exp += adj_temp

    print(f'Explaining {len(reg_temp)} classes.')

    # Only the feat mat explanation should be working
    print(f'Each class has regulon explain in shape of {reg_temp[0][0].shape}.\n')

    # Only taking explainations for class 0
    for exp in reg_temp[0]: reg_exp += abs(exp)
    break

reg_exp = torch.sum(reg_exp, dim = -1)
node_exp = torch.matmul(reg_exp, adj_exp.type(reg_exp.dtype))
print('Edge Explain:\n', adj_exp, '\n')
print('Reg Explain:\n', reg_exp, '\n')
print('Node Explain:\n', node_exp, '\n')
Explaining 2 classes.
Each class has regulon explain in shape of (4, 4).

Edge Explain:
 tensor([[0.0302, 0.0293, 0.0292, 0.0280, 0.0272, 0.0294, 0.0258, 0.0309, 0.0000,
         0.0291],
        [0.0342, 0.0287, 0.0284, 0.0255, 0.0264, 0.0269, 0.0298, 0.0260, 0.0000,
         0.0301],
        [0.0294, 0.0231, 0.0260, 0.0229, 0.0267, 0.0267, 0.0304, 0.0262, 0.0000,
         0.0266],
        [0.0254, 0.0262, 0.0260, 0.0251, 0.0285, 0.0309, 0.0297, 0.0265, 0.0000,
         0.0287]])

Reg Explain:
 tensor([0.0008, 0.0008, 0.0003, 0.0016], dtype=torch.float64)

Node Explain:
 tensor([9.7898e-05, 9.2607e-05, 9.2633e-05, 8.7263e-05, 9.3675e-05, 9.9449e-05,
        9.7947e-05, 9.3009e-05, 0.0000e+00, 9.8218e-05], dtype=torch.float64)

Cleanup Env

Need to clean up environment once finsihed.

[7]:
worker.cleanup(device)
print('Clean up worker env.')
Clean up worker env.