FateZ Multiomic Pertubation Effect Prediction(?)

This notebook demonstrate how to implement Pertubation Effect Prediction method with FateZ’s modules.

[1]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import fatez.lib as lib
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process as process
import fatez.process.worker as worker
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

print('Done Import')
Done Import

Build model and make some fake data first.

[2]:
# Parameters
params = {
    'n_sample': 10,       # Fake samples to make
    'batch_size': 2,      # Batch size
}

# Init worker env
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/gat_bert_config.json'
    )
)
suppressor = process.Quiet_Mode()
device = 'cuda'
# device = [0] # Applying DDP if having multiple devices
dtype = torch.float32
worker.setup(device)

print('Done Init')

# Generate Fake data
faker = test.Faker(model_config = config, dtype = dtype, **params)
pertubation_dataloader = faker.make_data_loader()
result_dataloader = faker.make_data_loader()

# Make id of pertubation result the 'label' of each sample
for i,k in enumerate(pertubation_dataloader.dataset.samples):
    k.y = i

print('Done Fake Data')
Done Init
Done Fake Data

The model will be architecturally similar with a pretrainer

[3]:
trainer = pre_trainer.Set(config, dtype = dtype, device=device)

print('Model Set')
Model Set

However, the training part will be littel bit different

This part is modified based on pre_trainer.Trainer.train()

[7]:
report_batch = False
size = trainer.input_sizes

trainer.worker.train(True)
best_loss = 99
loss_all = 0
report = list()

for x,y in pertubation_dataloader:

    # Prepare input data as always
    input = [ele.to(trainer.device) for ele in x]

    # Mute some debug outputs
    suppressor.on()
    node_rec, adj_rec = trainer.worker(input)
    suppressor.off()

    # Prepare pertubation result data using a seperate dataloader
    y = [result_dataloader.dataset.samples[ele].to(trainer.device) for ele in y]
    # Please be noted here that this script is only reconstructing TF parts
    # To reconstruct whole genome, we can certainly add an additionaly layer which takes adj_rec and node_rec to do the job.
    node_results = torch.stack([ele.x for ele in input], 0)
    adj_results = lib.get_dense_adjs(
        y, (size['n_reg'],size['n_node'],size['edge_attr'])
    )

    # Get total loss
    loss = trainer.criterion(node_rec, node_results)
    if adj_rec is not None:
        loss += trainer.criterion(adj_rec, adj_results)

    # Some backward stuffs here
    loss.backward()
    nn.utils.clip_grad_norm_(trainer.model.parameters(), trainer.max_norm)
    trainer.optimizer.step()
    trainer.optimizer.zero_grad()

    # Accumulate
    best_loss = min(best_loss, loss.item())
    loss_all += loss.item()

    # Some logs
    if report_batch: report.append([loss.item()])


trainer.scheduler.step()
report.append([loss_all / len(pertubation_dataloader)])
report = pd.DataFrame(report)
report.columns = ['Loss', ]
print(report)
       Loss
0  4.120818

In the case of tuning unlabeled data, which does not have pertubation results…

We shall set another trainer using previous model.

[8]:
tuner = pre_trainer.Set(config, prev_model = trainer, dtype = dtype, device = device)

# Some new fake data
tuner_dataloader = faker.make_data_loader()

# And the tuning process is also based on input reconstruction as pretraining
suppressor.on()
report = tuner.train(tuner_dataloader, report_batch = False,)
suppressor.off()
print(report)
       Loss
0  4.340387

Then we shall just use trainer object to make predictions.

Similar with the training block above for trainer, but no need to prepare y.

[9]:
trainer.model.eval()

for x,_ in tuner_dataloader:

    # Prepare input data as always
    input = [ele.to(trainer.device) for ele in x]

    # Mute some debug outputs
    suppressor.on()
    node_rec, adj_rec = trainer.model(input)
    suppressor.off()
    print(node_rec.shape, adj_rec.shape)
torch.Size([2, 10, 2]) torch.Size([2, 4, 10])
torch.Size([2, 10, 2]) torch.Size([2, 4, 10])
torch.Size([2, 10, 2]) torch.Size([2, 4, 10])
torch.Size([2, 10, 2]) torch.Size([2, 4, 10])
torch.Size([2, 10, 2]) torch.Size([2, 4, 10])

Cleanup Env

Need to clean up environment once finsihed.

[10]:
worker.cleanup(device)
print('Clean up worker env.')
Clean up worker env.