FateZ Explain
This notebook demonstrate how to utilize explanatory methods of FateZ models.
[1]:
import os
import sys
import torch
from torch.utils.data import DataLoader
import fatez.lib as lib
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process as process
import fatez.process.worker as worker
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename
print('Done Import')
Done Import
Build model and make some fake data first.
[2]:
# Parameters
params = {
'n_sample': 10, # Fake samples to make
'batch_size': 1, # Batch size
}
# Init worker env
config = JSON.decode(resource_filename(
__name__, '../../fatez/data/config/gat_bert_config.json'
)
)
suppressor = process.Quiet_Mode()
device = [0]
dtype = torch.float32
worker.setup(device, master_port = '2307')
# Generate Fake data
faker = test.Faker(model_config = config, **params)
train_dataloader = faker.make_data_loader()
print('Done Init')
Done Init
Now we perform pre-training with no label.
Here trainer’s \(train\_adj\) is set to False, and the model is NOT reconstructing the adjacency matrices, etc.
[3]:
trainer = pre_trainer.Set(config, dtype=dtype, device=device)
suppressor.on()
report = trainer.train(train_dataloader, report_batch = True)
suppressor.off()
print(report)
Loss
0 1.307709
1 7.125144
2 7.212573
3 7.185053
4 1.215182
5 7.039515
6 6.969513
7 1.040767
8 1.182723
9 1.147289
10 4.142547
In the case of pre-training with reconstructing adjacency matrices as well.
[4]:
config['pre_trainer']['train_adj'] = True
trainer = pre_trainer.Set(config, dtype = dtype, device = device)
suppressor.on()
report = trainer.train(train_dataloader, report_batch = True)
suppressor.off()
print(report)
Loss
0 1.197746
1 7.171856
2 7.135652
3 7.115325
4 1.138201
5 1.099432
6 7.150834
7 1.107591
8 1.277280
9 7.157175
10 4.155109
Then, we can go for fine tuning part with class labels.
[5]:
tuner = fine_tuner.Set(config, prev_model = trainer, dtype = dtype, device = device)
report = tuner.train(train_dataloader, report_batch = True,)
print(report)
Loss ACC
0 0.649586 1.0
1 0.743881 0.0
2 0.645328 1.0
3 0.744279 0.0
4 0.645847 1.0
5 0.645064 1.0
6 0.744720 0.0
7 0.747999 0.0
8 0.645316 1.0
9 0.746752 0.0
10 0.695877 0.5
To explain model.
Three kinds of explanations are available: 1. edge_explain 2. regulon_explain 3. node_explain
[6]:
# Initializing edge explain matrix and regulon explain matrix
adj_exp = torch.zeros((config['input_sizes']['n_reg'], config['input_sizes']['n_node']))
reg_exp = torch.zeros((config['input_sizes']['n_reg'], config['encoder']['d_model']))
# Make background data
bg = [a for a,_ in DataLoader(train_dataloader.dataset, faker.n_sample, collate_fn = lib.collate_fn,)][0]
# Set explainer through taking input data from pseudo-dataloader
explain = tuner.model.make_explainer([a.to(tuner.device) for a in bg])
for x,_ in train_dataloader:
data = [a.to(tuner.device) for a in x]
adj_temp, reg_temp, vars = tuner.model.explain_batch(data, explain)
adj_exp += adj_temp
print(f'Explaining {len(reg_temp)} classes.')
# Only the feat mat explanation should be working
print(f'Each class has regulon explain in shape of {reg_temp[0][0].shape}.\n')
# Only taking explainations for class 0
for exp in reg_temp[0]: reg_exp += abs(exp)
break
reg_exp = torch.sum(reg_exp, dim = -1)
node_exp = torch.matmul(reg_exp, adj_exp.type(reg_exp.dtype))
print('Edge Explain:\n', adj_exp, '\n')
print('Reg Explain:\n', reg_exp, '\n')
print('Node Explain:\n', node_exp, '\n')
Explaining 2 classes.
Each class has regulon explain in shape of (4, 4).
Edge Explain:
tensor([[0.0302, 0.0293, 0.0292, 0.0280, 0.0272, 0.0294, 0.0258, 0.0309, 0.0000,
0.0291],
[0.0342, 0.0287, 0.0284, 0.0255, 0.0264, 0.0269, 0.0298, 0.0260, 0.0000,
0.0301],
[0.0294, 0.0231, 0.0260, 0.0229, 0.0267, 0.0267, 0.0304, 0.0262, 0.0000,
0.0266],
[0.0254, 0.0262, 0.0260, 0.0251, 0.0285, 0.0309, 0.0297, 0.0265, 0.0000,
0.0287]])
Reg Explain:
tensor([0.0008, 0.0008, 0.0003, 0.0016], dtype=torch.float64)
Node Explain:
tensor([9.7898e-05, 9.2607e-05, 9.2633e-05, 8.7263e-05, 9.3675e-05, 9.9449e-05,
9.7947e-05, 9.3009e-05, 0.0000e+00, 9.8218e-05], dtype=torch.float64)
Cleanup Env
Need to clean up environment once finsihed.
[7]:
worker.cleanup(device)
print('Clean up worker env.')
Clean up worker env.