-
-
Save r-wheeler/ad2d34b8b07e41f87476 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import luigi | |
import luigi.scheduler | |
import luigi.worker | |
import logging as log | |
import socket | |
from datetime import datetime as dt | |
from ConfigParser import ConfigParser | |
from components import AssessSVMRegression | |
from components import CreateProteinList | |
from components import CreateReport | |
from components import CreateSparseTestDataset | |
from components import CreateSparseTrainDataset | |
from components import CreateUniqueSignaturesCopy | |
from components import CrossValidate | |
from components import ExistingSmiles | |
from components import ExtractDataFromChembl | |
from components import ExtractDataFromChemblSQLite | |
from components import FilterMinCompounds | |
from components import GenerateSignaturesFilterSubstances | |
from components import SampleTrainAndTest | |
from components import SplitDataPerProtein | |
from components import TrainSVMModel | |
from components import WorkflowUtils | |
config = ConfigParser() | |
config.read('mm_workflow_config.ini') | |
########################################### | |
# W O R K F L O W D E F I N I T I O N # | |
########################################### | |
class MMExistingSmiles(ExistingSmiles.ExistingSmiles): | |
dataset_name = luigi.Parameter() | |
pass | |
class MMGenerateSignaturesFilterSubstances(GenerateSignaturesFilterSubstances.GenerateSignaturesFilterSubstances): | |
dataset_name = luigi.Parameter() | |
def requires(self): | |
return MMExistingSmiles( | |
dataset_name=self.dataset_name) | |
class MMCreateUniqueSignaturesCopy(CreateUniqueSignaturesCopy.CreateUniqueSignaturesCopy): | |
dataset_name = luigi.Parameter() | |
def requires(self): | |
return MMGenerateSignaturesFilterSubstances( | |
dataset_name=self.dataset_name) | |
class MMSampleTrainAndTest(SampleTrainAndTest.SampleTrainAndTest): | |
dataset_name = luigi.Parameter() | |
unique_run_id = luigi.Parameter() | |
def requires(self): | |
return MMCreateUniqueSignaturesCopy( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id) | |
class MMCreateSparseTrainDataset(CreateSparseTrainDataset.CreateSparseTrainDataset): | |
dataset_name = luigi.Parameter() | |
unique_run_id = luigi.Parameter() | |
test_size = luigi.Parameter() | |
training_size = luigi.Parameter() | |
sampling_method = luigi.Parameter() | |
def requires(self): | |
return MMSampleTrainAndTest( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
class MMCreateSparseTestDataset(CreateSparseTestDataset.CreateSparseTestDataset): | |
dataset_name = luigi.Parameter() | |
unique_run_id = luigi.Parameter() | |
test_size = luigi.Parameter() | |
training_size = luigi.Parameter() | |
sampling_method = luigi.Parameter() | |
def requires(self): | |
sample_traintest_task = MMSampleTrainAndTest( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
sparse_train_task = MMCreateSparseTrainDataset( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
return {"sampletraintest" : sample_traintest_task, | |
"sparsetrain" : sparse_train_task} | |
class MMTrainSVMModel(TrainSVMModel.TrainSVMModel): | |
dataset_name = luigi.Parameter() | |
unique_run_id = luigi.Parameter() | |
test_size = luigi.Parameter() | |
training_size = luigi.Parameter() | |
sampling_method = luigi.Parameter() | |
def requires(self): | |
return MMCreateSparseTrainDataset( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
class MMAssessSVMRegression(AssessSVMRegression.AssessSVMRegression): | |
dataset_name = luigi.Parameter() | |
unique_run_id = luigi.Parameter() | |
test_size = luigi.Parameter() | |
training_size = luigi.Parameter() | |
sampling_method = luigi.Parameter() | |
def requires(self): | |
svm_task = MMTrainSVMModel( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
testcsr_task = MMCreateSparseTestDataset( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
return { "svm" : svm_task, | |
"testcsr" : testcsr_task } | |
class MMCreateReport(CreateReport.CreateReport): | |
dataset_name = luigi.Parameter() | |
unique_run_id = luigi.Parameter() | |
test_size = luigi.Parameter() | |
training_size = luigi.Parameter() | |
sampling_method = luigi.Parameter() | |
def requires(self): | |
return MMAssessSVMRegression( | |
dataset_name=self.dataset_name, | |
unique_run_id=self.unique_run_id, | |
test_size=self.test_size, | |
training_size=self.training_size, | |
sampling_method=self.sampling_method) | |
class MMRunAll(luigi.Task): | |
def requires(self): | |
for dataset in ["psa", "alogp", "acd_logd", "acd_logp", "acd_most_apka", "acd_most_bpka"]: | |
for trainsize in [100, 800, 5000, 20000, 40000, 80000, 160000]: | |
for replicateid in ["r1", "r2", "r3", "s1", "s2", "s3"]: | |
yield MMCreateReport( | |
dataset_name = dataset, | |
unique_run_id = replicateid, | |
test_size = "10000", | |
training_size = str(trainsize), | |
sampling_method = "random") | |
# From the workflow defined above, add the task to be run up to, specified on the commandline | |
if __name__ == '__main__': | |
luigi.run(main_task_cls=MMRunAll) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment