"""
Train a new model.
"""
from __future__ import annotations
import argparse
import sys
from collections.abc import Callable
from typing import NamedTuple
import h5py
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import average_precision_score as average_precision
from torch.autograd import Variable
from tqdm import tqdm
from .. import __version__
from ..fasta import parse_dict
from ..foldseek import fold_vocab, get_foldseek_onehot
from ..glider import glide_compute_map, glider_score
from ..models.contact import ContactCNN
from ..models.embedding import FullyConnectedEmbed
from ..models.interaction import ModelInteraction
from ..utils import (
PairedDataset,
collate_paired_sequences,
log,
)
[docs]class TrainArguments(NamedTuple):
cmd: str
device: int
train: str
test: str
embedding: str
no_augment: bool
input_dim: int
projection_dim: int
dropout: float
hidden_dim: int
kernel_width: int
no_w: bool
no_sigmoid: bool
do_pool: bool
pool_width: int
num_epochs: int
batch_size: int
weight_decay: float
lr: float
interaction_weight: float
run_tt: bool
glider_weight: float
glider_thresh: float
outfile: str | None
save_prefix: str | None
checkpoint: str | None
seed: int | None
func: Callable[[TrainArguments], None]
def add_args(parser):
"""
Create parser for command line utility.
:meta private:
"""
data_grp = parser.add_argument_group("Data")
proj_grp = parser.add_argument_group("Projection Module")
contact_grp = parser.add_argument_group("Contact Module")
inter_grp = parser.add_argument_group("Interaction Module")
train_grp = parser.add_argument_group("Training")
misc_grp = parser.add_argument_group("Output and Device")
foldseek_grp = parser.add_argument_group("Foldseek related commands")
# Data
data_grp.add_argument("--train", required=True, help="list of training pairs")
data_grp.add_argument(
"--test", required=True, help="list of validation/testing pairs"
)
data_grp.add_argument(
"--embedding",
required=True,
help="h5py path containing embedded sequences",
)
data_grp.add_argument(
"--no-augment",
action="store_true",
help="data is automatically augmented by adding (B A) for all pairs (A B). Set this flag to not augment data",
)
# Embedding model
proj_grp.add_argument(
"--input-dim",
type=int,
default=6165,
help="dimension of input language model embedding (per amino acid) (default: 6165)",
)
proj_grp.add_argument(
"--projection-dim",
type=int,
default=100,
help="dimension of embedding projection layer (default: 100)",
)
proj_grp.add_argument(
"--dropout-p",
type=float,
default=0.5,
help="parameter p for embedding dropout layer (default: 0.5)",
)
# Contact model
contact_grp.add_argument(
"--hidden-dim",
type=int,
default=50,
help="number of hidden units for comparison layer in contact prediction (default: 50)",
)
contact_grp.add_argument(
"--kernel-width",
type=int,
default=7,
help="width of convolutional filter for contact prediction (default: 7)",
)
# Interaction Model
inter_grp.add_argument(
"--no-w",
action="store_true",
help="no use of weight matrix in interaction prediction model",
)
inter_grp.add_argument(
"--no-sigmoid",
action="store_true",
help="no use of sigmoid activation at end of interaction model",
)
inter_grp.add_argument(
"--do-pool",
action="store_true",
help="use max pool layer in interaction prediction model",
)
inter_grp.add_argument(
"--pool-width",
type=int,
default=9,
help="size of max-pool in interaction model (default: 9)",
)
# Training
train_grp.add_argument(
"--num-epochs",
type=int,
default=10,
help="number of epochs (default: 10)",
)
train_grp.add_argument(
"--batch-size",
type=int,
default=25,
help="minibatch size (default: 25)",
)
train_grp.add_argument(
"--weight-decay",
type=float,
default=0,
help="L2 regularization (default: 0)",
)
train_grp.add_argument(
"--lr",
type=float,
default=0.001,
help="learning rate (default: 0.001)",
)
train_grp.add_argument(
"--lambda",
dest="interaction_weight",
type=float,
default=0.35,
help="weight on the similarity objective (default: 0.35)",
)
# Topsy-Turvy
train_grp.add_argument(
"--topsy-turvy",
dest="run_tt",
action="store_true",
help="run in Topsy-Turvy mode -- use top-down GLIDER scoring to guide training",
)
train_grp.add_argument(
"--glider-weight",
dest="glider_weight",
type=float,
default=0.2,
help="weight on the GLIDER accuracy objective (default: 0.2)",
)
train_grp.add_argument(
"--glider-thresh",
dest="glider_thresh",
type=float,
default=0.925,
help="threshold beyond which GLIDER scores treated as positive edges (0 < gt < 1) (default: 0.925)",
)
# Output
misc_grp.add_argument("-o", "--outfile", help="output file path (default: stdout)")
misc_grp.add_argument("--save-prefix", help="path prefix for saving models")
misc_grp.add_argument(
"-d", "--device", type=int, default=-1, help="compute device to use"
)
misc_grp.add_argument("--checkpoint", help="checkpoint model to start training from")
misc_grp.add_argument("--seed", help="Set random seed", type=int)
## Foldseek arguments
foldseek_grp.add_argument(
"--allow_foldseek",
default=False,
action="store_true",
help="If set to true, adds the foldseek one-hot representation",
)
foldseek_grp.add_argument(
"--foldseek_fasta",
help="foldseek fasta file containing the foldseek representation",
)
# foldseek_grp.add_argument(
# "--add_foldseek_after_projection", default = False, action = "store_true", help = "If set to true, adds the fold seek embedding after the projection layer"
# )
return parser
[docs]def predict_cmap_interaction(
model,
n0,
n1,
tensors,
use_cuda,
### Foldseek added here
allow_foldseek=False,
fold_record=None,
fold_vocab=None,
add_first=True,
###
):
"""
Predict whether a list of protein pairs will interact, as well as their contact map.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param n0: First protein names
:type n0: list[str]
:param n1: Second protein names
:type n1: list[str]
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
"""
b = len(n0)
p_hat = []
c_map_mag = []
for i in range(b):
z_a = tensors[n0[i]] # 1 x seqlen x dim
z_b = tensors[n1[i]]
if use_cuda:
z_a = z_a.cuda()
z_b = z_b.cuda()
if allow_foldseek:
assert fold_record is not None and fold_vocab is not None
f_a = get_foldseek_onehot(
n0[i], z_a.shape[1], fold_record, fold_vocab
).unsqueeze(0) # seqlen x vocabsize
f_b = get_foldseek_onehot(
n1[i], z_b.shape[1], fold_record, fold_vocab
).unsqueeze(0)
## check if cuda
if use_cuda:
f_a = f_a.cuda()
f_b = f_b.cuda()
if add_first:
z_a = torch.concat([z_a, f_a], dim=2)
z_b = torch.concat([z_b, f_b], dim=2)
if allow_foldseek and (not add_first):
cm, ph = model.map_predict(z_a, z_b, True, f_a, f_b)
else:
cm, ph = model.map_predict(z_a, z_b)
p_hat.append(ph)
c_map_mag.append(torch.mean(cm))
p_hat = torch.stack(p_hat, 0)
c_map_mag = torch.stack(c_map_mag, 0)
return c_map_mag, p_hat
[docs]def predict_interaction(
model,
n0,
n1,
tensors,
use_cuda,
### Foldseek added here
allow_foldseek=False,
fold_record=None,
fold_vocab=None,
add_first=True,
###
):
"""
Predict whether a list of protein pairs will interact.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param n0: First protein names
:type n0: list[str]
:param n1: Second protein names
:type n1: list[str]
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
"""
_, p_hat = predict_cmap_interaction(
model,
n0,
n1,
tensors,
use_cuda,
allow_foldseek,
fold_record,
fold_vocab,
add_first,
)
return p_hat
[docs]def interaction_grad(
model,
n0,
n1,
y,
tensors,
accuracy_weight=0.35,
run_tt=False,
glider_weight=0,
glider_map=None,
glider_mat=None,
use_cuda=True,
### Foldseek added here
allow_foldseek=False,
fold_record=None,
fold_vocab=None,
add_first=True,
###
):
"""
Compute gradient and backpropagate loss for a batch.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param n0: First protein names
:type n0: list[str]
:param n1: Second protein names
:type n1: list[str]
:param y: Interaction labels
:type y: torch.Tensor
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param accuracy_weight: Weight on the accuracy objective. Representation loss is :math:`1 - \\text{accuracy_weight}`.
:type accuracy_weight: float
:param run_tt: Use GLIDE top-down supervision
:type run_tt: bool
:param glider_weight: Weight on the GLIDE objective loss. Accuracy loss is :math:`(\\text{GLIDER_BCE}*\\text{glider_weight}) + (\\text{D-SCRIPT_BCE}*(1-\\text{glider_weight}))`.
:type glider_weight: float
:param glider_map: Map from protein identifier to index
:type glider_map: dict[str, int]
:param glider_mat: Matrix with pairwise GLIDE scores
:type glider_mat: np.ndarray
:param use_cuda: Whether to use GPU
:type use_cuda: bool
:return: (Loss, number correct, mean square error, batch size)
:rtype: (torch.Tensor, int, torch.Tensor, int)
"""
c_map_mag, p_hat = predict_cmap_interaction(
model,
n0,
n1,
tensors,
use_cuda,
allow_foldseek,
fold_record,
fold_vocab,
add_first,
)
if use_cuda:
y = y.cuda()
y = Variable(y)
p_hat = p_hat.float()
bce_loss = F.binary_cross_entropy(p_hat.float(), y.float())
if run_tt:
g_score = []
for i in range(len(n0)):
g_score.append(
torch.tensor(
glider_score(n0[i], n1[i], glider_map, glider_mat),
dtype=torch.float64,
)
)
g_score = torch.stack(g_score, 0)
if use_cuda:
g_score = g_score.cuda()
glider_loss = F.binary_cross_entropy(p_hat.float(), g_score.float())
accuracy_loss = (glider_weight * glider_loss) + ((1 - glider_weight) * bce_loss)
else:
accuracy_loss = bce_loss
representation_loss = torch.mean(c_map_mag)
loss = (accuracy_weight * accuracy_loss) + (
(1 - accuracy_weight) * representation_loss
)
b = len(p_hat)
# Backprop Loss
loss.backward()
if use_cuda:
y = y.cpu()
p_hat = p_hat.cpu()
if run_tt:
g_score = g_score.cpu()
with torch.no_grad():
guess_cutoff = 0.5
p_hat = p_hat.float()
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float()
y = y.float()
correct = torch.sum(p_guess == y).item()
mse = torch.mean((y.float() - p_hat) ** 2).item()
return loss, correct, mse, b
[docs]def interaction_eval(
model,
test_iterator,
tensors,
use_cuda,
### Foldseek added here
allow_foldseek=False,
fold_record=None,
fold_vocab=None,
add_first=True,
###
):
"""
Evaluate test data set performance.
:param model: Model to be trained
:type model: dscript.models.interaction.ModelInteraction
:param test_iterator: Test data iterator
:type test_iterator: torch.utils.data.DataLoader
:param tensors: Dictionary of protein names to embeddings
:type tensors: dict[str, torch.Tensor]
:param use_cuda: Whether to use GPU
:type use_cuda: bool
:return: (Loss, number correct, mean square error, precision, recall, F1 Score, AUPR)
:rtype: (torch.Tensor, int, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor)
"""
p_hat = []
true_y = []
for n0, n1, y in test_iterator:
p_hat.append(
predict_interaction(
model,
n0,
n1,
tensors,
use_cuda,
allow_foldseek,
fold_record,
fold_vocab,
add_first,
)
)
true_y.append(y)
y = torch.cat(true_y, 0)
p_hat = torch.cat(p_hat, 0)
if use_cuda:
y.cuda()
p_hat = torch.Tensor([x.cuda() for x in p_hat])
p_hat.cuda()
loss = F.binary_cross_entropy(p_hat.float(), y.float()).item()
b = len(y)
with torch.no_grad():
guess_cutoff = torch.Tensor([0.5]).float()
p_hat = p_hat.float()
y = y.float()
p_guess = (guess_cutoff * torch.ones(b) < p_hat).float()
correct = torch.sum(p_guess == y).item()
mse = torch.mean((y.float() - p_hat) ** 2).item()
tp = torch.sum(y * p_hat).item()
pr = tp / torch.sum(p_hat).item()
re = tp / torch.sum(y).item()
f1 = 2 * pr * re / (pr + re)
y = y.cpu().numpy()
p_hat = p_hat.data.cpu().numpy()
aupr = average_precision(y, p_hat)
return loss, correct, mse, pr, re, f1, aupr
[docs]def train_model(args, output):
# Create data sets
batch_size = args.batch_size
use_cuda = (args.device > -1) and torch.cuda.is_available()
train_fi = args.train
test_fi = args.test
no_augment = args.no_augment
embedding_h5 = args.embedding
########## Foldseek code #########################3
allow_foldseek = args.allow_foldseek
fold_fasta_file = args.foldseek_fasta
# fold_vocab_file = args.foldseek_vocab
add_first = False
fold_record = {}
# fold_vocab = None
if allow_foldseek:
assert fold_fasta_file is not None
fold_fasta = parse_dict(fold_fasta_file)
for rec_k, rec_v in fold_fasta.items():
fold_record[rec_k] = rec_v
##################################################
train_df = pd.read_csv(train_fi, sep="\t", header=None)
train_df.columns = ["prot1", "prot2", "label"]
if no_augment:
train_p1 = train_df["prot1"]
train_p2 = train_df["prot2"]
train_y = torch.from_numpy(train_df["label"].values)
else:
train_p1 = pd.concat((train_df["prot1"], train_df["prot2"]), axis=0).reset_index(
drop=True
)
train_p2 = pd.concat((train_df["prot2"], train_df["prot1"]), axis=0).reset_index(
drop=True
)
train_y = torch.from_numpy(
pd.concat((train_df["label"], train_df["label"])).values
)
train_dataset = PairedDataset(train_p1, train_p2, train_y)
train_iterator = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
collate_fn=collate_paired_sequences,
shuffle=True,
)
log(f"Loaded {len(train_p1)} training pairs", file=output)
output.flush()
test_df = pd.read_csv(test_fi, sep="\t", header=None)
test_df.columns = ["prot1", "prot2", "label"]
test_p1 = test_df["prot1"]
test_p2 = test_df["prot2"]
test_y = torch.from_numpy(test_df["label"].values)
test_dataset = PairedDataset(test_p1, test_p2, test_y)
test_iterator = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
collate_fn=collate_paired_sequences,
shuffle=False,
)
log(f"Loaded {len(test_p1)} test pairs", file=output)
log("Loading embeddings...", file=output)
output.flush()
all_proteins = set(train_p1).union(train_p2).union(test_p1).union(test_p2)
embeddings = {}
with h5py.File(embedding_h5, "r") as h5fi:
for prot_name in tqdm(all_proteins):
embeddings[prot_name] = torch.from_numpy(h5fi[prot_name][:, :])
# embeddings = load_hdf5_parallel(embedding_h5, all_proteins)
# Topsy-Turvy
run_tt = args.run_tt
glider_weight = args.glider_weight
glider_thresh = args.glider_thresh * 100
if run_tt:
log("Running D-SCRIPT Topsy-Turvy:", file=output)
log(f"\tglider_weight: {glider_weight}", file=output)
log(f"\tglider_thresh: {glider_thresh}th percentile", file=output)
log("Computing GLIDER matrix...", file=output)
output.flush()
glider_mat, glider_map = glide_compute_map(
train_df[train_df.iloc[:, 2] == 1], thres_p=glider_thresh
)
else:
glider_mat, glider_map = (None, None)
if args.checkpoint is None:
# Create embedding model
input_dim = args.input_dim
############### foldseek code ###########################
if allow_foldseek and add_first:
input_dim += len(fold_vocab)
##########################################################
projection_dim = args.projection_dim
dropout_p = args.dropout_p
embedding_model = FullyConnectedEmbed(
input_dim, projection_dim, dropout=dropout_p
)
log("Initializing embedding model with:", file=output)
log(f"\tprojection_dim: {projection_dim}", file=output)
log(f"\tdropout_p: {dropout_p}", file=output)
# Create contact model
hidden_dim = args.hidden_dim
kernel_width = args.kernel_width
log("Initializing contact model with:", file=output)
log(f"\thidden_dim: {hidden_dim}", file=output)
log(f"\tkernel_width: {kernel_width}", file=output)
proj_dim = projection_dim
if allow_foldseek and not add_first:
proj_dim += len(fold_vocab)
contact_model = ContactCNN(proj_dim, hidden_dim, kernel_width)
# Create the full model
do_w = not args.no_w
do_pool = args.do_pool
pool_width = args.pool_width
do_sigmoid = not args.no_sigmoid
log("Initializing interaction model with:", file=output)
log(f"\tdo_poool: {do_pool}", file=output)
log(f"\tpool_width: {pool_width}", file=output)
log(f"\tdo_w: {do_w}", file=output)
log(f"\tdo_sigmoid: {do_sigmoid}", file=output)
model = ModelInteraction(
embedding_model,
contact_model,
use_cuda,
do_w=do_w,
pool_size=pool_width,
do_pool=do_pool,
do_sigmoid=do_sigmoid,
)
log(model, file=output)
else:
log(
f"Loading model from checkpoint {args.checkpoint}",
file=output,
)
model = torch.load(args.checkpoint)
model.use_cuda = use_cuda
if use_cuda:
model.cuda()
# Train the model
lr = args.lr
wd = args.weight_decay
num_epochs = args.num_epochs
batch_size = args.batch_size
inter_weight = args.interaction_weight
cmap_weight = 1 - inter_weight
digits = int(np.floor(np.log10(num_epochs))) + 1
save_prefix = args.save_prefix
params = [p for p in model.parameters() if p.requires_grad]
optim = torch.optim.Adam(params, lr=lr, weight_decay=wd)
log(f'Using save prefix "{save_prefix}"', file=output)
log(f"Training with Adam: lr={lr}, weight_decay={wd}", file=output)
log(f"\tnum_epochs: {num_epochs}", file=output)
log(f"\tbatch_size: {batch_size}", file=output)
log(f"\tinteraction weight: {inter_weight}", file=output)
log(f"\tcontact map weight: {cmap_weight}", file=output)
output.flush()
batch_report_fmt = "[{}/{}] training {:.1%}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}"
epoch_report_fmt = "Finished Epoch {}/{}: Loss={:.6}, Accuracy={:.3%}, MSE={:.6}, Precision={:.6}, Recall={:.6}, F1={:.6}, AUPR={:.6}"
N = len(train_iterator) * batch_size
for epoch in range(num_epochs):
model.train()
n = 0
loss_accum = 0
acc_accum = 0
mse_accum = 0
# Train batches
for z0, z1, y in train_iterator:
loss, correct, mse, b = interaction_grad(
model,
z0,
z1,
y,
embeddings,
accuracy_weight=inter_weight,
run_tt=run_tt,
glider_weight=glider_weight,
glider_map=glider_map,
glider_mat=glider_mat,
use_cuda=use_cuda,
allow_foldseek=allow_foldseek,
fold_record=fold_record,
fold_vocab=fold_vocab,
add_first=add_first,
)
n += b
delta = b * (loss - loss_accum)
loss_accum += delta / n
delta = correct - b * acc_accum
acc_accum += delta / n
delta = b * (mse - mse_accum)
mse_accum += delta / n
report = (n - b) // 100 < n // 100
optim.step()
optim.zero_grad()
model.clip()
if report:
tokens = [
epoch + 1,
num_epochs,
n / N,
loss_accum,
acc_accum,
mse_accum,
]
log(batch_report_fmt.format(*tokens), file=output)
output.flush()
model.eval()
with torch.no_grad():
(
inter_loss,
inter_correct,
inter_mse,
inter_pr,
inter_re,
inter_f1,
inter_aupr,
) = interaction_eval(
model,
test_iterator,
embeddings,
use_cuda,
allow_foldseek,
fold_record,
fold_vocab,
add_first,
)
tokens = [
epoch + 1,
num_epochs,
inter_loss,
inter_correct / (len(test_iterator) * batch_size),
inter_mse,
inter_pr,
inter_re,
inter_f1,
inter_aupr,
]
log(epoch_report_fmt.format(*tokens), file=output)
output.flush()
# Save the model
if save_prefix is not None:
save_path = save_prefix + "_epoch" + str(epoch + 1).zfill(digits) + ".sav"
log(f"Saving model to {save_path}", file=output)
model.cpu()
torch.save(model, save_path)
if use_cuda:
model.cuda()
output.flush()
if save_prefix is not None:
save_path = save_prefix + "_final.sav"
log(f"Saving final model to {save_path}", file=output)
model.cpu()
torch.save(model, save_path)
if use_cuda:
model.cuda()
def main(args):
"""
Run training from arguments.
:meta private:
"""
output = args.outfile
if output is None:
output = sys.stdout
else:
output = open(output, "w")
log(f"D-SCRIPT Version {__version__}", file=output, print_also=True)
log(f"Called as: {' '.join(sys.argv)}", file=output, print_also=True)
# Set the device
device = args.device
use_cuda = (device > -1) and torch.cuda.is_available()
if use_cuda:
torch.cuda.set_device(device)
log(
f"Using CUDA device {device} - {torch.cuda.get_device_name(device)}",
file=output,
print_also=True,
)
else:
log("Using CPU", file=output, print_also=True)
device = "cpu"
if args.seed is not None:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_model(args, output)
output.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
add_args(parser)
main(parser.parse_args())