Source code for dscript.commands.train

"""
Train a new model.
"""

from __future__ import annotations

import argparse
import sys
from collections.abc import Callable
from typing import NamedTuple

import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import average_precision_score as average_precision
from torch.autograd import Variable
from tqdm import tqdm

from .. import __version__
from ..fasta import parse_dict
from ..foldseek import fold_vocab, get_foldseek_onehot
from ..glider import glide_compute_map, glider_score
from ..models.contact import ContactCNN
from ..models.embedding import FullyConnectedEmbed
from ..models.interaction import ModelInteraction
from ..utils import (
    PairedDataset,
    collate_paired_sequences,
    log,
)


[docs]class TrainArguments(NamedTuple): cmd: str device: int train: str test: str embedding: str no_augment: bool input_dim: int projection_dim: int dropout: float hidden_dim: int kernel_width: int no_w: bool no_sigmoid: bool do_pool: bool pool_width: int num_epochs: int batch_size: int weight_decay: float lr: float interaction_weight: float run_tt: bool glider_weight: float glider_thresh: float outfile: str | None save_prefix: str | None checkpoint: str | None seed: int | None func: Callable[[TrainArguments], None]
def add_args(parser): """ Create parser for command line utility. :meta private: """ data_grp = parser.add_argument_group("Data") proj_grp = parser.add_argument_group("Projection Module") contact_grp = parser.add_argument_group("Contact Module") inter_grp = parser.add_argument_group("Interaction Module") train_grp = parser.add_argument_group("Training") misc_grp = parser.add_argument_group("Output and Device") foldseek_grp = parser.add_argument_group("Foldseek related commands") # Data data_grp.add_argument("--train", required=True, help="list of training pairs") data_grp.add_argument( "--test", required=True, help="list of validation/testing pairs" ) data_grp.add_argument( "--embedding", required=True, help="h5py path containing embedded sequences", ) data_grp.add_argument( "--no-augment", action="store_true", help="data is automatically augmented by adding (B A) for all pairs (A B). Set this flag to not augment data", ) # Embedding model proj_grp.add_argument( "--input-dim", type=int, default=6165, help="dimension of input language model embedding (per amino acid) (default: 6165)", ) proj_grp.add_argument( "--projection-dim", type=int, default=100, help="dimension of embedding projection layer (default: 100)", ) proj_grp.add_argument( "--dropout-p", type=float, default=0.5, help="parameter p for embedding dropout layer (default: 0.5)", ) # Contact model contact_grp.add_argument( "--hidden-dim", type=int, default=50, help="number of hidden units for comparison layer in contact prediction (default: 50)", ) contact_grp.add_argument( "--kernel-width", type=int, default=7, help="width of convolutional filter for contact prediction (default: 7)", ) # Interaction Model inter_grp.add_argument( "--no-w", action="store_true", help="no use of weight matrix in interaction prediction model", ) inter_grp.add_argument( "--no-sigmoid", action="store_true", help="no use of sigmoid activation at end of interaction model", ) inter_grp.add_argument( "--do-pool", action="store_true", help="use max pool layer in interaction prediction model", ) inter_grp.add_argument( "--pool-width", type=int, default=9, help="size of max-pool in interaction model (default: 9)", ) # Training train_grp.add_argument( "--num-epochs", type=int, default=10, help="number of epochs (default: 10)", ) train_grp.add_argument( "--batch-size", type=int, default=25, help="minibatch size (default: 25)", ) train_grp.add_argument( "--weight-decay", type=float, default=0, help="L2 regularization (default: 0)", ) train_grp.add_argument( "--lr", type=float, default=0.001, help="learning rate (default: 0.001)", ) train_grp.add_argument( "--lambda", dest="interaction_weight", type=float, default=0.35, help="weight on the similarity objective (default: 0.35)", ) # Topsy-Turvy train_grp.add_argument( "--topsy-turvy", dest="run_tt", action="store_true", help="run in Topsy-Turvy mode -- use top-down GLIDER scoring to guide training", ) train_grp.add_argument( "--glider-weight", dest="glider_weight", type=float, default=0.2, help="weight on the GLIDER accuracy objective (default: 0.2)", ) train_grp.add_argument( "--glider-thresh", dest="glider_thresh", type=float, default=0.925, help="threshold beyond which GLIDER scores treated as positive edges (0 < gt < 1) (default: 0.925)", ) # Output misc_grp.add_argument("-o", "--outfile", help="output file path (default: stdout)") misc_grp.add_argument("--save-prefix", help="path prefix for saving models") misc_grp.add_argument( "-d", "--device", type=int, default=-1, help="compute device to use" ) misc_grp.add_argument("--checkpoint", help="checkpoint model to start training from") misc_grp.add_argument("--seed", help="Set random seed", type=int) ## Foldseek arguments foldseek_grp.add_argument( "--allow_foldseek", default=False, action="store_true", help="If set to true, adds the foldseek one-hot representation", ) foldseek_grp.add_argument( "--foldseek_fasta", help="foldseek fasta file containing the foldseek representation", ) # foldseek_grp.add_argument( # "--add_foldseek_after_projection", default = False, action = "store_true", help = "If set to true, adds the fold seek embedding after the projection layer" # ) return parser
[docs]def predict_cmap_interaction( model, n0, n1, tensors, use_cuda, ### Foldseek added here allow_foldseek=False, fold_record=None, fold_vocab=None, add_first=True, ### ): """ Predict whether a list of protein pairs will interact, as well as their contact map. :param model: Model to be trained :type model: dscript.models.interaction.ModelInteraction :param n0: First protein names :type n0: list[str] :param n1: Second protein names :type n1: list[str] :param tensors: Dictionary of protein names to embeddings :type tensors: dict[str, torch.Tensor] :param use_cuda: Whether to use GPU :type use_cuda: bool """ b = len(n0) p_hat = [] c_map_mag = [] for i in range(b): z_a = tensors[n0[i]] # 1 x seqlen x dim z_b = tensors[n1[i]] if use_cuda: z_a = z_a.cuda() z_b = z_b.cuda() if allow_foldseek: assert fold_record is not None and fold_vocab is not None f_a = get_foldseek_onehot( n0[i], z_a.shape[1], fold_record, fold_vocab ).unsqueeze(0) # seqlen x vocabsize f_b = get_foldseek_onehot( n1[i], z_b.shape[1], fold_record, fold_vocab ).unsqueeze(0) ## check if cuda if use_cuda: f_a = f_a.cuda() f_b = f_b.cuda() if add_first: z_a = torch.concat([z_a, f_a], dim=2) z_b = torch.concat([z_b, f_b], dim=2) if allow_foldseek and (not add_first): cm, ph = model.map_predict(z_a, z_b, True, f_a, f_b) else: cm, ph = model.map_predict(z_a, z_b) p_hat.append(ph) c_map_mag.append(torch.mean(cm)) p_hat = torch.stack(p_hat, 0) c_map_mag = torch.stack(c_map_mag, 0) return c_map_mag, p_hat
[docs]def predict_interaction( model, n0, n1, tensors, use_cuda, ### Foldseek added here allow_foldseek=False, fold_record=None, fold_vocab=None, add_first=True, ### ): """ Predict whether a list of protein pairs will interact. :param model: Model to be trained :type model: dscript.models.interaction.ModelInteraction :param n0: First protein names :type n0: list[str] :param n1: Second protein names :type n1: list[str] :param tensors: Dictionary of protein names to embeddings :type tensors: dict[str, torch.Tensor] :param use_cuda: Whether to use GPU :type use_cuda: bool """ _, p_hat = predict_cmap_interaction( model, n0, n1, tensors, use_cuda, allow_foldseek, fold_record, fold_vocab, add_first, ) return p_hat
[docs]def interaction_grad( model, n0, n1, y, tensors, accuracy_weight=0.35, run_tt=False, glider_weight=0, glider_map=None, glider_mat=None, use_cuda=True, ### Foldseek added here allow_foldseek=False, fold_record=None, fold_vocab=None, add_first=True, ### ): """ Compute gradient and backpropagate loss for a batch. :param model: Model to be trained :type model: dscript.models.interaction.ModelInteraction :param n0: First protein names :type n0: list[str] :param n1: Second protein names :type n1: list[str] :param y: Interaction labels :type y: torch.Tensor :param tensors: Dictionary of protein names to embeddings :type tensors: dict[str, torch.Tensor] :param accuracy_weight: Weight on the accuracy objective. Representation loss is :math:`1 - \\text{accuracy_weight}`. :type accuracy_weight: float :param run_tt: Use GLIDE top-down supervision :type run_tt: bool :param glider_weight: Weight on the GLIDE objective loss. Accuracy loss is :math:`(\\text{GLIDER_BCE}*\\text{glider_weight}) + (\\text{D-SCRIPT_BCE}*(1-\\text{glider_weight}))`. :type glider_weight: float :param glider_map: Map from protein identifier to index :type glider_map: dict[str, int] :param glider_mat: Matrix with pairwise GLIDE scores :type glider_mat: np.ndarray :param use_cuda: Whether to use GPU :type use_cuda: bool :return: (Loss, number correct, mean square error, batch size) :rtype: (torch.Tensor, int, torch.Tensor, int) """ c_map_mag, p_hat = predict_cmap_interaction( model, n0, n1, tensors, use_cuda, allow_foldseek, fold_record, fold_vocab, add_first, ) if use_cuda: y = y.cuda() y = Variable(y) p_hat = p_hat.float() bce_loss = F.binary_cross_entropy(p_hat.float(), y.float()) if run_tt: g_score = [] for i in range(len(n0)): g_score.append( torch.tensor( glider_score(n0[i], n1[i], glider_map, glider_mat), dtype=torch.float64, ) ) g_score = torch.stack(g_score, 0) if use_cuda: g_score = g_score.cuda() glider_loss = F.binary_cross_entropy(p_hat.float(), g_score.float()) accuracy_loss = (glider_weight * glider_loss) + ((1 - glider_weight) * bce_loss) else: accuracy_loss = bce_loss representation_loss = torch.mean(c_map_mag) loss = (accuracy_weight * accuracy_loss) + ( (1 - accuracy_weight) * representation_loss ) b = len(p_hat) # Backprop Loss loss.backward() if use_cuda: y = y.cpu() p_hat = p_hat.cpu() if run_tt: g_score = g_score.cpu() with torch.no_grad(): guess_cutoff = 0.5 p_hat = p_hat.float() p_guess = (guess_cutoff * torch.ones(b) < p_hat).float() y = y.float() correct = torch.sum(p_guess == y).item() mse = torch.mean((y.float() - p_hat) ** 2).item() return loss, correct, mse, b
[docs]def interaction_eval( model, test_iterator, tensors, use_cuda, ### Foldseek added here allow_foldseek=False, fold_record=None, fold_vocab=None, add_first=True, ### ): """ Evaluate test data set performance. :param model: Model to be trained :type model: dscript.models.interaction.ModelInteraction :param test_iterator: Test data iterator :type test_iterator: torch.utils.data.DataLoader :param tensors: Dictionary of protein names to embeddings :type tensors: dict[str, torch.Tensor] :param use_cuda: Whether to use GPU :type use_cuda: bool :return: (Loss, number correct, mean square error, precision, recall, F1 Score, AUPR) :rtype: (torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) """ p_hat = [] true_y = [] for n0, n1, y in test_iterator: p_hat.append( predict_interaction( model, n0, n1, tensors, use_cuda, allow_foldseek, fold_record, fold_vocab, add_first, ) ) true_y.append(y) y = torch.cat(true_y, 0) p_hat = torch.cat(p_hat, 0) if use_cuda: y.cuda() p_hat = torch.Tensor([x.cuda() for x in p_hat]) p_hat.cuda() loss = F.binary_cross_entropy(p_hat.float(), y.float()).item() b = len(y) with torch.no_grad(): guess_cutoff = torch.Tensor([0.5]).float() p_hat = p_hat.float() y = y.float() p_guess = (guess_cutoff * torch.ones(b) < p_hat).float() correct = torch.sum(p_guess == y).item() mse = torch.mean((y.float() - p_hat) ** 2).item() tp = torch.sum(y * p_hat).item() pr = tp / torch.sum(p_hat).item() re = tp / torch.sum(y).item() f1 = 2 * pr * re / (pr + re) y = y.cpu().numpy() p_hat = p_hat.data.cpu().numpy() aupr = average_precision(y, p_hat) return loss, correct, mse, pr, re, f1, aupr
[docs]def train_model(args, output): # Create data sets batch_size = args.batch_size use_cuda = (args.device > -1) and torch.cuda.is_available() train_fi = args.train test_fi = args.test no_augment = args.no_augment embedding_h5 = args.embedding ########## Foldseek code #########################3 allow_foldseek = args.allow_foldseek fold_fasta_file = args.foldseek_fasta # fold_vocab_file = args.foldseek_vocab add_first = False fold_record = {} # fold_vocab = None if allow_foldseek: assert fold_fasta_file is not None fold_fasta = parse_dict(fold_fasta_file) for rec_k, rec_v in fold_fasta.items(): fold_record[rec_k] = rec_v ################################################## train_df = pd.read_csv(train_fi, sep="\t", header=None) train_df.columns = ["prot1", "prot2", "label"] if no_augment: train_p1 = train_df["prot1"] train_p2 = train_df["prot2"] train_y = torch.from_numpy(train_df["label"].values) else: train_p1 = pd.concat((train_df["prot1"], train_df["prot2"]), axis=0).reset_index( drop=True ) train_p2 = pd.concat((train_df["prot2"], train_df["prot1"]), axis=0).reset_index( drop=True ) train_y = torch.from_numpy( pd.concat((train_df["label"], train_df["label"])).values ) train_dataset = PairedDataset(train_p1, train_p2, train_y) train_iterator = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, collate_fn=collate_paired_sequences, shuffle=True, ) log(f"Loaded {len(train_p1)} training pairs", file=output) output.flush() test_df = pd.read_csv(test_fi, sep="\t", header=None) test_df.columns = ["prot1", "prot2", "label"] test_p1 = test_df["prot1"] test_p2 = test_df["prot2"] test_y = torch.from_numpy(test_df["label"].values) test_dataset = PairedDataset(test_p1, test_p2, test_y) test_iterator = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, collate_fn=collate_paired_sequences, shuffle=False, ) log(f"Loaded {len(test_p1)} test pairs", file=output) log("Loading embeddings...", file=output) output.flush() all_proteins = set(train_p1).union(train_p2).union(test_p1).union(test_p2) embeddings = {} with h5py.File(embedding_h5, "r") as h5fi: for prot_name in tqdm(all_proteins): embeddings[prot_name] = torch.from_numpy(h5fi[prot_name][:, :]) # embeddings = load_hdf5_parallel(embedding_h5, all_proteins) # Topsy-Turvy run_tt = args.run_tt glider_weight = args.glider_weight glider_thresh = args.glider_thresh * 100 if run_tt: log("Running D-SCRIPT Topsy-Turvy:", file=output) log(f"\tglider_weight: {glider_weight}", file=output) log(f"\tglider_thresh: {glider_thresh}th percentile", file=output) log("Computing GLIDER matrix...", file=output) output.flush() glider_mat, glider_map = glide_compute_map( train_df[train_df.iloc[:, 2] == 1], thres_p=glider_thresh ) else: glider_mat, glider_map = (None, None) if args.checkpoint is None: # Create embedding model input_dim = args.input_dim ############### foldseek code ########################### if allow_foldseek and add_first: input_dim += len(fold_vocab) ########################################################## projection_dim = args.projection_dim dropout_p = args.dropout_p embedding_model = FullyConnectedEmbed( input_dim, projection_dim, dropout=dropout_p ) log("Initializing embedding model with:", file=output) log(f"\tprojection_dim: {projection_dim}", file=output) log(f"\tdropout_p: {dropout_p}", file=output) # Create contact model hidden_dim = args.hidden_dim kernel_width = args.kernel_width log("Initializing contact model with:", file=output) log(f"\thidden_dim: {hidden_dim}", file=output) log(f"\tkernel_width: {kernel_width}", file=output) proj_dim = projection_dim if allow_foldseek and not add_first: proj_dim += len(fold_vocab) contact_model = ContactCNN(proj_dim, hidden_dim, kernel_width) # Create the full model do_w = not args.no_w do_pool = args.do_pool pool_width = args.pool_width do_sigmoid = not args.no_sigmoid log("Initializing interaction model with:", file=output) log(f"\tdo_poool: {do_pool}", file=output) log(f"\tpool_width: {pool_width}", file=output) log(f"\tdo_w: {do_w}", file=output) log(f"\tdo_sigmoid: {do_sigmoid}", file=output) model = ModelInteraction( embedding_model, contact_model, use_cuda, do_w=do_w, pool_size=pool_width, do_pool=do_pool, do_sigmoid=do_sigmoid, ) log(model, file=output) else: log( f"Loading model from checkpoint {args.checkpoint}", file=output, ) model = torch.load(args.checkpoint) model.use_cuda = use_cuda if use_cuda: model.cuda() # Train the model lr = args.lr wd = args.weight_decay num_epochs = args.num_epochs batch_size = args.batch_size inter_weight = args.interaction_weight cmap_weight = 1 - inter_weight digits = int(np.floor(np.log10(num_epochs))) + 1 save_prefix = args.save_prefix params = [p for p in model.parameters() if p.requires_grad] optim = torch.optim.Adam(params, lr=lr, weight_decay=wd) log(f'Using save prefix "{save_prefix}"', file=output) log(f"Training with Adam: lr={lr}, weight_decay={wd}", file=output) log(f"\tnum_epochs: {num_epochs}", file=output) log(f"\tbatch_size: {batch_size}", file=output) log(f"\tinteraction weight: {inter_weight}", file=output) log(f"\tcontact map weight: {cmap_weight}", file=output) output.flush() batch_report_fmt = "[{}/{}] training {:.1%}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}" epoch_report_fmt = "Finished Epoch {}/{}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}, Precision={:.6}, Recall={:.6}, F1={:.6}, AUPR={:.6}" N = len(train_iterator) * batch_size for epoch in range(num_epochs): model.train() n = 0 loss_accum = 0 acc_accum = 0 mse_accum = 0 # Train batches for z0, z1, y in train_iterator: loss, correct, mse, b = interaction_grad( model, z0, z1, y, embeddings, accuracy_weight=inter_weight, run_tt=run_tt, glider_weight=glider_weight, glider_map=glider_map, glider_mat=glider_mat, use_cuda=use_cuda, allow_foldseek=allow_foldseek, fold_record=fold_record, fold_vocab=fold_vocab, add_first=add_first, ) n += b delta = b * (loss - loss_accum) loss_accum += delta / n delta = correct - b * acc_accum acc_accum += delta / n delta = b * (mse - mse_accum) mse_accum += delta / n report = (n - b) // 100 < n // 100 optim.step() optim.zero_grad() model.clip() if report: tokens = [ epoch + 1, num_epochs, n / N, loss_accum, acc_accum, mse_accum, ] log(batch_report_fmt.format(*tokens), file=output) output.flush() model.eval() with torch.no_grad(): ( inter_loss, inter_correct, inter_mse, inter_pr, inter_re, inter_f1, inter_aupr, ) = interaction_eval( model, test_iterator, embeddings, use_cuda, allow_foldseek, fold_record, fold_vocab, add_first, ) tokens = [ epoch + 1, num_epochs, inter_loss, inter_correct / (len(test_iterator) * batch_size), inter_mse, inter_pr, inter_re, inter_f1, inter_aupr, ] log(epoch_report_fmt.format(*tokens), file=output) output.flush() # Save the model if save_prefix is not None: save_path = save_prefix + "_epoch" + str(epoch + 1).zfill(digits) + ".sav" log(f"Saving model to {save_path}", file=output) model.cpu() torch.save(model, save_path) if use_cuda: model.cuda() output.flush() if save_prefix is not None: save_path = save_prefix + "_final.sav" log(f"Saving final model to {save_path}", file=output) model.cpu() torch.save(model, save_path) if use_cuda: model.cuda()
def main(args): """ Run training from arguments. :meta private: """ output = args.outfile if output is None: output = sys.stdout else: output = open(output, "w") log(f"D-SCRIPT Version {__version__}", file=output, print_also=True) log(f"Called as: {' '.join(sys.argv)}", file=output, print_also=True) # Set the device device = args.device use_cuda = (device > -1) and torch.cuda.is_available() if use_cuda: torch.cuda.set_device(device) log( f"Using CUDA device {device} - {torch.cuda.get_device_name(device)}", file=output, print_also=True, ) else: log("Using CPU", file=output, print_also=True) device = "cpu" if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) train_model(args, output) output.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description=__doc__) add_args(parser) main(parser.parse_args())