Skip to content

Instantly share code, notes, and snippets.

@adamkewley
Created February 25, 2020 09:45
Show Gist options
  • Save adamkewley/54a7cd422ad564f925489beeae608e41 to your computer and use it in GitHub Desktop.
Save adamkewley/54a7cd422ad564f925489beeae608e41 to your computer and use it in GitHub Desktop.
#define UNLIKELY(x) __builtin_expect(!!(x), 0)
#include <iostream>
#include <string>
#include <immintrin.h>
#include <array>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <vector>
using std::istream;
using std::ostream;
using std::runtime_error;
using std::string;
using std::bad_alloc;
using std::array;
using std::thread;
using std::mutex;
using std::atomic;
using std::condition_variable;
using std::unique_lock;
using std::vector;
constexpr size_t basepairs_in_line = 60;
constexpr size_t line_len = basepairs_in_line + sizeof('\n');
// custom vector impl. that has *similar* methods to a
// `vector<char>`. The reason this is necessary is because the stdlib
// `vector<char>` implementation requires that `.resize` initializes
// the newly-allocated content, and that `realloc` cannot be
// used. Valgrind reports that that is ~10-20 % of application
// perf. for large inputs.
class unsafe_vector {
public:
unsafe_vector() {
_buf = (char*)malloc(_capacity);
if (_buf == nullptr) {
throw bad_alloc{};
}
}
unsafe_vector(const unsafe_vector& other) = delete;
unsafe_vector(unsafe_vector&& other) = delete;
unsafe_vector& operator=(unsafe_vector& other) = delete;
unsafe_vector& operator=(unsafe_vector&& other) = delete;
~unsafe_vector() noexcept {
if (_buf != nullptr) {
free(_buf);
}
}
char* data() {
return _buf;
}
void resize(size_t count) {
size_t rem = _capacity - _size;
if (count > rem) {
grow(count);
}
_size = count;
}
size_t size() const {
return _size;
}
private:
void grow(size_t min_cap) {
size_t new_cap = _capacity;
while (new_cap < min_cap) {
new_cap *= 2;
}
char* new_buf = (char*)realloc(_buf, new_cap);
if (new_buf != nullptr) {
_capacity = new_cap;
_buf = new_buf;
} else {
// no need to reset _buf to prevent a double-free by the
// dtor. The POSIX definition of `realloc` states that a
// failed reallocation leaves the supplied pointer
// untouched, so throw here and let the class's destructor
// free the memory.
throw bad_alloc{};
}
}
char* _buf = nullptr;
size_t _size = 0;
size_t _capacity = 1024;
};
// Returns the complement of a a single input character
char complement(char character) {
static const char complement_lut[] = {
'\0', 'T', 'V', 'G',
'H', '\0', '\0', 'C',
'D', '\0', '\n', 'M',
'\0', 'K', 'N', '\0',
'\0', '\0', 'Y', 'S',
'A', 'A', 'B', 'W',
'\0', 'R'
};
return complement_lut[character & 0x1f];
}
__m128i packed(char c) {
return _mm_set_epi8(c, c, c, c, c, c, c, c, c, c, c, c, c, c, c, c);
}
__m128i reverse_complement_simd(__m128i v) {
// these two are standard steps, just like for a single-character
// reverse + complement (without the LUT step)
v = _mm_shuffle_epi8(v, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15));
v = _mm_and_si128(v, packed(0x1f));
// the LUT (above) is >16 els, which sucks, because SSSE3 only has
// a 16-el shuffle. However, the alg can use masking to do the
// lookup twice.
__m128i lt16_mask = _mm_cmplt_epi8(v, packed(16));
__m128i lt16_els = _mm_and_si128(v, lt16_mask);
__m128i lt16_lut = _mm_set_epi8('\0', 'N', 'K', '\0',
'M', '\n', '\0', 'D',
'C', '\0', '\0', 'H',
'G', 'V', 'T', '\0');
__m128i lt16_vals = _mm_shuffle_epi8(lt16_lut, lt16_els);
__m128i g16_els = _mm_sub_epi8(v, packed(16));
__m128i g16_lut = _mm_set_epi8('\0', '\0', '\0', '\0',
'\0', '\0', 'R', '\0',
'W', 'B', 'A', 'A',
'S', 'Y', '\0', '\0');
__m128i g16_vals = _mm_shuffle_epi8(g16_lut, g16_els);
return _mm_or_si128(lt16_vals, g16_vals);
}
// Complement then swap `a` and `b`
void complement_swap(char* a, char* b) {
char tmp = complement(*a);
*a = complement(*b);
*b = tmp;
}
// Reverse-complement a contiguous range, [begin, end), of bps.
//
// precondition: [begin, end) can be reverse-complemented without
// needing to account for newlines etc. (the caller should handle this
// externally).
void reverse_complement_bps(char* start, char* end, size_t num_bps) {
#ifdef SIMD
while (num_bps >= 16) {
end -= 16;
__m128i tmp = _mm_lddqu_si128((__m128i*)start);
_mm_storeu_si128((__m128i*)start, reverse_complement_simd(_mm_lddqu_si128((__m128i*)end)));
_mm_storeu_si128((__m128i*)end, reverse_complement_simd(tmp));
num_bps -= 16;
start += 16;
}
if (num_bps >= 8) {
}
#else
// even when not using platform-dependent SIMD, it's still
// advantageous to unroll this loop because most compilers won't
// (because the compiler can't know that the inputs are usually
// >16 in size at runtime). This gives a ~10 % speedup on my
// laptop.
while (num_bps >= 16) {
for (size_t i = 0; i < 16; ++i) {
complement_swap(start++, --end);
}
num_bps -= 16;
}
#endif
for (size_t i = 0; i < num_bps; ++i) {
complement_swap(start++, --end);
}
}
// Reverse-complements a FASTA sequence. Unformatted basepair (no
// header) input. All lines apart from the last line contain *exactly*
// 60 basepairs. The last line can contain <= 60 basepairs, and must
// have a trailing newline.
//
// The reason this alg. is more complicated than necessary for several
// reasons:
//
// - If newlines were stripped from the input while reading the input,
// then memory usage would be ~1/60th lower and this step would be
// mostly branchless (good). However, writing the output would
// require re-adding the newlines into some intermediate output
// buffer before the application emits the output (very bad).
//
// - If newlines are not stripped from the input, then they need to be
// handled by this step. The easiest (<10 LOC) way to handle the
// newlines is to have an `if (next_char == '\n') skip;` type check
// on the front and back of the input. However, this introduces two
// compare + (sometimes) jump operations per basepair swap, plus the
// main loop invariant, so the resulting loop can end up with 3/4
// branches. It also prevents doing multi-basepair swaps (SIMD, loop
// unrolling, etc.). Even without vectorization, that ends up being
// a 20-35 % perf hit overall.
//
// - So we want to optimize this alg. for branchless, preferably
// multi-basepair, swaps + complements. However, the presence of
// trailing newlines means that the input might be non-symmetric
// (i.e. the data cannot be blindly swapped because the newlines
// will end up in an incorrect location). "Symmetric", in this case,
// means that swapping the newlines can be done safely because they
// are at symmetric offsets relative to the beginning and end of the
// input.
void reverse_complement_basepairs(char* begin, char* end) {
if (UNLIKELY(begin == end)) {
return;
}
size_t len = end - begin;
size_t trailer_len = len % line_len;
// skip end-of-data, so that `end` points to the last newline in
// the input (i.e. "just past the end of the last basepair")
end--;
// optimal case: all lines in the input are exactly `line_len` in
// length, with no trailing bps. The relative offsets (from
// begin/end) of newlines in the data are symmetrical. Therefore,
// The algorithm can just reverse + complement the entire input,
// apart from the last newline.
if (trailer_len == 0) {
size_t num_pairs = len/2;
reverse_complement_bps(begin, end, num_pairs);
bool has_middle_bp = (len % 2) > 0;
if (has_middle_bp) {
begin[num_pairs] = complement(begin[num_pairs]);
}
return;
}
// suboptimal case: the last line in the sequence is < `line_len`
// (it is a "trailing" line). This means that newlines in the
// input appear at non-symmetrical offsets relative to `begin` and
// `end`. Because of this, the algorithm has to carefully step
// over the newlines so that they aren't reversed into an
// incorrect location in the output.
size_t trailer_bps = trailer_len > 0 ? trailer_len - 1 : 0;
size_t rem_bps = basepairs_in_line - trailer_bps;
size_t rem_bytes = rem_bps + 1;
size_t num_whole_lines = len / line_len;
size_t num_steps = num_whole_lines / 2;
// there are at least two whole lines (+ trailer) per iteration of
// this loop. This means that we can revcomp the trailer, skip the
// trailer (+ newline, on the trailer's side), then revcomp the
// remainder, skip the remainder (+newline, on the starting side)
// to maintain the loop invariant.
for (size_t i = 0; i < num_steps; ++i) {
reverse_complement_bps(begin, end, trailer_bps);
begin += trailer_bps;
end -= trailer_len;
reverse_complement_bps(begin, end, rem_bps);
begin += rem_bytes;
end -= rem_bps;
}
// there may be one whole line (+ trailer) remaining. In this
// case, we do the first step of the above (revcomp the trailer)
// but *not* the second (revcomp the remainder) because the
// remainder will overlap in both directions.
bool has_unpaired_line = (num_whole_lines % 2) > 0;
if (has_unpaired_line) {
reverse_complement_bps(begin, end, trailer_bps);
begin += trailer_bps;
end -= trailer_len;
}
// no *whole* lines remaining, but there may be not-multiline
// remaining. revcomp these bytes.
size_t bps_in_last_line = end - begin;
size_t swaps_in_last_line = bps_in_last_line/2;
reverse_complement_bps(begin, end, swaps_in_last_line);
// edge case: there is exactly one byte in the middle of the input
// that needs to be complemented but not swapped with anything.
bool has_unpaired_byte = (bps_in_last_line % 2) > 0;
if (has_unpaired_byte) {
begin[swaps_in_last_line] = complement(begin[swaps_in_last_line]);
}
}
void read_up_to(istream& in, unsafe_vector& out, char delim) {
constexpr size_t read_size = 1<<16;
size_t bytes_read = 0;
out.resize(read_size);
while (in) {
in.getline(out.data() + bytes_read, read_size, delim);
bytes_read += in.gcount();
if (in.fail()) {
// failed because it ran out of buffer space. Expand the
// buffer and perform another read
out.resize(bytes_read + read_size);
in.clear(in.rdstate() & ~std::ios::failbit);
} else if (in.eof()) {
// hit EOF, rather than delmiter, but an EOF can be
// treated almost identially to a delmiter, except that we
// don't remove the delimiter from the read buffer.
break;
} else {
// succeeded in reading *up to and including* the sequence
// delimiter. Remove the delmiter.
--bytes_read;
break;
}
}
out.resize(bytes_read);
}
struct Sequence {
string header; // not incl. starting delim (>)
unsafe_vector seq; // basepair lines. all lines terminated by newline
};
// Read a sequence, starting *after* the first delimiter (>)
void read_sequence(istream& in, Sequence& out) {
out.header.resize(0);
std::getline(in, out.header);
read_up_to(in, out.seq, '>');
}
void reverse_complement(Sequence& s) {
reverse_complement_basepairs(s.seq.data(), s.seq.data() + s.seq.size());
}
void write_sequence(ostream& out, Sequence& s) {
out << '>';
out << s.header;
out << '\n';
out.write(s.seq.data(), s.seq.size());
}
int main() {
// required for *large* (e.g. 1 GiB) inputs
std::cin.sync_with_stdio(false);
std::cout.sync_with_stdio(false);
// the read function assumes that '>' has already been read
// (because istream::getline will read it per loop iteration:
// prevents needing to 'peek' a bunch).
if (std::cin.get() != '>') {
throw runtime_error{"unexpected input: next char should be the start of a seqence header"};
}
constexpr size_t num_workers = 12;
std::vector<thread> workers;
mutex read_mutex;
atomic<size_t> read_idx{0};
mutex write_mutex;
atomic<size_t> write_idx{0};
condition_variable write_condvar;
for (size_t i = 0; i < num_workers; ++i) {
auto f = [&]() {
Sequence s;
size_t idx = -1;
for (;;) {
{
unique_lock<mutex> l{read_mutex};
if (std::cin.eof()) {
return;
}
read_sequence(std::cin, s);
idx = read_idx++;
}
reverse_complement(s);
{
unique_lock<mutex> l{write_mutex};
write_condvar.wait(l, [&]() { return write_idx == idx; });
write_sequence(std::cout, s);
std::cout.flush();
++write_idx;
write_condvar.notify_one();
}
}
};
workers.emplace_back(thread{f});
}
for (auto& worker : workers) {
worker.join();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment