Source code for dscript.pretrained

import os
import sys

import torch

from .models.contact import ContactCNN
from .models.embedding import FullyConnectedEmbed, SkipLSTM
from .models.interaction import ModelInteraction


def build_lm_1(state_dict_path):
    """
    :meta private:
    """
    model = SkipLSTM(21, 100, 1024, 3)
    state_dict = torch.load(state_dict_path)
    model.load_state_dict(state_dict)
    model.eval()
    return model


def build_human_1(state_dict_path):
    """
    :meta private:
    """
    embModel = FullyConnectedEmbed(6165, 100, 0.5)
    conModel = ContactCNN(100, 50, 7)
    model = ModelInteraction(
        embModel,
        conModel,
        use_cuda=True,
        do_w=True,
        do_pool=True,
        do_sigmoid=True,
        pool_size=9,
    )
    state_dict = torch.load(state_dict_path)
    model.load_state_dict(state_dict)
    model.eval()
    return model


VALID_MODELS = {"lm_v1": build_lm_1, "human_v1": build_human_1}


[docs]def get_state_dict(version="human_v1", verbose=True): """ Download a pre-trained model if not already exists on local device. :param version: Version of trained model to download [default: human_1] :type version: str :param verbose: Print model download status on stdout [default: True] :type verbose: bool :return: Path to state dictionary for pre-trained language model :rtype: str """ state_dict_basename = f"dscript_{version}.pt" state_dict_basedir = os.path.dirname(os.path.realpath(__file__)) state_dict_fullname = f"{state_dict_basedir}/{state_dict_basename}" state_dict_url = ( f"http://cb.csail.mit.edu/cb/dscript/data/models/{state_dict_basename}" ) if not os.path.exists(state_dict_fullname): try: import shutil import urllib.request if verbose: print(f"Downloading model {version} from {state_dict_url}...") with urllib.request.urlopen(state_dict_url) as response, open( state_dict_fullname, "wb" ) as out_file: shutil.copyfileobj(response, out_file) except Exception as e: print("Unable to download model - {}".format(e)) sys.exit(1) return state_dict_fullname
[docs]def get_pretrained(version="human_v1"): """ Get pre-trained model object. Currently Available Models ========================== See the `documentation <https://d-script.readthedocs.io/en/main/data.html#trained-models>`_ for most up-to-date list. - ``lm_v1`` - Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_. - ``human_v1`` - Human trained model from D-SCRIPT manuscript. Default: ``human_v1`` :param version: Version of pre-trained model to get :type version: str :return: Pre-trained model :rtype: dscript.models.* """ if version not in VALID_MODELS: raise ValueError("Model {} does not exist".format(version)) state_dict_path = get_state_dict(version) return VALID_MODELS[version](state_dict_path)