Last active
March 1, 2016 21:31
-
-
Save lorenzhs/0ad22028cf6b68a91f3f to your computer and use it in GitHub Desktop.
Unsafe boost::mpi things
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
// Copyright (C) 2015 Lorenz Hübschle-Schneider <[email protected]> | |
// MIT License | |
#include <cstdint> | |
#include <type_traits> | |
#include <vector> | |
#include <boost/mpi.hpp> | |
#include <boost/serialization/string.hpp> | |
#include <boost/serialization/utility.hpp> | |
#include <boost/serialization/vector.hpp> | |
namespace mpi = boost::mpi; | |
template <typename T> | |
struct is_trivial_enough : public std::is_trivial<T> {}; | |
// Pairs are trivial enough [TM] if both components are trivial enough [TM] | |
template <typename U, typename V> | |
struct is_trivial_enough<std::pair<U,V>> : | |
public std::integral_constant<bool, | |
is_trivial_enough<U>::value && is_trivial_enough<V>::value | |
> {}; | |
// Sort-of unsafe MPI operations on "trivial enough" (i.e. standard layout) data | |
// mostly because std::pair isn't technically a trivial type, but we want to treat it like one | |
template <typename T, bool trivial = is_trivial_enough<T>::value, typename transmit_type = uint64_t> | |
struct unsafe_mpi { | |
static_assert(!trivial || ((size_t)(sizeof(T)/sizeof(transmit_type))) | |
* sizeof(transmit_type) == sizeof(T), | |
"Invalid transmit_type for element type (sizeof(transmit_type) is not a multiple of sizeof(T))"); | |
static void broadcast(const mpi::communicator &comm, std::vector<T> &data, int root) { | |
if (comm.size() < 2) return; | |
if (trivial) { | |
int size = data.size(); | |
// broadcast size and allocate space | |
mpi::broadcast<int>(comm, size, root); | |
data.resize(size); | |
// broadcast elements as transmit_type | |
mpi::broadcast(comm, reinterpret_cast<transmit_type*>(data.data()), size*sizeof(T)/sizeof(transmit_type), root); | |
} else if (mpi::is_mpi_datatype<T>()) { | |
// We can use Boost.MPI directly to transmit MPI datatypes | |
// I don't think this codepath will ever be called, as all | |
// native MPI datatypes should be trivial enough. | |
mpi::broadcast(comm, data, root); | |
} else { | |
// Boost.MPI doesn't use MPI_Broadcast for types it doesn't know. WTF. | |
// Therefore, we need to do the archive broadcast ourselves. | |
if (comm.rank() == root) { | |
// Serialize data | |
mpi::packed_oarchive oa(comm); | |
oa << data; | |
// Broadcast archive size | |
size_t archive_size = oa.size(); | |
mpi::broadcast<size_t>(comm, archive_size, root); | |
// Broadcast archive data | |
auto sendptr = const_cast<void*>(oa.address()); | |
MPI_Bcast(sendptr, archive_size, MPI_PACKED, root, comm); | |
} else { | |
// Receive archive size and allocate space | |
size_t archive_size; | |
mpi::broadcast<size_t>(comm, archive_size, root); | |
mpi::packed_iarchive ia(comm); | |
ia.resize(archive_size); | |
// Receive broadcast archive data | |
auto recvptr = ia.address(); | |
MPI_Bcast(recvptr, archive_size, MPI_PACKED, root, comm); | |
// Unpack received data | |
ia >> data; | |
} | |
} | |
} | |
static void allgatherv(const mpi::communicator &comm, const std::vector<T> &in, std::vector<T> &out) { | |
// Trivial (enough) datatypes can be transmit directly via MPI_Allgatherv | |
// For all others, we have to serialize them using boost::serialize | |
if (trivial) { | |
allgatherv_unsafe(comm, in, out); | |
} else { | |
allgatherv_serialize(comm, in, out); | |
} | |
} | |
static void allgatherv_serialize(const mpi::communicator &comm, const std::vector<T> &in, std::vector<T> &out) { | |
// Step 1: serialize input data | |
mpi::packed_oarchive oa(comm); | |
oa << in; | |
// Step 2: exchange sizes (archives' .size() is measured in bytes) | |
// Need to cast to int because this is what MPI uses as size_t... | |
const int in_size = static_cast<int>(in.size()), | |
transmit_size = static_cast<int>(oa.size()); | |
std::vector<int> in_sizes(comm.size()), transmit_sizes(comm.size()); | |
mpi::all_gather(comm, in_size, in_sizes.data()); | |
mpi::all_gather(comm, transmit_size, transmit_sizes.data()); | |
// Step 3: calculate displacements from sizes (prefix sum) | |
std::vector<int> displacements(comm.size() + 1); | |
displacements[0] = sizeof(mpi::packed_iarchive); | |
for (int i = 1; i <= comm.size(); ++i) { | |
displacements[i] = displacements[i-1] + transmit_sizes[i-1]; | |
} | |
// Step 4: allocate space for result and MPI_Allgatherv | |
char* recv = new char[displacements.back()]; | |
auto sendptr = const_cast<void*>(oa.address()); | |
auto sendsize = oa.size(); | |
int status = MPI_Allgatherv(sendptr, sendsize, MPI_PACKED, recv, | |
transmit_sizes.data(), displacements.data(), | |
MPI_PACKED, comm); | |
if (status != 0) { | |
std::cerr << "PE " << comm.rank() << ": MPI_Allgatherv returned " | |
<< status << ", errno " << errno << std::endl; | |
return; | |
} | |
// Step 5: deserialize received data | |
// Preallocate storage to prevent reallocations | |
std::vector<T> temp; | |
size_t largest_size = *std::max_element(in_sizes.begin(), in_sizes.end()); | |
temp.reserve(largest_size); | |
out.reserve(std::accumulate(in_sizes.begin(), in_sizes.end(), 0)); | |
// Deserialize archives one by one, inserting elements at the end of ̀out̀ | |
for (int i = 0; i < comm.size(); ++i) { | |
mpi::packed_iarchive archive(comm); | |
archive.resize(transmit_sizes[i]); | |
memcpy(archive.address(), recv + displacements[i], transmit_sizes[i]); | |
temp.clear(); | |
temp.resize(in_sizes[i]); | |
archive >> temp; | |
out.insert(out.end(), temp.begin(), temp.end()); | |
} | |
} | |
static void allgatherv_unsafe(const mpi::communicator &comm, const std::vector<T> &in, std::vector<T> &out) { | |
// Step 1: exchange sizes | |
// We need to compute the displacement array, specifying for each PE | |
// at which position in out to place the data received from it | |
// Need to cast to int because this is what MPI uses as size_t... | |
const int factor = sizeof(T) / sizeof(transmit_type); | |
const int in_size = static_cast<int>(in.size()) * factor; | |
std::vector<int> sizes(comm.size()); | |
mpi::all_gather(comm, in_size, sizes.data()); | |
// Step 2: calculate displacements from sizes | |
// Compute prefix sum to compute displacements from sizes | |
std::vector<int> displacements(comm.size() + 1); | |
displacements[0] = 0; | |
std::partial_sum(sizes.begin(), sizes.end(), displacements.begin() + 1); | |
// divide by factor by which T is larger than transmit_type | |
out.resize(displacements.back() / factor); | |
// Step 3: MPI_Allgatherv | |
const transmit_type *sendptr = reinterpret_cast<const transmit_type*>(in.data()); | |
transmit_type *recvptr = reinterpret_cast<transmit_type*>(out.data()); | |
const MPI_Datatype datatype = mpi::get_mpi_datatype<transmit_type>(); | |
int status = MPI_Allgatherv(sendptr, in_size, datatype, recvptr, | |
sizes.data(), displacements.data(), | |
datatype, comm); | |
if (status != 0) { | |
std::cerr << "PE " << comm.rank() << ": MPI_Allgatherv returned " | |
<< status << ", errno " << errno << std::endl; | |
} | |
} | |
// Send `size` elements of type `T` starting at `data` to `dest` via `comm` with `tag`, | |
// using trivial type `transmit_type` if `T` is Standard Layout | |
static void send(const mpi::communicator &comm, int dest, int tag, const T *data, const size_t size) { | |
// send size | |
comm.send(dest, tag, size); | |
// send actual data | |
if (trivial) { | |
comm.send(dest, tag, reinterpret_cast<const transmit_type*>(data), size*sizeof(T)/sizeof(transmit_type)); | |
} else { | |
comm.send(dest, tag, data, size); | |
} | |
} | |
// convenience wrapper for vectors | |
static void send(const mpi::communicator &comm, int dest, int tag, const std::vector<T> &data) { | |
send(comm, dest, tag, data.data(), data.size()); | |
} | |
static void recv(const mpi::communicator &comm, int src, int tag, std::vector<T> &data) { | |
auto size = data.size(); // for the type deduction | |
// receive size and resize | |
comm.recv(src, tag, size); | |
data.resize(size); | |
// receive actual data | |
if (trivial) { | |
comm.recv(src, tag, reinterpret_cast<transmit_type*>(data.data()), size*sizeof(T)/sizeof(transmit_type)); | |
} else { | |
comm.recv(src, tag, data.data(), size); | |
} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment