Last active
August 1, 2020 21:44
-
-
Save AvantiShri/2d166f201716d8d019c979b32dc70767 to your computer and use it in GitHub Desktop.
Code for running In-Silico Mutagenesis (ISM)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Author: Avanti Shrikumar | |
Here's a gist containing code to run in-silico mutagenesis (ISM) | |
on a model that takes one-hot encoded DNA sequence as the input. The | |
ISM score at a base is the prediction when that base is | |
present minus the average prediction across all 4 possible bases at | |
at that position. | |
"prediction_func" needs to be a function that maps one-hot encoded sequence | |
to the output. Calling "make_ism_func" and supplying the appropriate | |
"prediction_func" would return a function that does ISM scoring | |
(that is, it would return a function that maps one-hot encoded | |
sequence to ISM scores). For example, you can supply keras_model.predict() | |
as a prediction_func. | |
If you don't need ISM scores for the entire sequence (e.g. you just care | |
about +/- 100bp from the center), use the "flank_around_middle_to_perturb" | |
argument - e.g. set it to 100 to get ISM scores only for the central 200bp. | |
Note: in some applications you may want ISM scores on all four possible | |
bases at each position, rather than just the ISM score for the base that | |
happens to be present in the original sequence. If that is what you are | |
after, simply change the line `return input_data_onehot*results_arr` | |
to be just `return results_arr`. | |
Some additional notes on prediction_func: | |
1. Many deep learning models expect the user to supply a *list* | |
of numpy arrays as input. Keras models are such that they can | |
handle both - i.e. if the user supplies a single numpy array, | |
under-the-hood keras will wrap that numpy array into a list of length 1. | |
In the code for make_ism_func, I am assuming that prediction_func accepts | |
(or can accept) a *list* of numpy arrays as the input; keras models satisfy | |
this requirement, but if your predictor is not like this, then you | |
can instead simply do: | |
`prediction_func = lambda x: original_prediction_func(x[0])` | |
to get around the fact that my code below is supplying a list of length 1 to | |
prediction_func. | |
2. The code assumes you do not have multi-modal inputs, though it wouldn't | |
be too hard to extend to the case of multi-modal inputs (you would need | |
to make sure the appropriate values for the additional modes were being | |
supplied alongside the perturbed sequences whenever prediction_func(...) | |
is called). | |
3. If you wanted to extend to one-hot encoded inputs with more than 4 characters, | |
I think you would just need to change the line `for base_idx in range(4):` | |
to have the appropriate number of characters. | |
""" | |
#The list wrapper is a convenience function that makes the code work whether a | |
# user supplies just a single one-hot encoded numpy array or a | |
# list of length 1 containing the numpy array | |
def list_wrapper(func): | |
def wrapped_func(input_data_list, **kwargs): | |
if (isinstance(input_data_list, list)): | |
remove_list_on_return=False | |
else: | |
remove_list_on_return=True | |
input_data_list = [input_data_list] | |
to_return = func(input_data_list=input_data_list, | |
**kwargs) | |
return to_return | |
return wrapped_func | |
def empty_ism_buffer(results_arr, | |
input_data_onehot, | |
perturbed_inputs_preds, | |
perturbed_inputs_info): | |
for perturbed_input_pred,perturbed_input_info\ | |
in zip(perturbed_inputs_preds, perturbed_inputs_info): | |
example_idx = perturbed_input_info[0] | |
if (perturbed_input_info[1]=="original"): | |
results_arr[example_idx] +=\ | |
(perturbed_input_pred*input_data_onehot[example_idx]) | |
else: | |
pos_idx,base_idx = perturbed_input_info[1] | |
results_arr[example_idx,pos_idx,base_idx] = perturbed_input_pred | |
def make_ism_func(prediction_func, | |
flank_around_middle_to_perturb, | |
batch_size=200): | |
@list_wrapper | |
def ism_func(input_data_list, progress_update=10000, **kwargs): | |
assert len(input_data_list)==1 | |
input_data_onehot=input_data_list[0] | |
results_arr = np.zeros_like(input_data_onehot).astype("float64") | |
perturbed_inputs_info = [] | |
perturbed_onehot_seqs = [] | |
perturbed_inputs_preds = [] | |
num_done = 0 | |
for i,onehot_seq in enumerate(input_data_onehot): | |
perturbed_onehot_seqs.append(onehot_seq) | |
perturbed_inputs_info.append((i,"original")) | |
for pos in range(int(len(onehot_seq)/2)-flank_around_middle_to_perturb, | |
int(len(onehot_seq)/2)+flank_around_middle_to_perturb): | |
for base_idx in range(4): | |
if onehot_seq[pos,base_idx]==0: | |
assert len(onehot_seq.shape)==2 | |
new_onehot = np.zeros_like(onehot_seq) + onehot_seq | |
new_onehot[pos,:] = 0 | |
new_onehot[pos,base_idx] = 1 | |
perturbed_onehot_seqs.append(new_onehot) | |
perturbed_inputs_info.append((i,(pos,base_idx))) | |
num_done += 1 | |
if ((progress_update is not None) | |
and num_done%progress_update==0): | |
print("Done",num_done) | |
if (len(perturbed_inputs_info)>=batch_size): | |
empty_ism_buffer( | |
results_arr=results_arr, | |
input_data_onehot=input_data_onehot, | |
perturbed_inputs_preds= | |
prediction_func([perturbed_onehot_seqs]), | |
perturbed_inputs_info=perturbed_inputs_info) | |
perturbed_inputs_info = [] | |
perturbed_onehot_seqs = [] | |
if (len(perturbed_inputs_info)>0): | |
empty_ism_buffer( | |
results_arr=results_arr, | |
input_data_onehot=input_data_onehot, | |
perturbed_inputs_preds= | |
prediction_func([perturbed_onehot_seqs]), | |
perturbed_inputs_info=perturbed_inputs_info) | |
perturbed_inputs_info = [] | |
perturbed_onehot_seqs = [] | |
#mean-normalize the ISM scores at each base. | |
results_arr = results_arr - np.mean(results_arr,axis=-1)[:,:,None] | |
return input_data_onehot*results_arr | |
return ism_func |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment