Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions run/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from torch import nn
from explainn.utils.tools import pearson_loss

CRITERIONS = {
"bcewithlogits": nn.BCEWithLogitsLoss(),
"crossentropy": nn.CrossEntropyLoss(),
"mse": nn.MSELoss(),
"pearson": pearson_loss,
"poissonnll": nn.PoissonNLLLoss(),
}

OPTIMIZERS = {
"adam": torch.optim.Adam,
"sgd": torch.optim.SGD
}

CONFIG_REQUIRED_FIELDS = {
"data": {
"input_files": list,
"output_dir": str,
"prefix": str,
"rev_complement": bool,
"input_length": int,
"intermediates": {
"training_file": str,
"validation_file": str,
"test_file": str,
},
},
"cnn": {
"filter_size": int,
"num_fc": int,
"num_units": int,
"pool_size": int,
"pool_stride": int,
},
"training": {
"cpu_threads": int,
"batch_size": int,
"num_epochs": int,
"checkpoint": int,
"patience": int,
"trim_weights": bool,
},
"optimizer": {"criterion": str, "lr": float, "optimizer": str},
"interpretation": {
"model_file": str,
"cpu_threads": int,
"batch_size": int,
"num_well_pred_seqs": int,
"correlation": int,
"exact_match": bool,
"percentile_bottom": int,
"percentile_top": int,
},
"options": {"debugging": bool, "use_time": bool, "store_intermediates": bool},
"postprocess": {
"cpu_threads": int,
"target_file": str,
"tomtom": {
"dist": str,
"evalue": bool,
"min_overlap": int,
"motif_pseudo": float,
"threshold": float,
},
},
}
72 changes: 52 additions & 20 deletions run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import click
import json
import logging

from explainn.train.train import train_explainn
from explainn.utils.tools import pearson_loss
Expand All @@ -12,34 +13,67 @@
from train import run_train
from test import test_model
from interpret import interpret_results
from utils import save_data_splits
from utils import save_data_splits, validate_config


# Setup logging
logging.basicConfig(
format="{asctime} - {name} - {levelname} - {message}",
style="{",
datefmt="%Y-%m-%d %H:%M",
level=logging.INFO,
)
logger = logging.getLogger(__name__)


CONTEXT_SETTINGS = {
"help_option_names": ["-h", "--help"],
}


@click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS)
@click.argument(
"config_file",
type=click.Path(exists=True, resolve_path=True),
)
def main(**args):
# Read config file
# TODO: Validate the fields of the config file
with open(args["config_file"]) as f:
config = json.load(f)

# TODO: Check that output dir exists

# Validate the fields of the config file
try:
validate_config(config)
logging.info("Config file validated.")
except Exception as e:
logging.error(str(e))

# Check that output dir exists
output_dir = config["data"]["output_dir"]
if not os.path.isdir(output_dir):
raise OSError(
f"The output directory: {output_dir} does not exist.\n"
f"Check the path relative to the current working directory: {os.getcwd()}"
)

# TODO: Add preprocessing steps as arguments/config, eg. match-seqs-by-gc,
# subsample-seqs-by-gc, resize, etc.


# subsample-seqs-by-gc, resize, etc.
if config["preprocessing"]["match_seqs_by_gc"]:
# TODO: perform match seqs by gc
pass
if config["preprocessing"]["subsample_seqs_by_gc"]:
# TODO: perform subsample_seqs_by_gc
pass
if config["preprocessing"]["resize"]:
# TODO Perform resize?
pass

# Preprocess the data
# TODO: Add this as an argument/in config
classes = combine_seq_files(config["data"]["input_files"])
splits = json2explainn(classes)
save_data_splits(config["data"]["output_dir"],
save_data_splits(
config["data"]["output_dir"],
splits[0],
splits[1],
splits[2],
Expand All @@ -48,7 +82,9 @@ def main(**args):
# TODO: Update config file with output location? Where to store path to intermediates

if config["options"]["store_intermediates"]:
handle = open(os.path.join(config["data"]["output_dir"], "combined_data.json"), "wt")
handle = open(
os.path.join(config["data"]["output_dir"], "combined_data.json"), "wt"
)
json.dump(classes, handle, indent=4, sort_keys=True)
handle.close()

Expand All @@ -64,30 +100,26 @@ def main(**args):
# Finetune the model
# TODO: Specify this with config/arguments


# Further interpretation
# TODO: Specify these with config/arguments
# MEME to logos
meme2logo(config)

# MEME to scores
#meme2scores(config)
# meme2scores(config)

# MEME to clusters
#meme2clusters(config)
# meme2clusters(config)

# Tomtom
#tomtom(config)
# tomtom(config)

# JASPAR to logos
#jaspar2logo(config)
# jaspar2logo(config)

# PWM to scores
#pwm2scores(config)



# pwm2scores(config)


if __name__=='__main__':
main()
if __name__ == "__main__":
main()
18 changes: 15 additions & 3 deletions run/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from explainn.models.networks import ExplaiNN
from explainn.interpretation.interpretation import get_explainn_predictions
from run.utils import (get_file_handle, get_seqs_labels_ids, get_data_loader,
get_device, data_split_names, get_criterion)
get_device, data_split_names, get_criterion, validate_config)

CONTEXT_SETTINGS = {
"help_option_names": ["-h", "--help"],
Expand All @@ -40,11 +40,23 @@ def main(**args):
"""
"""
# Read config file
# TODO: Validate the fields of the config file
with open(args["config_file"]) as f:
config = json.load(f)

# TODO: Check that output dir exists
# Validate the fields of the config file
try:
validate_config(config)
logging.info("Config file validated.")
except Exception as e:
logging.error(str(e))

# Check that output dir exists
output_dir = config["data"]["output_dir"]
if not os.path.isdir(output_dir):
raise OSError(
f"The output directory: {output_dir} does not exist.\n"
f"Check the path relative to the current working directory: {os.getcwd()}"
)

test_model(config)

Expand Down
43 changes: 24 additions & 19 deletions run/train.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/usr/bin/env python

import logging
import os
import sys
import time
import torch
import click
import json
import constants

import pandas as pd

Expand All @@ -15,7 +17,7 @@
from explainn.train.train import train_explainn
from explainn.models.networks import ExplaiNN
from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader,
get_device, data_split_names, get_criterion)
get_device, data_split_names, get_criterion, validate_config)

CONTEXT_SETTINGS = {
"help_option_names": ["-h", "--help"],
Expand All @@ -29,11 +31,23 @@ def main(**args):
"""
"""
# Read config file
# TODO: Validate the fields of the config file
with open(args["config_file"]) as f:
config = json.load(f)

# TODO: Check that output dir exists
# Validate the fields of the config file
try:
validate_config(config)
logging.info("Config file validated.")
except Exception as e:
logging.error(str(e))

# Check that output dir exists
output_dir = config["data"]["output_dir"]
if not os.path.isdir(output_dir):
raise OSError(
f"The output directory: {output_dir} does not exist.\n"
f"Check the path relative to the current working directory: {os.getcwd()}"
)

run_train(config)

Expand Down Expand Up @@ -76,15 +90,11 @@ def run_train(config):
try:
criterion = get_criterion()[config["optimizer"]["criterion"].lower()]
except KeyError:
# TODO: Create error for this instead of print statement
print("""Criterion not found, please select from the following list:
BCEWithLogits
CrossEntropy
MSE
Pearson
PoissonNLL
""")
return
raise KeyError(
f"Invalid criterion '{config['optimizer']['criterion']}'. "
f"Please choose one of: {', '.join(get_criterion().keys())}"
)


# Get model
m = ExplaiNN(config["cnn"]["num_units"], config["data"]["input_length"],
Expand Down Expand Up @@ -116,13 +126,8 @@ def run_train(config):
def _get_optimizer(optimizer, parameters, lr=0.0005):
"""
"""
# TODO: Change this to a map
if optimizer.lower() == "adam":
return torch.optim.Adam(parameters, lr=lr)
elif optimizer.lower() == "sgd":
return torch.optim.SGD(parameters, lr=lr)


return constants.OPTIMIZERS[optimizer.lower()](parameters, lr=lr)

def _train(train_loader, test_loader, model, device, criterion, optimizer,
num_epochs=100, output_dir="./", name_ind=None, verbose=False,
trim_weights=False, checkpoint=0, patience=0):
Expand Down
Loading