Created
August 3, 2020 13:39
-
-
Save philippmuench/6b9bbb9f9f987ab22efb573f9f19160f to your computer and use it in GitHub Desktop.
train for wavenet binary target
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' @title Trains a (mostly) LSTM model on genomic data. Designed for developing genome based language models (GenomeNet) | |
#' | |
#' @description | |
#' Depth and number of neurons per layer of the netwok can be specified. First layer can be a Convolutional Neural Network (CNN) that is designed to capture codons. | |
#' If a path to a folder where FASTA files are located is provided, batches will ge generated using an external generator which | |
#' is recommended for big training sets. Alternative, a dataset can be supplied that holds the preprocessed batches (generated by \code{preprocessSemiRedundant()}) | |
#' and keeps them in RAM. Supports also training on instances with multiple GPUs and scales linear with number of GPUs present. | |
#' @param train_type Either "lm" for language model, "label_header" or "label_folder". Language model is trained to predict next character in sequence. | |
#' label_header/label_folder are trained to predict a corresponding class, given a sequence as input. If "label_header", class will be read from fasta headers. | |
#' If "label_folder", class will be read from folder, i.e. all fasta files in one folder must belong to the same class. | |
#' @param model A keras model. | |
#' @param built_model Call to a function that creates a model. \code{create_model_function} can be either "create_model_lstm_cnn" or "create_model_wavenet". | |
#' In \code{function_args} arguments of the corresponding can be specified, if no argument is given default values will be used. | |
#' Example: \code{built_model = list(create_model_function = "create_model_lstm_cnn", function_args = list(maxlen = 50, layer.size = 32, layers.lstm = 1)} | |
#' @param model_path Path to a pretrained model. | |
#' @param path Path to folder where individual or multiple FASTA files are located for training. If \code{train_type} is \code{label_folder}, should be a vector | |
#' containing a path for each class. | |
#' @param path.val Path to folder where individual or multiple FASTA files are located for validation.If \code{train_type} is \code{label_folder}, should be a vector | |
#' containing a path for each class. | |
#' @param dataset Dataframe holding training samples in RAM instead of using generator. | |
#' @param checkpoint_path Path to checkpoints folder. | |
#' @param validation.split Defines the fraction of the batches that will be used for validation (compared to size of training data). | |
#' @param run.name Name of the run (without file ending). Name will be used to identify output from callbacks. | |
#' @param batch.size Number of samples that are used for one network update. | |
#' @param epochs Number of iterations. | |
#' @param max.queue.size Queue on fit_generator(). | |
#' @param lr.plateau.factor Factor of decreasing learning rate when plateau is reached. | |
#' @param patience Number of epochs waiting for decrease in loss before reducing learning rate. | |
#' @param cooldown Number of epochs without changing learning rate. | |
#' @param steps.per.epoch Number of batches to finish one epoch. | |
#' @param step Frequency of sampling steps. | |
#' @param randomFiles TRUE/FALSE go through files sequentially or shuffle beforehand. | |
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector. | |
#' @param initial_epoch Epoch at which to start training, set to 0 if no \code{model_path} argument is given. Note that network | |
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds. | |
#' @param tensorboard.log Path to tensorboard log directory. | |
#' @param save_best_only Only save model that improved on best val_loss score. | |
#' @param compile Whether to compile the model after loading. | |
#' @param solver Optimization method, options are "adam", "adagrad", "rmsprop" or "sgd". Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE. | |
#' Otherwise solver is determined when model is created. | |
#' @param learning.rate Learning rate for optimizer. Only used when pretrained model is given (\code{model_path} is not NULL) and compile is FALSE. | |
#' Otherwise learning rate is determined when model is created. | |
#' @param seed Sets seed for set.seed function, for reproducible results when using \code{randomFiles} or \code{shuffleFastaEntries} | |
#' @param shuffleFastaEntries Logical, shuffle entries in file. | |
#' @param output List of optional outputs, no output if none is TRUE. | |
#' @param tb_images Boolean, whether to show plots in tensorboard. Note this doubles the time needed for validation step. | |
#' @param format File format, "fasta" or "fastq". | |
#' @param fileLog Write name of files to csv file if path is specified. | |
#' @param labelVocabulary Character vector of possible targets. Targets outside \code{labelVocabulary} will get discarded. | |
#' @param numberOfFiles Use only specified number of files, ignored if greater than number of files in corpus.dir. | |
#' @param reverseComplements Logical, half of batch contains sequences and other its reverse complements. Reverse complement | |
#' is given by reversed order of sequence and switching A/T and C/G. \code{batch.size} argument has to be even, otherwise 1 will be added | |
#' to \code{batch.size} | |
#' @param wavenet_format Boolean. If true target is a sequence equal to input shifted by one position to the right (last target position is not in input). | |
#' If sequence is ACGT, maxlen = 3, first input corresponds to ACG and target to CGT. | |
#' @param target_middle Boolean, target is in middle of sequence. | |
#' @param reset_states Boolean, whether to reset hidden states of RNN layer at every new input file. | |
#' @param ambiguous_nuc How to handle nucleotides outside vocabulary, either "zero", "discard", "empirical" or "equal". If "zero", input gets encoded as zero vector; | |
#' if "equal" input is 1/length(vocabulary) x length(vocabulary). If "discard" samples containing nucleotides outside vocabulary get discarded. | |
#' If "empirical" use nucleotide distribution of current file. | |
#' @param percentage_per_file Numerical value between 0 and 1. Proportion of possible samples to take from one file. Takes samples from random subsequence. | |
#' @export | |
trainNetwork <- function(train_type = "lm", | |
model_path = NULL, | |
built_model = list(create_model_function = NULL, function_args = list()), | |
model = NULL, | |
path = NULL, | |
path.val = NULL, | |
dataset = NULL, | |
checkpoint_path, | |
validation.split = 0.2, | |
run.name = "run", | |
batch.size = 64, | |
epochs = 10, | |
max.queue.size = 100, | |
lr.plateau.factor = 0.9, | |
patience = 20, | |
cooldown = 1, | |
steps.per.epoch = 1000, | |
step = 1, | |
randomFiles = FALSE, | |
initial_epoch = 0, | |
vocabulary = c("a", "c", "g", "t"), | |
tensorboard.log, | |
save_best_only = TRUE, | |
compile = TRUE, | |
learning.rate = NULL, | |
solver = NULL, | |
seed = c(1234, 4321), | |
shuffleFastaEntries = FALSE, | |
output = list(none = FALSE, | |
checkpoints =TRUE, | |
tensorboard = TRUE, | |
log = FALSE, | |
serialize_model = FALSE, | |
full_model = FALSE | |
), | |
tb_images = FALSE, | |
format = "fasta", | |
fileLog = NULL, | |
labelVocabulary = NULL, | |
numberOfFiles = NULL, | |
reverseComplements = FALSE, | |
wavenet_format = FALSE, | |
target_middle = FALSE, | |
reset_states = FALSE, | |
ambiguous_nuc = "zero", | |
percentage_per_file = NULL) { | |
stopifnot(train_type %in% c("lm", "label_header", "label_folder")) | |
stopifnot(ambiguous_nuc %in% c("zero", "equal", "discard", "empirical")) | |
if (is.null(built_model$create_model_function) + is.null(model) == 0) { | |
stop("Two models were specified. Set either model or built_model$create_model_function argument to NULL.") | |
} | |
if (train_type == "lm") { | |
labelGen <- FALSE | |
labelByFolder <- FALSE | |
} | |
if (train_type == "label_header") { | |
labelGen <- TRUE | |
labelByFolder <- FALSE | |
stopifnot(!is.null(labelVocabulary)) | |
} | |
if (train_type == "label_folder") { | |
labelGen <- TRUE | |
labelByFolder <- TRUE | |
stopifnot(!is.null(labelVocabulary)) | |
stopifnot(length(path) == length(labelVocabulary)) | |
} | |
if (output$none) { | |
output$checkpoints <- FALSE | |
output$tensorboard <- FALSE | |
output$log <- FALSE | |
output$serialize_model <- FALSE | |
output$full_model <- FALSE | |
} | |
# set model arguments | |
if (!is.null(built_model[[1]])) { | |
if (built_model[[1]] == "create_model_lstm_cnn_target_middle") { | |
target_middle <- TRUE | |
wavenet_format <- FALSE | |
} | |
if (built_model[[1]] == "create_model_lstm_cnn") { | |
target_middle <- FALSE | |
wavenet_format <- FALSE | |
} | |
if (built_model[[1]] == "create_model_wavenet") { | |
target_middle <- TRUE | |
wavenet_format <- TRUE | |
} | |
new_arguments <- names(built_model[[2]]) | |
default_arguments <- formals(built_model[[1]]) | |
# overwrite default arguments | |
for (arg in new_arguments) { | |
default_arguments[arg] <- built_model[[2]][arg] | |
} | |
# create model | |
if (built_model[[1]] == "create_model_lstm_cnn") { | |
formals(create_model_lstm_cnn) <- default_arguments | |
model <- create_model_lstm_cnn() | |
} | |
if (built_model[[1]] == "create_model_lstm_cnn_target_middle") { | |
formals(create_model_lstm_cnn_target_middle) <- default_arguments | |
model <- create_model_lstm_cnn_target_middle() | |
} | |
if (built_model[[1]] == "create_model_wavenet") { | |
if (!wavenet_format) { | |
warning("Argument wavenet_format should be TRUE when using wavenet architecture.") | |
} | |
formals(create_model_wavenet) <- default_arguments | |
model <- create_model_wavenet() | |
} | |
} | |
# function arguments | |
argumentList <- as.list(match.call(expand.dots=FALSE)) | |
label.vocabulary.size <- length(labelVocabulary) | |
vocabulary.size <- length(vocabulary) | |
# extract maxlen from model | |
if (!target_middle) { | |
maxlen <- model$input$shape[[2]] | |
} else { | |
maxlen <- model$input[[1]]$shape[[2]] + model$input[[2]]$shape[[2]] | |
} | |
if (labelByFolder) { | |
if (length(path) == 1) warning("Training with just one label") | |
} | |
if (output$checkpoints) { | |
## create folder for checkpoints using run.name | |
## filenames contain epoch, validation loss and validation accuracy | |
checkpoint_dir <- paste0(checkpoint_path, "/", run.name, "_checkpoints") | |
dir.create(checkpoint_dir, showWarnings = FALSE) | |
filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-val_loss{val_loss:.2f}-val_acc{val_acc:.3f}.hdf5") | |
} | |
# Check if fileLog is unique | |
if (!is.null(fileLog) && dir.exists(fileLog)) { | |
stop(paste0("fileLog entry is already present. Please give this file a unique name.")) | |
} | |
# Check if run.name is unique | |
if (dir.exists(file.path(tensorboard.log, run.name)) & output$tensorboard) { | |
stop(paste0("Tensorboard entry '", run.name , "' is already present. Please give your run a unique name.")) | |
} | |
# Load pretrained model | |
if (!is.null(model_path)) { | |
# epochs arguments can be misleading | |
if (!missing(initial_epoch)) { | |
if (initial_epoch >= epochs) { | |
stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.") | |
} | |
} | |
# extract initial_epoch from filename if no argument is given | |
if (is.null(initial_epoch)) { | |
epochFromFilename <- stringr::str_extract(model_path, "Ep.\\d+") | |
initial_epoch <- as.integer(substring(epochFromFilename, 4, nchar(epochFromFilename))) | |
if (initial_epoch >= epochs) { | |
stop("Networks trains (epochs - initial_epochs) rounds overall, NOT epochs rounds. Increase epochs or decrease initial_epoch.") | |
} | |
} | |
# load model | |
model <- keras::load_model_hdf5(model_path, compile = compile) | |
model$hparam <- reticulate::dict() | |
summary(model) | |
# extract maxlen | |
if (!target_middle) { | |
maxlen <- model$input$shape[[2]] | |
} else { | |
maxlen <- model$input[[1]]$shape[[2]] + model$input[[2]]$shape[[2]] | |
} | |
if (compile & (!is.null(learning.rate)|!is.null(solver))) { | |
message("Arguments for solver and learning rate will be ignored. Set compile to FALSE to use custom solver and learning rate.") | |
} | |
if (!compile) { | |
# choose optimization method | |
if (solver == "adam") | |
optimizer <- | |
keras::optimizer_adam(lr = learning.rate) | |
if (solver == "adagrad") | |
optimizer <- | |
keras::optimizer_adagrad(lr = learning.rate) | |
if (solver == "rmsprop") | |
optimizer <- | |
keras::optimizer_rmsprop(lr = learning.rate) | |
if (solver == "sgd") | |
optimizer <- | |
keras::optimizer_sgd(lr = learning.rate) | |
model %>% keras::compile(loss = "categorical_crossentropy", | |
optimizer = optimizer, metrics = c("acc", percentage_training_files_cb)) | |
} | |
} | |
# if no dataset is supplied, external fasta generator will generate batches | |
if (is.null(dataset)) { | |
message("Starting fasta generator...") | |
# tempory file to log training data | |
removeLog <- FALSE | |
if (is.null(fileLog)) { | |
removeLog <- TRUE | |
fileLog <- tempfile(pattern = "", fileext = ".csv") | |
} | |
if (reset_states) { | |
fileLogVal <- tempfile(pattern = "", fileext = ".csv") | |
} else { | |
fileLogVal <- NULL | |
} | |
if (!labelGen) { | |
# generator for training | |
gen <- fastaFileGenerator(corpus.dir = path, batch.size = batch.size, | |
maxlen = maxlen, step = step, randomFiles = randomFiles, | |
vocabulary = vocabulary, seed = seed[1], | |
shuffleFastaEntries = shuffleFastaEntries, format = format, | |
fileLog = fileLog, reverseComplements = reverseComplements, | |
wavenet_format = wavenet_format, target_middle = target_middle, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
# generator for validation | |
gen.val <- fastaFileGenerator(corpus.dir = path.val, batch.size = batch.size, | |
maxlen = maxlen, step = step, randomFiles = randomFiles, | |
vocabulary = vocabulary, seed = seed[2], | |
shuffleFastaEntries = shuffleFastaEntries, format = format, | |
fileLog = fileLogVal, reverseComplements = FALSE, | |
wavenet_format = wavenet_format, target_middle = target_middle, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
if (tb_images) { | |
# TODO: check if gen_cb uses same data if max_samples_per_file != NULL | |
gen_cb <- fastaFileGenerator(corpus.dir = path.val, batch.size = batch.size, | |
maxlen = maxlen, step = step, randomFiles = randomFiles, | |
vocabulary = vocabulary, seed = seed[2], | |
shuffleFastaEntries = shuffleFastaEntries, format = format, | |
fileLog = NULL, reverseComplements = FALSE, | |
wavenet_format = wavenet_format, target_middle = target_middle, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
} | |
# label generator | |
} else { | |
# label by folder | |
if (labelByFolder) { | |
# initialize training generators | |
initializeGenerators(directories = path, format = format, batch.size = batch.size, maxlen = maxlen, vocabulary = vocabulary, | |
verbose = FALSE, randomFiles = randomFiles, step = step, showWarnings = FALSE, seed = seed[1], | |
shuffleFastaEntries = shuffleFastaEntries, numberOfFiles = numberOfFiles, | |
fileLog = fileLog, reverseComplements = reverseComplements, val = FALSE, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
# initialize validation generators | |
initializeGenerators(directories = path.val, format = format, batch.size = batch.size, maxlen = maxlen, | |
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step, | |
showWarnings = FALSE, seed = seed[2], shuffleFastaEntries = shuffleFastaEntries, | |
numberOfFiles = NULL, fileLog = fileLogVal, reverseComplements = FALSE, val = TRUE, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
gen <- labelByFolderGeneratorWrapper(val = FALSE, path = path) | |
gen.val <- labelByFolderGeneratorWrapper(val = TRUE, path = path.val) | |
if (tb_images) { | |
gen_cb <- labelByFolderGeneratorWrapper(val = TRUE, path = path.val) | |
} | |
} else { | |
# generator for training | |
gen <- fastaLabelGenerator(corpus.dir = path, format = format, batch.size = batch.size, maxlen = maxlen, | |
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step, | |
showWarnings = FALSE, seed = seed[1], shuffleFastaEntries = shuffleFastaEntries, | |
fileLog = fileLog, labelVocabulary = labelVocabulary, reverseComplements = reverseComplements, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
# generator for validation | |
gen.val <- fastaLabelGenerator(corpus.dir = path.val, format = format, batch.size = batch.size, maxlen = maxlen, | |
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step, | |
showWarnings = FALSE, seed = seed[2], shuffleFastaEntries = shuffleFastaEntries, | |
fileLog = fileLogVal, labelVocabulary = labelVocabulary, reverseComplements = FALSE, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
if (tb_images) { | |
gen_cb <- fastaLabelGenerator(corpus.dir = path.val, format = format, batch.size = batch.size, maxlen = maxlen, | |
vocabulary = vocabulary, verbose = FALSE, randomFiles = randomFiles, step = step, | |
showWarnings = FALSE, seed = seed[2], shuffleFastaEntries = shuffleFastaEntries, | |
fileLog = fileLogVal, labelVocabulary = labelVocabulary, reverseComplements = FALSE, | |
ambiguous_nuc = ambiguous_nuc, percentage_per_file = percentage_per_file) | |
} | |
} | |
} | |
# callbacks | |
callbacks <- vector("list") | |
callbacks[[1]] <- reduce_lr_cb(patience = patience, cooldown = cooldown, lr.plateau.factor = lr.plateau.factor) | |
if (output$log) { | |
callbacks <- c(callbacks, log_cb(run.name)) | |
} | |
if (output$tensorboard) { | |
# count files in path | |
num_train_files <- rep(0, length(path)) | |
if (train_type != "label_folder" && endsWith(path, paste0(".", format))) { | |
num_train_files <- 1 | |
} else { | |
for (k in 1:length(path)) { | |
if (endsWith(path[k], paste0(".", format))) { | |
num_train_files[k] <- 1 | |
} else { | |
num_train_files[k] <- length(list.files(path[k], pattern = paste0(".", format))) | |
} | |
} | |
} | |
complete_tb <- tensorboard_complete_cb(default_arguments = default_arguments, model = model, tensorboard.log = tensorboard.log, run.name = run.name, train_type = train_type, | |
model_path = model_path, path = path, validation.split = validation.split, batch.size = batch.size, epochs = epochs, | |
max.queue.size = max.queue.size, lr.plateau.factor = lr.plateau.factor, patience = patience, cooldown = cooldown, | |
steps.per.epoch = steps.per.epoch, step = step, randomFiles = randomFiles, initial_epoch = initial_epoch, vocabulary = vocabulary, | |
learning.rate = learning.rate, shuffleFastaEntries = shuffleFastaEntries, labelVocabulary = labelVocabulary, solver = solver, | |
numberOfFiles = numberOfFiles, reverseComplements = reverseComplements, wavenet_format = wavenet_format, | |
create_model_function = built_model$create_model_function, vocabulary.size = vocabulary.size, gen_cb = gen_cb, argumentList = argumentList, | |
maxlen = maxlen, labelGen = labelGen, labelByFolder = labelByFolder, label.vocabulary.size = label.vocabulary.size, tb_images = tb_images, | |
target_middle = target_middle, num_train_files = num_train_files, fileLog = fileLog, percentage_per_file = percentage_per_file) | |
callbacks <- c(callbacks, complete_tb) | |
} | |
if (output$checkpoints) { | |
callbacks <- c(callbacks, checkpoint_cb(filepath = filepath_checkpoints, save_weights_only = TRUE, | |
save_best_only = save_best_only)) | |
} | |
if (reset_states) { | |
callbacks <- c(callbacks, reset_states_cb(fileLog = fileLog, fileLogVal = fileLogVal)) | |
} | |
# training | |
message("Start training ...") | |
history <- | |
model %>% keras::fit_generator( | |
generator = gen, | |
validation_data = gen.val, | |
validation_steps = ceiling(steps.per.epoch * validation.split), | |
steps_per_epoch = steps.per.epoch, | |
max_queue_size = max.queue.size, | |
epochs = epochs, | |
initial_epoch = initial_epoch, | |
callbacks = callbacks, | |
verbose = 1 | |
) | |
} else { | |
message("Start training ...") | |
history <- model %>% keras::fit( | |
dataset$X, | |
dataset$Y, | |
batch_size = batch.size, | |
validation_split = validation.split, | |
epochs = epochs) | |
} | |
if (removeLog) { | |
file.remove(fileLog) | |
} | |
# save final model | |
message("Training done.") | |
if (output$serialize_model) { | |
Rmodel <- | |
keras::serialize_model(model, include_optimizer = TRUE) | |
save(Rmodel, file = paste0(run.name, "_full_model.Rdata")) | |
} | |
if (output$full_model) { | |
keras::save_model_hdf5( | |
model, | |
paste0(run.name, "_full_model.hdf5"), | |
overwrite = TRUE, | |
include_optimizer = TRUE | |
) | |
} | |
return(history) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment