Skip to content

Instantly share code, notes, and snippets.

@AvantiShri
Last active August 1, 2020 21:44
Show Gist options
  • Save AvantiShri/2d166f201716d8d019c979b32dc70767 to your computer and use it in GitHub Desktop.
Save AvantiShri/2d166f201716d8d019c979b32dc70767 to your computer and use it in GitHub Desktop.
Code for running In-Silico Mutagenesis (ISM)
"""
Author: Avanti Shrikumar
Here's a gist containing code to run in-silico mutagenesis (ISM)
on a model that takes one-hot encoded DNA sequence as the input. The
ISM score at a base is the prediction when that base is
present minus the average prediction across all 4 possible bases at
at that position.
"prediction_func" needs to be a function that maps one-hot encoded sequence
to the output. Calling "make_ism_func" and supplying the appropriate
"prediction_func" would return a function that does ISM scoring
(that is, it would return a function that maps one-hot encoded
sequence to ISM scores). For example, you can supply keras_model.predict()
as a prediction_func.
If you don't need ISM scores for the entire sequence (e.g. you just care
about +/- 100bp from the center), use the "flank_around_middle_to_perturb"
argument - e.g. set it to 100 to get ISM scores only for the central 200bp.
Note: in some applications you may want ISM scores on all four possible
bases at each position, rather than just the ISM score for the base that
happens to be present in the original sequence. If that is what you are
after, simply change the line `return input_data_onehot*results_arr`
to be just `return results_arr`.
Some additional notes on prediction_func:
1. Many deep learning models expect the user to supply a *list*
of numpy arrays as input. Keras models are such that they can
handle both - i.e. if the user supplies a single numpy array,
under-the-hood keras will wrap that numpy array into a list of length 1.
In the code for make_ism_func, I am assuming that prediction_func accepts
(or can accept) a *list* of numpy arrays as the input; keras models satisfy
this requirement, but if your predictor is not like this, then you
can instead simply do:
`prediction_func = lambda x: original_prediction_func(x[0])`
to get around the fact that my code below is supplying a list of length 1 to
prediction_func.
2. The code assumes you do not have multi-modal inputs, though it wouldn't
be too hard to extend to the case of multi-modal inputs (you would need
to make sure the appropriate values for the additional modes were being
supplied alongside the perturbed sequences whenever prediction_func(...)
is called).
3. If you wanted to extend to one-hot encoded inputs with more than 4 characters,
I think you would just need to change the line `for base_idx in range(4):`
to have the appropriate number of characters.
"""
#The list wrapper is a convenience function that makes the code work whether a
# user supplies just a single one-hot encoded numpy array or a
# list of length 1 containing the numpy array
def list_wrapper(func):
def wrapped_func(input_data_list, **kwargs):
if (isinstance(input_data_list, list)):
remove_list_on_return=False
else:
remove_list_on_return=True
input_data_list = [input_data_list]
to_return = func(input_data_list=input_data_list,
**kwargs)
return to_return
return wrapped_func
def empty_ism_buffer(results_arr,
input_data_onehot,
perturbed_inputs_preds,
perturbed_inputs_info):
for perturbed_input_pred,perturbed_input_info\
in zip(perturbed_inputs_preds, perturbed_inputs_info):
example_idx = perturbed_input_info[0]
if (perturbed_input_info[1]=="original"):
results_arr[example_idx] +=\
(perturbed_input_pred*input_data_onehot[example_idx])
else:
pos_idx,base_idx = perturbed_input_info[1]
results_arr[example_idx,pos_idx,base_idx] = perturbed_input_pred
def make_ism_func(prediction_func,
flank_around_middle_to_perturb,
batch_size=200):
@list_wrapper
def ism_func(input_data_list, progress_update=10000, **kwargs):
assert len(input_data_list)==1
input_data_onehot=input_data_list[0]
results_arr = np.zeros_like(input_data_onehot).astype("float64")
perturbed_inputs_info = []
perturbed_onehot_seqs = []
perturbed_inputs_preds = []
num_done = 0
for i,onehot_seq in enumerate(input_data_onehot):
perturbed_onehot_seqs.append(onehot_seq)
perturbed_inputs_info.append((i,"original"))
for pos in range(int(len(onehot_seq)/2)-flank_around_middle_to_perturb,
int(len(onehot_seq)/2)+flank_around_middle_to_perturb):
for base_idx in range(4):
if onehot_seq[pos,base_idx]==0:
assert len(onehot_seq.shape)==2
new_onehot = np.zeros_like(onehot_seq) + onehot_seq
new_onehot[pos,:] = 0
new_onehot[pos,base_idx] = 1
perturbed_onehot_seqs.append(new_onehot)
perturbed_inputs_info.append((i,(pos,base_idx)))
num_done += 1
if ((progress_update is not None)
and num_done%progress_update==0):
print("Done",num_done)
if (len(perturbed_inputs_info)>=batch_size):
empty_ism_buffer(
results_arr=results_arr,
input_data_onehot=input_data_onehot,
perturbed_inputs_preds=
prediction_func([perturbed_onehot_seqs]),
perturbed_inputs_info=perturbed_inputs_info)
perturbed_inputs_info = []
perturbed_onehot_seqs = []
if (len(perturbed_inputs_info)>0):
empty_ism_buffer(
results_arr=results_arr,
input_data_onehot=input_data_onehot,
perturbed_inputs_preds=
prediction_func([perturbed_onehot_seqs]),
perturbed_inputs_info=perturbed_inputs_info)
perturbed_inputs_info = []
perturbed_onehot_seqs = []
#mean-normalize the ISM scores at each base.
results_arr = results_arr - np.mean(results_arr,axis=-1)[:,:,None]
return input_data_onehot*results_arr
return ism_func
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment