Skip to content

Instantly share code, notes, and snippets.

@hannes-brt
Last active May 7, 2023 11:36
Show Gist options
  • Save hannes-brt/54ca5d4094b3d96237fa2e820c0945dd to your computer and use it in GitHub Desktop.
Save hannes-brt/54ca5d4094b3d96237fa2e820c0945dd to your computer and use it in GitHub Desktop.
One-hot encoding DNA with TensorFlow
# Copyright 2019 Hannes Bretschneider
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation
# files (the "Software"), to deal in the Software without
# restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following
# conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
# OTHER DEALINGS IN THE SOFTWARE.
import numpy as np
import tensorflow as tf
import twobitreader
import timeit
def tf_dna_encode_lookup_table(seq, name="dna_encode"):
"""Map DNA string inputs to integer ids using a lookup table."""
with tf.name_scope(name):
# Defining the lookup table
mapping_strings = tf.constant(["A", "C", "G", "T"])
table = tf.contrib.lookup.index_table_from_tensor(
mapping=mapping_strings, num_oov_buckets=0, default_value=-1)
# Splitting the string into single characters
seq = tf.squeeze(
tf.sparse.to_dense(
tf.string_split([seq], delimiter=""),
default_value=""), 0)
return table.lookup(seq)
def tf_dna_encode_bit_manipulation(seq, name='dna_encode'):
with tf.name_scope(name):
bytes = tf.decode_raw(seq, tf.uint8)
bytes = tf.bitwise.bitwise_and(bytes, ~(1 << 6))
bytes = tf.bitwise.bitwise_and(bytes, ~(1 << 4))
bytes = tf.bitwise.right_shift(bytes, 1)
mask = tf.bitwise.bitwise_and(bytes, 2)
mask = tf.bitwise.right_shift(mask, 1)
bytes = tf.bitwise.bitwise_xor(bytes, mask)
return bytes
#%%
def tf_dna_encode_embedding_table(dna_input, name="dna_encode"):
"""Map DNA sequence to one-hot encoding using an embedding table."""
# Define the embedding table
_embedding_values = np.zeros([89, 4], np.float32)
_embedding_values[ord('A')] = np.array([1, 0, 0, 0])
_embedding_values[ord('C')] = np.array([0, 1, 0, 0])
_embedding_values[ord('G')] = np.array([0, 0, 1, 0])
_embedding_values[ord('T')] = np.array([0, 0, 0, 1])
_embedding_values[ord('W')] = np.array([.5, 0, 0, .5])
_embedding_values[ord('S')] = np.array([0, .5, .5, 0])
_embedding_values[ord('M')] = np.array([.5, .5, 0, 0])
_embedding_values[ord('K')] = np.array([0, 0, .5, .5])
_embedding_values[ord('R')] = np.array([.5, 0, .5, 0])
_embedding_values[ord('Y')] = np.array([0, .5, 0, .5])
_embedding_values[ord('B')] = np.array([0, 1. / 3, 1. / 3, 1. / 3])
_embedding_values[ord('D')] = np.array([1. / 3, 0, 1. / 3, 1. / 3])
_embedding_values[ord('H')] = np.array([1. / 3, 1. / 3, 0, 1. / 3])
_embedding_values[ord('V')] = np.array([1. / 3, 1. / 3, 1. / 3, 0])
_embedding_values[ord('N')] = np.array([.25, .25, .25, .25])
embedding_table = tf.get_variable(
'dna_lookup_table', _embedding_values.shape,
initializer=tf.constant_initializer(_embedding_values),
trainable=False) # Ensure that embedding table is not trained
with tf.name_scope(name):
dna_input = tf.decode_raw(dna_input, tf.uint8) # Interpret string as bytes
dna_32 = tf.cast(dna_input, tf.int32)
encoded_dna = tf.nn.embedding_lookup(embedding_table, dna_32)
return encoded_dna
#%%
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"genome_file", help="Location to genome 2bit file (hg38)")
parser.add_argument(
"-N", type=int, help="Number of iterations for each method")
parser.add_argument("-r", type=int, help="Number of repeats")
args = parser.parse_args()
# Extract DMD sequence and compute reverse complement
genome = twobitreader.TwoBitFile(args.genome_file)
dmd_sequence = genome['chrX'][31097676:33339441].upper()
def reverse_complement(seq):
return "".join("TGCA"["ACGT".index(s)] for s in seq[::-1])
dmd_sequence_r = reverse_complement(dmd_sequence)
# Set up TensorFlow graph
seq_t = tf.constant(dmd_sequence_r, tf.string)
seq_encoded_bit_manip_t = tf.one_hot(tf_dna_encode_bit_manipulation(seq_t), 4)
seq_encoded_lookup_t = tf.one_hot(tf_dna_encode_lookup_table(seq_t), 4)
seq_encoded_embedding_table_t = tf_dna_encode_embedding_table(seq_t)
# TensorFlow boilerplate
session = tf.Session()
with session.as_default():
tf.tables_initializer().run()
tf.global_variables_initializer().run()
# Now benchmark each method
print("### Benchmarking bit manipulation method ###")
results = timeit.repeat(lambda: session.run(seq_encoded_bit_manip_t),
number=args.N, repeat=args.r)
print("""Bit manipulation method ({} iterations, {} repeats):
Total time: {}
Best time: {}
""".format(args.N, args.r, sum(results), min(results)))
print("### Benchmarking embedding table method ###")
results = timeit.repeat(lambda: session.run(seq_encoded_embedding_table_t),
number=args.N, repeat=args.r)
print("""Embedding table method ({} iterations, {} repeats):
Total time: {}
Best time: {}
""".format(args.N, args.r, sum(results), min(results)))
print("### Benchmarking lookup table method ###")
results = timeit.repeat(lambda: session.run(seq_encoded_lookup_t),
number=args.N, repeat=args.r)
print("""Lookup table method ({} iterations, {} repeats):
Total time: {}
Best time: {}
""".format(args.N, args.r, sum(results), min(results)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment