Skip to content

Instantly share code, notes, and snippets.

@mrdomino
Last active March 13, 2024 23:37
Show Gist options
  • Save mrdomino/2aba53d3849f5fa050993265d9734ae8 to your computer and use it in GitHub Desktop.
Save mrdomino/2aba53d3849f5fa050993265d9734ae8 to your computer and use it in GitHub Desktop.
#pragma once
#include <cassert>
#include <cstdint>
#include <exception>
#include <new>
#include <type_traits>
#include <utility>
namespace detail {
class RcControl {
public:
constexpr RcControl(): shared(0), weak(0) {}
virtual ~RcControl() noexcept {}
void add_shared() noexcept {
++shared;
}
void release_shared() noexcept {
if (--shared == -1) {
on_zero_shared();
release_weak();
}
}
void add_weak() noexcept {
++weak;
}
void release_weak() noexcept {
if (--weak == -1) {
on_zero_weak();
}
}
int32_t use_count() const noexcept {
return shared + 1;
}
int32_t weak_count() const noexcept {
return weak;
}
private:
virtual void on_zero_shared() noexcept = 0;
virtual void on_zero_weak() noexcept = 0;
int32_t shared;
int32_t weak;
};
template <typename T>
class RcPointer: public RcControl {
public:
explicit constexpr RcPointer(T* const ptr) noexcept: ptr(ptr) {}
private:
void on_zero_shared() noexcept override {
delete ptr;
}
void on_zero_weak() noexcept override {
delete this;
}
T* ptr;
};
template <typename T>
class RcEmplace: public RcControl {
public:
constexpr RcEmplace() {}
T* t() noexcept {
return std::launder(reinterpret_cast<T*>(blob));
}
template <typename... Args>
void construct(Args&&... args) {
::new ((void*)t()) T(std::forward<Args>(args)...);
}
private:
void on_zero_shared() noexcept override {
t()->~T();
}
void on_zero_weak() noexcept override {
delete this;
}
alignas(T) char blob[sizeof(T)];
};
template <typename T>
struct Hold {
T* ptr;
explicit constexpr Hold(T* const ptr) noexcept: ptr(ptr) {}
~Hold() noexcept {
delete ptr;
}
};
} // namespace detail
// Implements a (non-atomic) reference-counted smart pointer. The main
// difference from std::shared_ptr (aside from not being atomic, nor supporting
// custom disposers / allocators) is that instead of
// std::enable_shared_from_this, we pass a weak this parameter to T's
// constructor via makeRc.
//
template <typename T>
class Rc;
// Implements weak references to Rc<T> to allow breaking refcount cycles.
//
template <typename T>
class Weak;
// Thrown when trying to construct an Rc from an expired Weak.
//
class BadWeak: public std::exception {};
template <typename T>
class Rc {
public:
constexpr Rc(nullptr_t = nullptr) noexcept: ptr(nullptr), rc(nullptr) {}
template <typename U>
requires std::is_convertible_v<U, T>
explicit Rc(U* const nPtr) {
assert(nPtr);
auto hold = detail::Hold { nPtr };
rc = new detail::RcPointer { hold.ptr };
ptr = std::exchange(hold.ptr, nullptr);
}
template <typename U>
requires std::is_convertible_v<U, T>
explicit Rc(Weak<U> const& r): ptr(r.ptr), rc(r.rc) {
if (r.expired()) {
throw BadWeak{};
}
if (rc) {
rc->add_shared();
}
}
template <typename U>
Rc(Rc<U> const& r, T* const ptr): ptr(ptr), rc(r.rc) {
if (rc) {
rc->add_shared();
}
}
template <typename U>
Rc(Rc<U>&& r, T* const ptr): ptr(ptr), rc(r.rc) {
r.ptr = nullptr;
r.rc = nullptr;
}
Rc(Rc const& r) noexcept: ptr(r.ptr), rc(r.rc) {
if (rc) {
rc->add_shared();
}
}
Rc(Rc&& r) noexcept: ptr(r.ptr), rc(r.rc) {
r.ptr = nullptr;
r.rc = nullptr;
}
~Rc() noexcept {
if (rc) {
rc->release_shared();
}
}
void swap(Rc& r) noexcept {
using std::swap;
swap(ptr, r.ptr);
swap(rc, r.rc);
}
Rc& operator=(Rc r) noexcept {
r.swap(*this);
return *this;
}
void reset(nullptr_t = nullptr) noexcept {
Rc().swap(*this);
}
template <typename U>
requires std::is_convertible_v<U, T>
void reset(U* const ptr) {
Rc(ptr).swap(*this);
}
T& operator*() {
return *ptr;
}
T const& operator*() const {
return *ptr;
}
T* operator->() noexcept {
return ptr;
}
T const* operator->() const noexcept {
return ptr;
}
explicit operator bool() const noexcept {
return ptr;
}
T* get() noexcept {
return ptr;
}
T const* get() const noexcept {
return ptr;
}
template <typename U>
bool operator==(Rc<U> const& r) const noexcept {
return ptr == r.ptr;
}
bool operator==(nullptr_t) const noexcept {
return ptr == nullptr;
}
template <typename U>
std::strong_ordering operator<=>(Rc<U> const& r) const noexcept {
return ptr <=> r.ptr;
}
std::strong_ordering operator<=>(nullptr_t) const noexcept {
return ptr <=> static_cast<T*>(nullptr);
}
int64_t use_count() const noexcept {
return rc ? rc->use_count() : 0;
}
// Caller must ensure an unused shared count.
static Rc _unsafe_create_with_control(
T* const ptr, detail::RcControl* const rc) noexcept {
Rc r;
r.ptr = ptr;
r.rc = rc;
return r;
}
private:
friend class Weak<T>;
template <typename U>
friend class Rc;
T* ptr;
detail::RcControl* rc;
};
template <typename T>
class Weak {
public:
constexpr Weak(nullptr_t = nullptr) noexcept: ptr(nullptr), rc(nullptr) {}
Weak(Rc<T> const& r) noexcept: ptr(r.ptr), rc(r.rc) {
if (rc) {
rc->add_weak();
}
}
Weak(Weak const& r) noexcept: ptr(r.ptr), rc(r.rc) {
if (rc) {
rc->add_weak();
}
}
Weak(Weak&& r) noexcept: ptr(r.ptr), rc(r.rc) {
r.ptr = nullptr;
r.rc = nullptr;
}
~Weak() noexcept {
if (rc) {
rc->release_weak();
}
}
void swap(Weak& r) noexcept {
using std::swap;
swap(ptr, r.ptr);
swap(rc, r.rc);
}
void reset() noexcept {
Weak().swap(*this);
}
Weak& operator=(Weak r) noexcept {
r.swap(*this);
return *this;
}
bool expired() const noexcept {
return rc && rc->use_count() == 0;
}
Rc<T> lock() {
return expired() ? Rc<T>() : Rc<T>(*this);
}
private:
// Caller must ensure an unused weak count.
static Weak _unsafe_create_with_control(
T* const ptr, detail::RcControl* const rc) noexcept {
Weak r;
r.ptr = ptr;
r.rc = rc;
return r;
}
// Caller must dispose of both ptr and rc.
void _unsafe_release() && noexcept {
assert(ptr && rc->use_count() == 1 && rc->weak_count() == 0);
ptr = nullptr;
rc = nullptr;
}
// Caller must ensure an unused shared count and that it is safe to leak a
// weak.
Rc<T> _unsafe_promote() && noexcept {
using std::swap;
assert(!expired());
Rc<T> r;
swap(ptr, r.ptr);
swap(rc, r.rc);
return r;
}
friend class Rc<T>;
template <typename U, typename... Args> // XX needs U?
friend Rc<U> makeRc(Args&&... args);
T* ptr;
detail::RcControl* rc;
};
// Constructs an instance of T in-place alongside the reference-counting
// metadata.
//
// If T has a constructor that takes a Weak<T> const& followed by Args..., then
// that constructor is called with a weak reference to the T being constructed.
// A suggested usage is to store a copy of the reference as e.g. weak_this_ and
// expose it via a sharedThis() (and/or weakThis()) method. This works like
// std::enable_shared_from_this, except that the STL's shared_from_this() cannot
// be accessed during construction.
//
// All copies of the Weak parameter *must* be destroyed if the constructor
// throws. (This happens automatically if the parameter is only assigned to Weak
// fields of transitive members.)
//
// makeRc throws on allocation failure or exception in T's constructor.
//
template <typename T, typename... Args>
Rc<T> makeRc(Args&&... args) {
auto control = new detail::RcEmplace<T> {};
auto hold = detail::Hold { control };
if constexpr (std::is_constructible_v<T, Weak<T> const&, Args...>) {
// We avoid some unnecessary refcount mutations by directly creating a Weak,
// passing it by const&, and later promoting it to an Rc. A freshly created
// RcControl has a phantom weak use owned by the shared uses; our Weak
// essentially inhabits that use. This is safe as long as we dispose of it
// via either _unsafe_release or _unsafe_promote with invariants met.
//
auto w = Weak<T>::_unsafe_create_with_control(control->t(), control);
try {
control->construct(
const_cast<Weak<T> const&>(w), std::forward<Args>(args)...);
}
catch (...) {
std::move(w)._unsafe_release(); // hold will dispose
throw;
}
hold.ptr = nullptr;
return std::move(w)._unsafe_promote(); // takes the first use
}
else {
control->construct(std::forward<Args>(args)...);
hold.ptr = nullptr;
return Rc<T>::_unsafe_create_with_control(control->t(), control);
}
}
#include "rc.h"
#include <cstdio>
#include <functional>
#include <memory>
#include <string>
#include <sstream>
#include <vector>
int g_total = 0;
int g_failing = 0;
#define EXPECT_THAT(P) do { \
++g_total; \
if (!(P)) { \
fprintf(stderr, "Expectation failed: %s\n" \
" %s:%d (%s)\n", #P, __FILE__, __LINE__, \
__PRETTY_FUNCTION__); \
++g_failing; \
} \
} while (0)
#define EXPECT_EQUAL(E, A) do { \
++g_total; \
if ((E) != (A)) { \
fprintf(stderr, "Expected %s, got %s (%s)\n" \
" %s:%d (%s)\n", #E, \
lexical_cast(A).c_str(), #A, __FILE__, \
__LINE__, __PRETTY_FUNCTION__); \
++g_failing; \
} \
} while (0)
#define EXPECT_THROWS(EXC, EXPR) do { \
++g_total; \
bool threw = false; \
bool threwOther = false; \
try { \
[&]{ EXPR; }(); \
} catch (EXC& _exc) { \
threw = true; \
} catch (...) { \
threwOther = true; \
} \
if (!threw) { \
fprintf(stderr, "Expected `%s` to throw a(n) " \
"`%s`, but it didn't.\n", #EXPR, #EXC); \
if (threwOther) { \
fprintf(stderr, "An unexpected exception was " \
"thrown instead.\n"); \
} \
fprintf(stderr, " %s:%d (%s)\n", __FILE__, \
__LINE__, __PRETTY_FUNCTION__); \
++g_failing; \
} \
} while (0)
template <typename T>
std::string lexical_cast(T&& t) {
std::ostringstream os;
os << std::boolalpha << std::forward<T>(t);
return os.str();
}
template <typename T>
std::ostream& operator<<(std::ostream& os, Rc<T> const& r) {
os << "Rc(" << r.get() << ")";
return os;
}
struct StubRcControl: public detail::RcControl {
bool on_zero_shared_called = false;
bool on_zero_weak_called = false;
void on_zero_shared() noexcept override {
EXPECT_THAT(!on_zero_shared_called);
on_zero_shared_called = true;
}
void on_zero_weak() noexcept override {
EXPECT_THAT(!on_zero_weak_called);
on_zero_weak_called = true;
}
};
struct RcTests {
RcTests():
rc(std::make_unique<StubRcControl>()),
other(std::make_unique<StubRcControl>()) {}
template <typename T>
Rc<T> makeRcWithStub(T* const ptr) {
return Rc<T>::_unsafe_create_with_control(ptr, rc.get());
}
template <typename T>
Rc<T> makeRcWithStub(T* const ptr, detail::RcControl* const rc) {
return Rc<T>::_unsafe_create_with_control(ptr, rc);
}
Rc<int> makeRcWithIntStub() {
return Rc<int>::_unsafe_create_with_control(&x, rc.get());
}
void testNonNull() {
{
auto r = makeRcWithIntStub();
EXPECT_EQUAL(false, rc->on_zero_shared_called);
EXPECT_EQUAL(false, rc->on_zero_weak_called);
}
EXPECT_EQUAL(true, rc->on_zero_shared_called);
EXPECT_EQUAL(true, rc->on_zero_weak_called);
}
void testWithWeak() {
{
Weak<int> w;
{
auto r = makeRcWithIntStub();
w = r;
}
EXPECT_EQUAL(true, rc->on_zero_shared_called);
EXPECT_EQUAL(false, rc->on_zero_weak_called);
}
EXPECT_EQUAL(true, rc->on_zero_weak_called);
}
void testUseCount() {
{
auto r = makeRcWithIntStub();
EXPECT_EQUAL(1, rc->use_count());
EXPECT_EQUAL(0, rc->weak_count());
{
auto r1 = r;
EXPECT_EQUAL(2, rc->use_count());
EXPECT_EQUAL(0, rc->weak_count());
{
auto r2 = r;
auto w = Weak(r2);
EXPECT_EQUAL(3, rc->use_count());
EXPECT_EQUAL(1, rc->weak_count());
}
EXPECT_EQUAL(2, rc->use_count());
EXPECT_EQUAL(0, rc->weak_count());
}
EXPECT_EQUAL(1, rc->use_count());
EXPECT_EQUAL(0, rc->weak_count());
EXPECT_EQUAL(false, rc->on_zero_shared_called);
EXPECT_EQUAL(false, rc->on_zero_weak_called);
}
EXPECT_EQUAL(0, rc->use_count());
EXPECT_EQUAL(-1, rc->weak_count());
EXPECT_EQUAL(true, rc->on_zero_shared_called);
EXPECT_EQUAL(true, rc->on_zero_weak_called);
}
struct Bad {
Rc<Bad> other;
};
void testRefcountCycle() {
Bad b1, b2;
{
auto r1 = makeRcWithStub(&b1, rc.get());
auto r2 = makeRcWithStub(&b2, other.get());
r1->other = r2;
r2->other = r1;
}
EXPECT_EQUAL(false, rc->on_zero_shared_called);
EXPECT_EQUAL(false, rc->on_zero_weak_called);
EXPECT_EQUAL(false, other->on_zero_shared_called);
EXPECT_EQUAL(false, other->on_zero_weak_called);
}
struct Delegate;
struct Owner {
Rc<Delegate> other;
};
struct Delegate {
Weak<Owner> other;
};
void testRefcountBreakCycle() {
Owner owner;
Delegate delegate;
{
auto r1 = makeRcWithStub(&owner, rc.get());
auto r2 = makeRcWithStub(&delegate, other.get());
r1->other = r2;
r2->other = r1;
r2.reset();
EXPECT_EQUAL(false, rc->on_zero_shared_called);
EXPECT_EQUAL(false, rc->on_zero_weak_called);
EXPECT_EQUAL(false, other->on_zero_shared_called);
EXPECT_EQUAL(false, other->on_zero_weak_called);
}
EXPECT_EQUAL(true, rc->on_zero_shared_called);
EXPECT_EQUAL(false, rc->on_zero_weak_called);
EXPECT_EQUAL(false, other->on_zero_shared_called);
EXPECT_EQUAL(false, other->on_zero_weak_called);
EXPECT_EQUAL(0, rc->weak_count());
EXPECT_EQUAL(0, other->weak_count());
// Simulate owner going out of scope due to rc->on_zero_shared
owner.other.reset();
EXPECT_EQUAL(false, rc->on_zero_weak_called);
EXPECT_EQUAL(true, other->on_zero_shared_called);
EXPECT_EQUAL(true, other->on_zero_weak_called);
// Simulate delegate going out of scope due to other->on_zero_shared
delegate.other.reset();
EXPECT_EQUAL(true, rc->on_zero_weak_called);
}
struct HasSharedThis {
Weak<HasSharedThis> weak_this;
explicit HasSharedThis(Weak<HasSharedThis> const& self):
weak_this(self) {}
Rc<HasSharedThis> sharedThis() {
return weak_this.lock();
}
Weak<HasSharedThis> weakThis() {
return weak_this;
}
};
void testMakeRcWeakThis() {
auto r = makeRc<HasSharedThis>();
EXPECT_THAT(r);
EXPECT_EQUAL(r, r->sharedThis());
auto w = r->weakThis();
EXPECT_THAT(!w.expired());
EXPECT_EQUAL(r, r->weakThis().lock());
EXPECT_EQUAL(r, w.lock());
auto r2 = r->sharedThis();
r.reset();
EXPECT_THAT(!w.expired());
r2.reset();
EXPECT_THAT(w.expired());
}
void testMakeRcNoWeakThis() {
auto x = makeRc<int>(5);
EXPECT_EQUAL(5, *x);
auto w = Weak(x);
EXPECT_THAT(!w.expired());
x.reset();
EXPECT_THAT(w.expired());
}
struct Error {};
struct CtorThrows {
CtorThrows() {
throw Error{};
}
};
void testMakeRcThrowingCtor() {
EXPECT_THROWS(Error, makeRc<CtorThrows>());
}
struct WeakSelfThrows {
WeakSelfThrows(Weak<WeakSelfThrows> const&) {
throw Error{};
}
};
void testMakeRcWeakSelfThrows() {
EXPECT_THROWS(Error, makeRc<WeakSelfThrows>());
}
bool didDestroy = false;
struct CheckedDestroy {
CheckedDestroy(RcTests* t): t(t) {}
RcTests* t;
~CheckedDestroy() {
EXPECT_THAT(!t->didDestroy);
t->didDestroy = true;
}
};
void testRawPointer() {
auto r = Rc<CheckedDestroy>(new CheckedDestroy(this));
EXPECT_THAT(!didDestroy);
auto w = Weak(r);
EXPECT_THAT(!w.expired());
r.reset();
EXPECT_THAT(didDestroy);
EXPECT_THAT(w.expired());
}
void testNull() {
Rc<nullptr_t> hold;
Weak<int> w;
{
auto r = makeRcWithIntStub();
hold = Rc<nullptr_t>(r, nullptr);
w = r;
}
EXPECT_THAT(!w.expired());
x = 5;
EXPECT_EQUAL(5, *w.lock());
hold.reset();
EXPECT_THAT(w.expired());
w.reset();
EXPECT_THAT(rc->on_zero_weak_called);
}
void testHoldNullWeak() {
auto r = makeRcWithIntStub();
Rc<nullptr_t> r2(r, nullptr);
auto w = Weak(r2);
r2.reset();
r.reset();
EXPECT_EQUAL(false, rc->on_zero_weak_called);
w.reset();
EXPECT_EQUAL(true, rc->on_zero_weak_called);
}
void testCompare() {
int x[2];
auto r1 = makeRcWithStub(&x[0]);
auto r2 = makeRcWithStub(&x[1]);
EXPECT_THAT(!(r1 == r2));
EXPECT_THAT(r1 != r2);
EXPECT_THAT(r1 < r2);
EXPECT_THAT(!(r1 > r2));
EXPECT_THAT(r1 <= r2);
EXPECT_THAT(!(r1 >= r2));
auto r3 = makeRcWithStub(&x[0]);
EXPECT_THAT(r1 == r3);
EXPECT_THAT(!(r1 != r3));
EXPECT_THAT(!(r1 < r3));
EXPECT_THAT(!(r1 > r3));
EXPECT_THAT(r1 <= r3);
EXPECT_THAT(r1 >= r3);
EXPECT_THAT(r1 != nullptr);
EXPECT_THAT(!(r1 == nullptr));
EXPECT_THAT(!(r1 < nullptr));
EXPECT_THAT(!(r1 <= nullptr));
EXPECT_THAT(r1 > nullptr);
EXPECT_THAT(r1 >= nullptr);
}
static void testAll() {
auto ts = std::vector<std::function<void(RcTests&&)>> {
&RcTests::testNonNull,
&RcTests::testWithWeak,
&RcTests::testUseCount,
&RcTests::testRefcountCycle,
&RcTests::testRefcountBreakCycle,
&RcTests::testMakeRcWeakThis,
&RcTests::testMakeRcNoWeakThis,
&RcTests::testMakeRcThrowingCtor,
&RcTests::testMakeRcWeakSelfThrows,
&RcTests::testRawPointer,
&RcTests::testNull,
&RcTests::testHoldNullWeak,
&RcTests::testCompare,
};
for (auto t: ts) {
t(RcTests());
}
}
int x;
std::unique_ptr<StubRcControl> rc;
std::unique_ptr<StubRcControl> other;
};
int main() {
RcTests::testAll();
if (g_failing) {
const char* plural[] = {"", "s"};
fprintf(stderr, "%d failure%s of %d\n",
g_failing, plural[g_failing > 1], g_total);
}
else {
setlocale(LC_CTYPE, "");
wchar_t ok(0x2705);
wprintf(L"%lc %d/%d\n", ok, g_total, g_total);
}
return g_failing ? 1 : 0;
}
#include "rc.h"
#include <cstdio>
#include <iterator>
#include <ranges>
#define FWD(x) std::forward<decltype(x)>(x)
template <typename T>
struct Tree {
T value;
Rc<Tree> left, right;
Weak<Tree> parent;
Tree(Weak<Tree> const& self, auto&& value, auto&& nLeft, auto&& nRight):
value(FWD(value)), left(FWD(nLeft)), right(FWD(nRight))
{
if (left) {
left->parent = self;
}
if (right) {
right->parent = self;
}
}
};
template <typename Iter>
Rc<Tree<std::iter_value_t<Iter>>> tree(Iter begin, Iter end) {
using Tree = Tree<std::iter_value_t<Iter>>;
using std::distance;
if (begin == end) {
return nullptr;
}
auto mid = begin + distance(begin, end) / 2;
return makeRc<Tree>(*mid, tree(begin, mid), tree(mid + 1, end));
}
// Constructs a balanced binary tree from the input range. If the range is
// sorted, then the tree is a binary search tree.
//
template <typename R>
Rc<Tree<std::ranges::range_value_t<R>>> tree(R const& r) {
using std::begin;
using std::end;
return tree(begin(r), end(r));
}
// O(1)-space inorder traversal with depth
template <typename T, typename F>
void traverse(Rc<Tree<T>> const& tree, F f) {
if (!tree) {
return;
}
auto curr = tree;
auto depth = 0;
while (curr->left) {
curr = curr->left; ++depth;
}
while (curr) {
f(curr->value, depth);
if (curr->right) {
curr = curr->right; ++depth;
while (curr->left) {
curr = curr->left; ++depth;
}
}
else {
auto parent = Rc<Tree<T>>();
while ((parent = curr->parent.lock()) && parent->right == curr) {
curr = std::move(parent); --depth;
}
curr = std::move(parent); --depth;
}
}
}
int main() {
const auto xs = std::views::iota(1, 16);
auto t = tree(xs);
traverse(t, [](int x, int d) {
printf("%*d\n", d * 4, x);
});
auto x = Weak(t);
printf("expired: %d\n", x.expired());
t.reset();
printf("expired: %d\n", x.expired());
return 0;
}
PROJECT_FLAGS=-std=c++2b -Wall -pedantic
all: test-rc tree-shared
leak-check: all
leaks --atExit -- ./test-rc && \
leaks --atExit -- ./tree-shared
test-rc: test-rc.c++ rc.h z.mk
$(CXX) $(PROJECT_FLAGS) $(CXXFLAGS) $< -o $@
tree-shared: tree-shared.c++ rc.h z.mk
$(CXX) $(PROJECT_FLAGS) $(CXXFLAGS) $< -o $@
clean:
rm -f test-rc tree-shared
.PHONY: all clean leak-check
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment