Skip to content

Instantly share code, notes, and snippets.

@pgoodman
Last active August 3, 2020 01:59
Show Gist options
  • Save pgoodman/4b74b95732f18f31b4cd17efda040c4d to your computer and use it in GitHub Desktop.
Save pgoodman/4b74b95732f18f31b4cd17efda040c4d to your computer and use it in GitHub Desktop.
// Copyright 2020 Peter Goodman, all rights reserved.
#include <cassert>
#include <cstdint>
#include <cstddef>
#include <iostream>
#include <type_traits>
#include <vector>
template<typename IntegralType_, unsigned kNumBits_, unsigned kShift_>
struct AddressBits {
using IntegralType = IntegralType_;
static constexpr IntegralType kShift = kShift_;
static constexpr IntegralType kNumBits = kNumBits_;
// Mask for selecting only the bits of an address that belongs to this component.
static constexpr IntegralType kMask =
((~IntegralType(0)) >> (sizeof(IntegralType) * 8 - kNumBits)) << kShift;
static constexpr IntegralType kNumEntries = (kMask >> kShift) + 1;
static_assert(kNumEntries == (1ull << kNumBits));
// Return `true` if this component appears to be valid.
static bool IsValid(IntegralType addr) noexcept {
return true;
}
// Extract the bits of the address relevant to this component.
static IntegralType Extract(IntegralType addr) noexcept {
return (addr & kMask) >> kShift;
}
};
// We're in some intermediate portion of the address, at `AddrCompType`, where our parent
// node is `ParentType` and we still need to process the rest of the address,
// `AddressCompTypes...`.
template<typename Tag, typename ParentType, typename AddrCompType,
typename ... AddressCompTypes>
struct AddressWalker {
static inline uint8_t *Walk(ParentType &parent,
typename AddrCompType::IntegralType bits,
Tag tag) noexcept {
const auto index = AddrCompType::Extract(bits);
if (const auto child = parent.Enter(index, AddrCompType::kNumEntries, tag); child) {
using ChildType = std::remove_reference_t<decltype(*child)>;
return AddressWalker<Tag, ChildType, AddressCompTypes...>::Walk(*child, bits, tag);
} else {
return nullptr;
}
}
};
// We're at the least significant bits of the address where we should have actual mapped bytes.
template<typename Tag, typename ParentType, typename AddrCompType>
struct AddressWalker<Tag, ParentType, AddrCompType> {
static inline uint8_t *Walk(ParentType &parent,
typename AddrCompType::IntegralType bits,
Tag tag) noexcept {
const auto index = AddrCompType::Extract(bits);
return parent.PointerTo(index, AddrCompType::kNumEntries, tag);
}
};
// An address, parameterized by its components. The first component represents the highest
// order bits of the integral representation of the address. The last component represents
// the lowest order bits.
template<typename FirstAddrPart, typename ... AddrParts>
struct Address {
static constexpr unsigned kNumBits = (FirstAddrPart::kNumBits + ... + AddrParts::kNumBits);
static constexpr unsigned kNumBytes = kNumBits / 8;
static_assert((kNumBytes * 8) == kNumBits);
using IntegralType = typename FirstAddrPart::IntegralType;
static_assert(sizeof(IntegralType) == kNumBytes);
// Get the last tyoe in the parameter pack, which should represent the page level type,
// and figure out our page granularity.
using PageType = decltype((AddrParts(), ...));
static constexpr IntegralType kPageMask = PageType::kMask;
static constexpr IntegralType kPageSize = kPageMask + 1;
inline Address(void) noexcept
: bits(0) {
}
inline Address(IntegralType bits_) noexcept
: bits(bits_) {
}
template<typename T>
inline Address(T *ptr)
: bits(static_cast<IntegralType>(reinterpret_cast<uintptr_t>(ptr))) {}
bool IsValid(void) const noexcept {
return (FirstAddrPart::IsValid(bits) && ... && AddrParts::IsValid(bits));
}
template<typename ParentType, typename Tag>
uint8_t *Walk(ParentType &root, Tag tag) {
return AddressWalker<Tag, ParentType, FirstAddrPart, AddrParts...>::Walk(
root, bits, tag);
}
IntegralType bits;
};
// 64-bit addresses on x86 must be canonical. That means bit 47 must match bits [48:63].
struct ReservedBits64 : public AddressBits<uint64_t, 16, 48> {
// Check for a canonical address.
static bool IsValid(uint64_t addr) noexcept {
const auto high_16bits = static_cast<uint16_t>(addr >> 48);
const auto bit_47 = static_cast<uint16_t>((addr >> 47) & 1u);
return !(high_16bits + bit_47);
}
};
// Specialize the address walker to ignore the high 16 bits.
template<typename Tag, typename ParentType, typename ... AddressCompTypes>
struct AddressWalker<Tag, ParentType, ReservedBits64, AddressCompTypes...> :
public AddressWalker<Tag, ParentType, AddressCompTypes...> {
};
// NOTE(pag): Using a struct here instead of `typedef` or `using` means we'll see slightly
// nicer types show up in any compiler errors.
struct PageOffset64 : public AddressBits<uint64_t, 12, 0> {};
struct PTIndex64 : public AddressBits<uint64_t, 9, 12> {};
struct PDIndex64 : public AddressBits<uint64_t, 9, 21> {};
struct PDPIndex64 : public AddressBits<uint64_t, 9, 30> {};
struct PML4Index64 : public AddressBits<uint64_t, 9, 39> {};
using VA64 = Address<ReservedBits64, PML4Index64, PDPIndex64, PDIndex64, PTIndex64, PageOffset64>;
struct ResolveTag {};
struct CheckReadTag {};
struct CheckReadWriteTag {};
struct Permissions {
bool can_read { false };
bool can_write { false };
bool can_execute { false };
};
template<typename VA>
class Memory {
private:
// The bytes backing the actual pages, along with their permissions.
std::vector<std::pair<Permissions, std::vector<uint8_t>>> bytes;
// Opaque, 1-based offsets back into itself or 1-based indices into `bytes`.
// A zero value represents an entry with no children.
//
// TODO(pag): 32-bit values should be enough for anyone :-P
std::vector<uint32_t> ptes;
public:
Memory(void)
: ptes() {
ptes.reserve(1024 * 1024);
ptes.push_back(0); // Base case for the visitor's constructor.
}
using IntegralType = typename VA::IntegralType;
template<typename Tag>
uint8_t *Walk(IntegralType addr, Tag tag) {
Visitor vis(*this);
return VA(addr).Walk(vis, tag);
}
// Descends through the levels of an address.
struct Visitor {
explicit Visitor(Memory<VA> &memory_)
: memory(memory_),
ptr_to_index(&(memory.ptes[0])) {}
Memory<VA> &memory;
// Points into `memory.ptes`.
uint32_t *ptr_to_index;
// Enter into an address with the `Permissions` tag, which we use to either change
// or initialize a page and its permissions.
inline Visitor *Enter(IntegralType offset, IntegralType size,
Permissions new_perms) {
if (auto pte_offset = *ptr_to_index; pte_offset) {
ptr_to_index = &(ptr_to_index[pte_offset - 1 + offset]);
return this;
} else {
return EnterSlow(offset, size, new_perms);
}
}
// Slow path to the above function, where an entry in our `ptes` vector doesn't
// exist.
__attribute__((noinline))
Visitor *EnterSlow(IntegralType offset, IntegralType size,
Permissions new_perms) {
const auto index_of_pte_offset = static_cast<uint32_t>(ptr_to_index
- &(memory.ptes[0]));
const auto curr_size = static_cast<uint32_t>(memory.ptes.size());
const auto new_size = curr_size + size;
memory.ptes.resize(new_size);
// Resizing `memory.ptes` might have invalidated `ptr_to_index`, so we'll re-
// compute it.
ptr_to_index = &(memory.ptes[index_of_pte_offset]);
const auto pte_offset = (curr_size - index_of_pte_offset) + 1;
*ptr_to_index = pte_offset;
ptr_to_index = &(ptr_to_index[pte_offset - 1 + offset]);
return this;
}
// Generic traversal method for all other tags (e.g. permission checking).
template<typename Tag>
inline Visitor *Enter(IntegralType offset, IntegralType size, Tag) {
if (auto pte_offset = *ptr_to_index; pte_offset) {
ptr_to_index = &(ptr_to_index[pte_offset - 1 + offset]);
return this;
} else {
return nullptr;
}
}
// We've descended all the way down to the page level, and all we care about is
// resolving a pointer to the byte if it's mapped.
inline uint8_t *PointerTo(IntegralType offset, IntegralType size,
ResolveTag) {
if (auto pte_offset = *ptr_to_index; pte_offset) {
return &(memory.bytes[pte_offset].second[offset]);
} else {
return nullptr;
}
}
// We've descended all the way down to the page level, and all we care about is
// resolving a pointer to the byte if it's mapped and it is marked as readable.
inline uint8_t *PointerTo(IntegralType offset, IntegralType size,
CheckReadTag) {
if (auto pte_offset = *ptr_to_index; pte_offset) {
if (auto &page = memory.bytes[pte_offset - 1]; page.first.can_read) {
return &(page.second[offset]);
}
}
return nullptr;
}
// We've descended all the way down to the page level, and all we care about is
// resolving a pointer to the byte if it's mapped and it is marked as readable
// and writable.
inline uint8_t *PointerTo(IntegralType offset, IntegralType size,
CheckReadWriteTag) {
if (auto pte_offset = *ptr_to_index; pte_offset) {
if (auto &page = memory.bytes[pte_offset - 1]; page.first.can_read
&& page.first.can_write) {
return &(page.second[offset]);
}
}
return nullptr;
}
// We've descended all the way down to the page level, and we want to change the
// permissions of this page, or create and initialize the permissions of this page
// if it doesn't exist.
inline uint8_t *PointerTo(IntegralType offset, IntegralType size,
Permissions new_perms) {
if (auto pte_offset = *ptr_to_index; pte_offset) {
auto &page = memory.bytes[pte_offset - 1];
page.first = new_perms;
return nullptr;
} else {
return PointerToSlow(offset, size, new_perms);
}
}
// Slow path of the above function; the page doesn't exist to add in the backing
// bytes.
__attribute__((noinline))
uint8_t *PointerToSlow(IntegralType offset, IntegralType size,
Permissions new_perms) {
std::vector<uint8_t> data(VA::kPageSize);
memory.bytes.emplace_back(new_perms, std::move(data));
*ptr_to_index = static_cast<uint32_t>(memory.bytes.size());
return nullptr;
}
};
friend struct Visitor;
};
template<typename VA>
class MMU {
public:
using IntegralType = typename VA::IntegralType;
// Get a raw pointer to the data if it's mapped.
uint8_t *PointerTo(VA addr) {
return memory.Walk(addr.bits, ResolveTag());
}
// Get a raw pointer to the data if it's mapped and if it is readable.
const uint8_t *ReadPointerTo(VA addr) {
return memory.Walk(addr.bits, CheckReadTag());
}
// Get a raw pointer to the data if it's mapped and if it is readable and writable.
uint8_t *WritePointerTo(VA addr) {
return memory.Walk(addr.bits, CheckReadWriteTag());
}
// Map a page range.
void MapRange(VA begin, IntegralType size, Permissions perms) {
const auto addr = begin.bits & ~VA::kPageMask;
const auto max_offset = (size + VA::kPageMask) & ~VA::kPageMask;
for (IntegralType offset = 0; offset < max_offset; offset +=
VA::kPageSize) {
memory.Walk(addr + offset, perms);
}
}
private:
Memory<VA> memory;
};
int main(void) {
int x;
VA64 addr_of_x(&x);
static_assert(sizeof(VA64) == sizeof(VA64::IntegralType));
MMU<VA64> mmu;
std::cout << "Num bits in address: " << VA64::kNumBits << std::endl
<< "Page size: " << VA64::kPageSize << std::endl << "Is valid: "
<< addr_of_x.IsValid() << std::endl;
mmu.MapRange(0x1909eff29000, 0x1000, { true, true, true });
mmu.MapRange(0x3465a698000, 0x1000, { true, false, true });
mmu.MapRange(0x1c2aac84e000, 0x1000, { true, false, true });
mmu.MapRange(0x2229e3956000, 0x1000, { true, true, false });
if (mmu.WritePointerTo(0x1909eff29001)) {
std::cout << "Address 0x1909eff29001 is writable" << std::endl;
} else {
assert(false);
}
if (mmu.WritePointerTo(0x3465a698001)) {
std::cout << "Address 0x3465a6980001 is writable" << std::endl;
} else {
mmu.MapRange(0x3465a698000, 0x1000, { true, true, true });
if (auto write_ptr = mmu.WritePointerTo(0x3465a698001); write_ptr) {
std::cout << "Address 0x3465a6980001 is NOW writable" << std::endl;
*write_ptr = 1;
mmu.MapRange(0x3465a698000, 0x1000, { true, false, false });
if (mmu.WritePointerTo(0x3465a698001)) {
assert(false);
} else {
std::cout << "Address 0x3465a6980001 is no longer writable"
<< std::endl;
if (auto read_ptr = mmu.ReadPointerTo(0x3465a698001); read_ptr) {
std::cout << "Value read from 0x3465a6980001 is "
<< unsigned(*read_ptr) << std::endl;
} else {
assert(false);
}
}
} else {
assert(false);
}
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment