Skip to content

Instantly share code, notes, and snippets.

@warrenseine
Created November 1, 2012 16:37
Show Gist options
  • Save warrenseine/3994894 to your computer and use it in GitHub Desktop.
Save warrenseine/3994894 to your computer and use it in GitHub Desktop.
C++11 Observer Pattern (removable slots, simple syntax, event-based)
#include <memory>
#include <vector>
#include <functional>
#include <unordered_map>
#include <map>
#include <string>
#include <typeindex>
#include <iostream>
#include <cassert>
struct Message
{
using Ptr = std::shared_ptr<Message>;
virtual ~Message() = default;
};
class Observer
{
public:
using Ptr = std::shared_ptr<Observer>;
using ArgumentType = Message::Ptr;
using ReturnType = void;
virtual
~Observer() = default;
virtual bool
call(ArgumentType) = 0;
virtual bool
equals(Observer::Ptr) const = 0;
};
template <typename M>
class FreeObserver : public Observer
{
public:
using FreeFunctionType = ReturnType (*)(M);
FreeObserver(FreeFunctionType f) :
_f(f)
{
}
virtual bool
call(ArgumentType argument) override
{
M a = std::static_pointer_cast<typename M::element_type>(argument);
_f(a);
return true;
}
virtual bool
equals(Observer::Ptr p) const override
{
FreeObserver<M>* freeObserver = dynamic_cast<FreeObserver<M>*>(p.get());
if (not freeObserver)
return false;
return _f == freeObserver->_f;
}
private:
FreeFunctionType _f;
};
template <typename M, typename T>
class WeakMemberObserver : public Observer
{
public:
using MemberFunctionType = ReturnType (T::*)(M m);
WeakMemberObserver(MemberFunctionType f, std::weak_ptr<T> p) :
_p(p),
_f(f)
{
}
virtual bool
call(ArgumentType argument) override
{
std::shared_ptr<T> p = _p.lock();
if (not p)
return false;
T* t = p.get();
M a = std::static_pointer_cast<typename M::element_type>(argument);
(t->*_f)(a);
return true;
}
virtual bool
equals(Observer::Ptr p) const override
{
const WeakMemberObserver<M, T>* memberObserver = dynamic_cast<WeakMemberObserver<M, T>*>(p.get());
if (not memberObserver)
return false;
return _f == memberObserver->_f && _p.lock() == memberObserver->_p.lock();
}
private:
std::weak_ptr<T> _p;
MemberFunctionType _f;
};
template <typename M, typename T>
class RawMemberObserver : public Observer
{
public:
using MemberFunctionType = ReturnType (T::*)(M m);
RawMemberObserver(MemberFunctionType f, T* p) :
_p(p),
_f(f)
{
}
virtual bool
call(ArgumentType argument) override
{
M a = std::static_pointer_cast<typename M::element_type>(argument);
(_p->*_f)(a);
return true;
}
virtual bool
equals(Observer::Ptr p) const override
{
const RawMemberObserver<M, T>* memberObserver = dynamic_cast<RawMemberObserver<M, T>*>(p.get());
if (not memberObserver)
return false;
return _f == memberObserver->_f && _p == memberObserver->_p;
}
private:
T* _p;
MemberFunctionType _f;
};
class MessageDispatcher
{
public:
using ObserverList = std::vector<Observer::Ptr>;
using ObserverMap = std::map<std::type_index, ObserverList>;
int
dispatchMessage(Message* message)
{
return dispatchMessage(Message::Ptr(message));
}
int
dispatchMessage(Message::Ptr message)
{
int dispatched = 0;
setUp(message);
ObserverList collected;
collectObservers(message, collected);
for (auto observer : collected)
{
if (observer->call(message))
dispatched++;
else
doRemoveMessageHandler(*observer, typeid(*message));
}
tearDown(message);
return dispatched;
}
bool
setUp(Message::Ptr message) const
{
return true;
}
bool
tearDown(Message::Ptr message) const
{
return true;
}
void
collectObservers(Message::Ptr message, ObserverList& collected) const
{
std::type_index ti = typeid(*message);
for (auto i : _globalObservers)
collected.push_back(i);
auto found = _localObservers.find(ti);
if (found != _localObservers.end())
{
const ObserverList& localObservers = found->second;
for (auto i : localObservers)
collected.push_back(i);
}
}
template <typename T, typename M>
void
addMessageHandler(void (T::*func)(std::shared_ptr<M>), T* that)
{
Observer::Ptr observer(new RawMemberObserver<std::shared_ptr<M>, T>(func, that));
doAddMessageHandler(observer, typeid(M));
}
template <typename T, typename M>
void
addMessageHandler(void (T::*func)(std::shared_ptr<M>), std::weak_ptr<T> that)
{
Observer::Ptr observer(new WeakMemberObserver<std::shared_ptr<M>, T>(func, that));
doAddMessageHandler(observer, typeid(M));
}
template <typename M>
void
addMessageHandler(void (*func)(std::shared_ptr<M>))
{
Observer::Ptr observer(new FreeObserver<std::shared_ptr<M>>(func));
doAddMessageHandler(observer, typeid(M));
}
void
doAddMessageHandler(Observer::Ptr observer, const std::type_info& info)
{
std::type_index ti(info);
if (info == typeid(Message))
_globalObservers.push_back(observer);
else
_localObservers[ti].push_back(observer);
}
template <typename T, typename M>
void
removeMessageHandler(void (T::*func)(std::shared_ptr<M>), T* that)
{
RawMemberObserver<std::shared_ptr<M>, T> observer(func, that);
doRemoveMessageHandler(observer, typeid(M));
}
template <typename T, typename M>
void
removeMessageHandler(void (T::*func)(std::shared_ptr<M>), std::weak_ptr<T> that)
{
WeakMemberObserver<std::shared_ptr<M>, T> observer(func, that);
doRemoveMessageHandler(observer, typeid(M));
}
template <typename M>
void
removeMessageHandler(void (*func)(std::shared_ptr<M>))
{
FreeObserver<std::shared_ptr<M>> observer(func);
doRemoveMessageHandler(observer, typeid(M));
}
template <typename T, typename M>
void
removeMessageHandler(std::function<void (std::shared_ptr<M>)> func)
{
}
void
doRemoveMessageHandler(Observer& observer, const std::type_info& info)
{
std::type_index ti(info);
ObserverMap::iterator found = _localObservers.find(ti);
if (found != _localObservers.end())
{
ObserverList& localObservers = found->second;
auto end = std::remove_if(localObservers.begin(), localObservers.end(), [&](Observer::Ptr current) -> bool
{
return observer.equals(current);
});
localObservers.erase(end, localObservers.end());
}
auto end = std::remove_if(_globalObservers.begin(), _globalObservers.end(), [&](Observer::Ptr current) -> bool
{
bool b = observer.equals(current);
return observer.equals(current);
});
_globalObservers.erase(end, _globalObservers.end());
}
private:
ObserverMap _localObservers;
ObserverList _globalObservers;
};
struct FooMessage : public Message
{
using Ptr = std::shared_ptr<FooMessage>;
};
struct BarMessage : public Message
{
using Ptr = std::shared_ptr<BarMessage>;
};
struct A : public MessageDispatcher
{
void e(FooMessage::Ptr m)
{
std::cout << "A::e()" << std::endl;
}
void f(Message::Ptr m)
{
std::cout << "A::f()" << std::endl;
}
void g(Message::Ptr m)
{
std::cout << "A::g()" << std::endl;
}
};
void h(Message::Ptr m)
{
std::cout << "h()" << std::endl;
}
void i(Message::Ptr m)
{
std::cout << "i()" << std::endl;
}
void j(FooMessage::Ptr m)
{
std::cout << "j()" << std::endl;
}
int main()
{
std::cout << "# Dispatch on no handler" << std::endl;
{
A a;
int r = a.dispatchMessage(new Message());
assert(r == 0);
}
std::cout << "# Dispatch on one free function" << std::endl;
{
A a;
a.addMessageHandler(&h);
int r = a.dispatchMessage(new Message());
assert(r == 1);
}
std::cout << "# Dispatch on a non-matching free function" << std::endl;
{
A a;
a.addMessageHandler(&j);
int r = a.dispatchMessage(new BarMessage());
assert(r == 0);
}
std::cout << "# Dispatch on multiple free functions" << std::endl;
{
A a;
a.addMessageHandler(&h);
a.addMessageHandler(&i);
int r = a.dispatchMessage(new Message());
assert(r == 2);
}
std::cout << "# Multiple remove of free functions at once" << std::endl;
{
A a;
a.addMessageHandler(&h);
a.addMessageHandler(&h);
a.removeMessageHandler(&h);
int r = a.dispatchMessage(new Message());
assert(r == 0);
}
std::cout << "# Remove of free function in global handlers" << std::endl;
{
A a;
a.addMessageHandler(&h);
a.addMessageHandler(&i);
a.addMessageHandler(&j);
a.removeMessageHandler(&j);
int r = a.dispatchMessage(new Message());
assert(r == 2);
}
std::cout << "# Remove of free function in local handlers" << std::endl;
{
A a;
a.addMessageHandler(&h);
a.addMessageHandler(&i);
a.addMessageHandler(&j);
a.removeMessageHandler(&j);
int r = a.dispatchMessage(new FooMessage());
assert(r == 2);
}
std::cout << "# Dispatch of specific message with a free function" << std::endl;
{
A a;
a.addMessageHandler(&j);
int r1 = a.dispatchMessage(new FooMessage());
int r2 = a.dispatchMessage(new BarMessage());
assert(r1 == 1 && r2 == 0);
}
std::cout << "# Dispatch on one member function" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
int r = a.dispatchMessage(new Message());
assert(r == 1);
}
std::cout << "# Dispatch on the same member function twice" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
a.addMessageHandler(&A::f, &a);
int r = a.dispatchMessage(new Message());
assert(r == 2);
}
std::cout << "# Dispatch on multiple member functions" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
a.addMessageHandler(&A::g, &a);
int r = a.dispatchMessage(new Message());
assert(r == 2);
}
std::cout << "# Remove of member function in global handlers" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
a.addMessageHandler(&A::g, &a);
a.removeMessageHandler(&A::f, &a);
int r = a.dispatchMessage(new Message());
assert(r == 1);
}
std::cout << "# Remove of member function in local handlers" << std::endl;
{
A a;
a.addMessageHandler(&A::g, &a);
a.addMessageHandler(&A::e, &a);
a.removeMessageHandler(&A::e, &a);
int r = a.dispatchMessage(new FooMessage());
assert(r == 1);
}
std::cout << "# Useless remove" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
a.addMessageHandler(&A::g, &a);
a.removeMessageHandler(&A::f, &a);
a.removeMessageHandler(&A::f, &a);
int r = a.dispatchMessage(new Message());
assert(r == 1);
}
std::cout << "# Dispatch on different MessageDispatchers" << std::endl;
{
A a1;
A a2;
a1.addMessageHandler(&A::f, &a1);
a1.addMessageHandler(&A::f, &a2);
int r1 = a1.dispatchMessage(new Message());
int r2 = a2.dispatchMessage(new Message());
assert(r1 == 2 && r2 == 0);
}
std::cout << "# Dispatch on both member and free functions" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
a.addMessageHandler(&h);
int r = a.dispatchMessage(new BarMessage());
assert(r == 2);
}
std::cout << "# Remove of a member function while keeping a free function" << std::endl;
{
A a;
a.addMessageHandler(&A::f, &a);
a.addMessageHandler(&h);
a.removeMessageHandler(&A::f, &a);
int r = a.dispatchMessage(new BarMessage());
assert(r == 1);
}
std::cout << "# Dispatch on lambda function" << std::endl;
{
A a;
auto f = [](std::shared_ptr<Message>)
{
std::cout << "[](Message)" << std::endl;
};
a.addMessageHandler<Message>(f);
int r = a.dispatchMessage(new BarMessage());
assert(r == 1);
}
std::cout << "# Remove a lambda function" << std::endl;
{
A a;
auto f = [](std::shared_ptr<Message>)
{
std::cout << "[](Message)" << std::endl;
};
a.addMessageHandler<Message>(f);
a.removeMessageHandler<Message>(f);
int r = a.dispatchMessage(new BarMessage());
assert(r == 0);
}
std::cout << "# Try to remove an inline lambda function" << std::endl;
{
A a;
a.addMessageHandler<Message>([](std::shared_ptr<Message>)
{
std::cout << "[](Message)" << std::endl;
});
a.removeMessageHandler<Message>([](std::shared_ptr<Message>)
{
std::cout << "[](Message)" << std::endl;
});
int r = a.dispatchMessage(new BarMessage());
assert(r == 1);
}
std::cout << "# Dispatch on a non-matching lambda function" << std::endl;
{
A a;
a.addMessageHandler<FooMessage>([](std::shared_ptr<FooMessage>)
{
std::cout << "[](FooMessage)" << std::endl;
});
int r = a.dispatchMessage(new BarMessage());
assert(r == 0);
}
}
// Compile and test with: clang++ -std=c++11 -stdlib=libc++ Observer.cpp && ./a.out
@stephenamills
Copy link

This is pretty cool.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment