import os
import os.path
import sys
from functools import wraps
import torch
from .models.contact import ContactCNN
from .models.embedding import FullyConnectedEmbed, SkipLSTM
from .models.interaction import ModelInteraction
from .utils import log
def build_lm_1(state_dict_path):
"""
:meta private:
"""
model = SkipLSTM(21, 100, 1024, 3)
state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)
model.eval()
return model
def build_human_1(state_dict_path):
"""
:meta private:
"""
embModel = FullyConnectedEmbed(6165, 100, 0.5)
conModel = ContactCNN(100, 50, 7)
model = ModelInteraction(
embModel,
conModel,
use_cuda=True,
do_w=True,
do_pool=True,
do_sigmoid=True,
pool_size=9,
)
state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)
model.eval()
return model
def build_human_tt3d(state_dict_path):
"""
:meta private:
"""
embModel = FullyConnectedEmbed(6165, 100, 0.5)
conModel = ContactCNN(121, 50, 7)
model = ModelInteraction(
embModel,
conModel,
use_cuda=True,
do_w=True,
do_pool=True,
do_sigmoid=True,
pool_size=9,
)
state_dict = torch.load(state_dict_path)
model.load_state_dict(state_dict)
model.eval()
return model
VALID_MODELS = {
"human_v1": build_human_1, # Original D-SCRIPT
"human_v2": build_human_1, # Topsy-Turvy
"human_tt3d": build_human_tt3d, # TT3D
"lm_v1": build_lm_1, # Bepler & Berger 2019
}
STATE_DICT_BASENAME = "dscript_{version}.pt"
ROOT_URL = "http://cb.csail.mit.edu/cb/dscript/data/models"
[docs]def get_state_dict_path(version: str) -> str:
state_dict_basedir = os.path.dirname(os.path.realpath(__file__))
state_dict_fullname = (
f"{state_dict_basedir}/{STATE_DICT_BASENAME.format(version=version)}"
)
return state_dict_fullname
[docs]def get_state_dict(version="human_v2", verbose=True):
"""
Download a pre-trained model if not already exists on local device.
:param version: Version of trained model to download [default: human_1]
:type version: str
:param verbose: Print model download status on stdout [default: True]
:type verbose: bool
:return: Path to state dictionary for pre-trained language model
:rtype: str
"""
state_dict_fullname = get_state_dict_path(version)
state_dict_url = f"{ROOT_URL}/{STATE_DICT_BASENAME.format(version=version)}"
if not os.path.exists(state_dict_fullname):
try:
import shutil
import urllib.request
from urllib.parse import urlparse
# Validate URL scheme for security
parsed_url = urlparse(state_dict_url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError(
f"Invalid URL scheme '{parsed_url.scheme}'. Only http and https are allowed."
)
if verbose:
log(f"Downloading model {version} from {state_dict_url}...")
with (
urllib.request.urlopen(state_dict_url) as response,
open(state_dict_fullname, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
except Exception as e:
log(f"Unable to download model - {e}")
sys.exit(1)
return state_dict_fullname
[docs]def retry(retry_count: int):
def decorate(func):
@wraps(func)
def retry_wrapper(*args, **kwargs):
attempt = 0
if len(args):
version = args[0]
elif "version" in kwargs:
version = kwargs["version"]
else:
version = func.__defaults__[0]
while attempt < retry_count:
try:
result = func(*args, **kwargs)
return result
except RuntimeError as e:
log(
f"\033[93mLoading {version} from disk failed. Retrying download attempt: {attempt + 1}\033[0m"
)
if e.args[0].startswith("unexpected EOF"):
state_dict_fullname = get_state_dict_path(version)
if os.path.exists(state_dict_fullname):
os.remove(state_dict_fullname)
else:
raise e
attempt += 1
raise Exception(f"Failed to download {version}")
return retry_wrapper
return decorate
[docs]@retry(3)
def get_pretrained(version="human_v2"):
"""
Get pre-trained model object.
See the `documentation <https://d-script.readthedocs.io/en/main/data.html#trained-models>`_ for most up-to-date list.
- ``lm_v1`` - Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
- ``human_v1`` - Human trained model from D-SCRIPT manuscript.
- ``human_v2`` - Human trained model from Topsy-Turvy manuscript.
- ``human_tt3d`` - Human trained model with FoldSeek sequence inputs
Default: ``human_v2``
:param version: Version of pre-trained model to get
:type version: str
:return: Pre-trained model
:rtype: dscript.models.*
"""
if version not in VALID_MODELS:
raise ValueError(f"Model {version} does not exist")
state_dict_path = get_state_dict(version)
return VALID_MODELS[version](state_dict_path)