-
-
Save hussius/04fb172154681b7174c7a923d6600784 to your computer and use it in GitHub Desktop.
Function to decode a COSSMO training example in tfrecord format
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_single_cossmo_example(serialized_example, n_tissues=1, coord_sys='rna1'): | |
"""Decode a single COSSMO example | |
coord_sys must be one of 'rna1' or 'dna0', if 'dna0' then an extra 'strand' field | |
must exist in the tfrecord and is extracted. | |
""" | |
assert coord_sys in ['dna0', 'rna1'] | |
context_features = { | |
'n_alt_ss': tf.FixedLenFeature([], tf.int64), | |
'event_type': tf.FixedLenFeature([], tf.string), | |
'const_seq': tf.FixedLenFeature([2], tf.string), | |
'const_site_id': tf.FixedLenFeature([], tf.string), | |
'const_site_position': tf.FixedLenFeature([], tf.int64), | |
} | |
if coord_sys == 'dna0': | |
context_features['strand'] = tf.FixedLenFeature([], tf.string) | |
sequence_features = { | |
'alt_seq': tf.FixedLenSequenceFeature([2], tf.string), | |
'psi': tf.FixedLenSequenceFeature([n_tissues], tf.float32), | |
'psi_std': tf.FixedLenSequenceFeature([n_tissues], tf.float32), | |
'alt_ss_position': tf.FixedLenSequenceFeature([], tf.int64), | |
'alt_ss_type': tf.FixedLenSequenceFeature([], tf.string) | |
} | |
decoded_features = tf.parse_single_sequence_example( | |
serialized_example, | |
context_features=context_features, | |
sequence_features=sequence_features | |
) | |
return decoded_features | |
def read_data_files(alt_ss_type, input_files, n_tissues=1, | |
num_epochs=None, shuffle=False, sort=True): | |
"""Read and decode a list of COSSMO tfrecord files. | |
Parameters | |
---------- | |
alt_ss_type : One of 'acceptor' or 'donor'. | |
input_files : A list of paths to the tfrecord files. | |
n_tissues : Number of tissues (is always one in our dataset. | |
num_epochs : Number of epochs for which to repeat, or None to repeat forever. | |
shuffle : Whether to shuffle training examples. | |
sort : Whether to sort alternative splice sites from 5' to 3'. Should be true | |
when training recurrent models or whenever the order of splice sites is | |
important. | |
Returns | |
------- | |
decoded_example : dict | |
Fields | |
------ | |
tfrecord_key : (tf.Tensor, shape=(), dtype=string) A unique key for every | |
training example. | |
event_type : (tf.Tensor, shape=(), dtype=string) The event type | |
('acceptor' or 'donor'). | |
const_site_id : (tf.Tensor, shape=(), dtype=string) The constitutive site as | |
'chromosome:strand:position'. | |
const_site_position : (tf.Tensor, shape=(), dtype=int64) The position of the | |
constitutive site. Positions are in "RNA1" format, i.e. forward strand | |
positions are positive and one based and reverse strand positions are | |
negative numbers. | |
const_seq : (tf.Tensor, shape=(2,), dtype=object) The first element is 40nt of | |
the intronic sequence of the constitutive site and the second one is the | |
the exonic sequence. I.e. when event type is 'acceptor', the first element | |
is 40nt upstream of the donor, and the second element is 40nt downstream. | |
When event type is 'donor', the first element is 40nt downstream of the | |
constitutive acceptor and the second element is 40nt upstream of it. Sequence | |
is according to the coding strand. | |
const_dna_seq : (tf.Tensor, shape=(80,), dtype=uint8) This is the sequence from | |
a symmetric window of 80nt around the constitutive splice site, encoded as | |
the ASCII code of the nucleotide (uppercase only). | |
n_alt_ss : (tf.Tensor, shape=(), dtype=int64) The number of alternative splice | |
sites (K). | |
alt_ss_position : (tf.Tensor, shape=(K,), dtype=int64) Position, in RNA1 | |
coordinatesFor each of K alternative splice sites. | |
alt_ss_type : (tf.Tensor, shape=(K,), dtype=string) The type of each of K | |
alternative splice sites. Value is one of: | |
- 'annotated': Splice site is from Gencode v19 annotations. | |
- 'gtex': A de-novo splice site found in GTEx RNA-Seq data. | |
- 'maxent': A "decoy" splice site that has MaxEntScore score >= 3.0, | |
but without RNA-Seq evidence to be used as a splice site. | |
- 'hard_negative': A random genomic location. | |
alt_seq : (tf.Tensor, shape=(K, 2), dtype=string) The first dimension corresponds to | |
K alternative splice site. In each row, the first element is the intronic sequence | |
of the splice site and the second one is the exonic one. I.e. when event type is | |
'acceptor', the first element is 40nt upstream of the acceptor and the second element | |
is 40nt downstream. When event type is 'donor', the first element is 40nt downstream | |
of the donor site and the second one is 40nt upstream. | |
alt_dna_seq : (tf.Tensor, shape=(K, 80), dtype=uint8) For each of K alternative sites, the | |
sequence from a 80nt window around it, encoded as the ASCII code of the nucleotide | |
(uppercase only). | |
rna_seq : (tf.Tensor, shape=(K, 80), dtype=uint8) For each of K alternative sites, this | |
is the "post-splicing" (mRNA) sequence, encoded as the ASCII code of the nucleotide | |
(uppercase only). When event type is 'acceptor' this is the exonic sequence of the | |
constitutive site/donor with the exonic sequence of the alternative site/acceptor, | |
or the reverse when the event type is 'donor'. | |
psi : (tf.Tensor, shape=(K, 1), dtype=float32) The PSI estimated by the positional bootstrap | |
procedure for each alternative splice site . | |
psi_std : (tf.Tensor, shape=(K, 2), dtype=float32) The standard deviation of the PSI estimated | |
by the positional bootstrap procedure for each alternative splice site. | |
""" | |
with tf.name_scope('data_pipeline'): | |
assert (alt_ss_type in ('acceptor', 'donor')) | |
filename_queue = tf.train.string_input_producer( | |
input_files, num_epochs=num_epochs, shuffle=shuffle) | |
file_reader = tf.TFRecordReader() | |
tf_record_key, serialized_example = file_reader.read(filename_queue) | |
_decoded_example = read_single_cossmo_example(serialized_example, | |
n_tissues) | |
decoded_example = _decoded_example[0] | |
decoded_example.update(_decoded_example[1]) | |
if sort: | |
sorted_distance_indices = tf.nn.top_k(-tf.abs(decoded_example['alt_ss_position'] - | |
decoded_example['const_site_position']), | |
k=tf.cast(decoded_example['n_alt_ss'], tf.int32), | |
sorted=True).indices | |
decoded_example['alt_seq'] = tf.gather(decoded_example['alt_seq'], sorted_distance_indices) | |
decoded_example['psi'] = tf.gather(decoded_example['psi'], sorted_distance_indices) | |
decoded_example['psi_std'] = tf.gather(decoded_example['psi_std'], sorted_distance_indices) | |
decoded_example['alt_ss_position'] = tf.gather(decoded_example['alt_ss_position'], sorted_distance_indices) | |
decoded_example['alt_ss_type'] = tf.gather(decoded_example['alt_ss_type'], sorted_distance_indices) | |
decoded_example['tfrecord_key'] = tf_record_key | |
const_exonic_seq, const_intronic_seq = \ | |
tf.split(axis=0, num_or_size_splits=2, value=decoded_example['const_seq']) | |
alt_exonic_seq, alt_intronic_seq = \ | |
tf.split(axis=1, num_or_size_splits=2, value=decoded_example['alt_seq']) | |
alt_exonic_seq = tf.squeeze(alt_exonic_seq, [1]) | |
alt_intronic_seq = tf.squeeze(alt_intronic_seq, [1]) | |
const_exonic_seq = tf.decode_raw(const_exonic_seq, tf.uint8) | |
const_intronic_seq = tf.decode_raw(const_intronic_seq, tf.uint8) | |
alt_exonic_seq = tf.decode_raw(alt_exonic_seq, tf.uint8) | |
alt_intronic_seq = tf.decode_raw(alt_intronic_seq, tf.uint8) | |
tile_multiples = tf.stack( | |
[tf.to_int32(decoded_example['n_alt_ss']), 1]) | |
const_exonic_seq_tiled = tf.tile( | |
const_exonic_seq, tile_multiples | |
) | |
if alt_ss_type == 'acceptor': | |
rna_seq = tf.concat(axis=1, values=[const_exonic_seq_tiled, alt_exonic_seq]) | |
const_dna = tf.squeeze( | |
tf.concat(axis=1, values=[const_exonic_seq, const_intronic_seq]), | |
[0]) | |
alt_dna = tf.concat(axis=1, values=[alt_intronic_seq, alt_exonic_seq]) | |
elif alt_ss_type == 'donor': | |
rna_seq = tf.concat(axis=1, values=[alt_exonic_seq, const_exonic_seq_tiled]) | |
const_dna = tf.squeeze( | |
tf.concat(axis=1, values=[const_intronic_seq, const_exonic_seq]), | |
[0]) | |
alt_dna = tf.concat(axis=1, values=[alt_exonic_seq, alt_intronic_seq]) | |
decoded_example['rna_seq'] = rna_seq | |
decoded_example['const_dna_seq'] = const_dna | |
decoded_example['alt_dna_seq'] = alt_dna | |
return decoded_example | |
if __name__ == '__main__': | |
import tensorflow as tf | |
import os | |
# Get a list of all input files | |
tfrecord_dir = 'local/path/to/tfrecords' | |
files = [os.path.join(tfrecord_dir, f) for f in os.listdir(tfrecord_dir) if f.endswith('tfrecord')] | |
# Read and decode the tfrecords | |
decoded_examples_tensor = read_data_files('acceptor', files) | |
# Get a session and start reading from the queues | |
sess = tf.Session() | |
coordinator = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coordinator) | |
sess.run(tf.local_variables_initializer()) | |
# Training examples can now be read | |
decoded_examples_values = sess.run(decoded_examples_tensor) | |
# ...continue with batching etc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment