Skip to content

Instantly share code, notes, and snippets.

@kwilczynski
Last active March 22, 2021 07:00
Show Gist options
  • Save kwilczynski/8cb77f389a9643a2ab8f93f1737b8361 to your computer and use it in GitHub Desktop.
Save kwilczynski/8cb77f389a9643a2ab8f93f1737b8361 to your computer and use it in GitHub Desktop.
Generic Registry type using concurrent sharded map in Go
package registry
import (
"fmt"
"hash/fnv"
"sync"
)
type Registrar interface {
Add(string, interface{}) interface{}
Contains(string) bool
Get(string) interface{}
Remove(string) interface{}
Names() []string
Values() []interface{}
Each(EachCallback)
Count() int
Empty() bool
Copy() interface{}
Load(map[string]interface{})
Save() map[string]interface{}
Reset()
String() string
}
type EachCallback func(string, interface{}) bool
const RegistryShards = 32
type Registry []*registry
type registry struct {
sync.RWMutex
items map[string]interface{}
}
type tuple struct {
Key string
Value interface{}
}
func New() *Registry {
r := make(Registry, RegistryShards)
for i := 0; i < RegistryShards; i++ {
r[i] = &registry{
items: make(map[string]interface{}),
}
}
return &r
}
func (r *Registry) Add(name string, value interface{}) interface{} {
s := r.shard(name)
s.Lock()
s.items[name] = value
s.Unlock()
return value
}
func (r *Registry) Contains(name string) bool {
s := r.shard(name)
s.RLock()
_, found := s.items[name]
s.RUnlock()
return found
}
func (r *Registry) Get(name string) interface{} {
s := r.shard(name)
s.RLock()
value := s.items[name]
s.RUnlock()
return value
}
func (r *Registry) Remove(name string) interface{} {
s := r.shard(name)
s.Lock()
value := s.items[name]
delete(s.items, name)
s.Unlock()
return value
}
func (r *Registry) Names() []string {
c := r.collect()
names := make([]string, cap(c))
i := 0
for t := range c {
names[i] = t.Key
i++
}
return names
}
func (r *Registry) Values() []interface{} {
c := r.collect()
values := make([]interface{}, cap(c))
i := 0
for t := range c {
values[i] = t.Value
i++
}
return values
}
func (r *Registry) Each(function EachCallback) {
for t := range r.collect() {
if !function(t.Key, t.Value) {
break
}
}
}
func (r *Registry) Count() int {
c := 0
for i := 0; i < RegistryShards; i++ {
s := (*r)[i]
s.RLock()
c += len(s.items)
s.RUnlock()
}
return c
}
func (r *Registry) Empty() bool {
return r.Count() == 0
}
func (r *Registry) Copy() interface{} {
copy := make(Registry, RegistryShards)
for i := 0; i < RegistryShards; i++ {
s := (*r)[i]
s.RLock()
copy[i] = &registry{
items: make(map[string]interface{}, len(s.items)),
}
s.RUnlock()
}
for t := range r.collect() {
copy.Add(t.Key, t.Value)
}
return &copy
}
func (r *Registry) Load(data map[string]interface{}) {
for k, v := range data {
s := r.shard(k)
s.Lock()
s.items[k] = v
s.Unlock()
}
}
func (r *Registry) Save() map[string]interface{} {
data := make(map[string]interface{}, r.Count())
for t := range r.collect() {
data[t.Key] = t.Value
}
return data
}
func (r *Registry) Reset() {
for i := 0; i < RegistryShards; i++ {
s := (*r)[i]
s.Lock()
s.items = make(map[string]interface{})
s.Unlock()
}
}
func (r *Registry) String() string {
return fmt.Sprintf("Registry(%d): %v", r.Count(), r.Names())
}
func (r *Registry) hash(name string) uint32 {
h := fnv.New32()
if _, err := h.Write([]byte(name)); err != nil {
panic("unreachable")
}
return h.Sum32()
}
func (r *Registry) shard(name string) *registry {
return (*r)[r.hash(name)&uint32(RegistryShards-1)]
}
func (r *Registry) collect() <-chan tuple {
count := r.Count()
c := make(chan tuple, count)
go func() {
wg := sync.WaitGroup{}
wg.Add(RegistryShards)
for _, shard := range *r {
go func(s *registry) {
s.RLock()
for k, v := range s.items {
c <- tuple{k, v}
}
s.RUnlock()
wg.Done()
}(shard)
}
wg.Wait()
close(c)
}()
return c
}
package registry
import (
"fmt"
"reflect"
"sort"
"testing"
)
func TestRegistry_New(t *testing.T) {
var b bool
r := New()
if r == nil {
t.Fatalf("expecetd type %T, got %T", &Registry{}, r)
}
func(v interface{}) {
if _, ok := v.(*Registry); !ok {
t.Fatalf("expected type %T, got %T", &Registry{}, v)
}
}(r)
b = r.Empty()
if !b {
t.Errorf("expected empty %v, got %v", true, b)
}
}
func TestRegistry_Add(t *testing.T) {
var i int
var s string
var v interface{}
r := New()
r.Add("test1", "")
r.Add("test2", "")
i = r.Count()
if i != 2 {
t.Errorf("expecetd count %d, got %d", 2, i)
}
v = r.Add("test2", "test456")
if _, ok := v.(string); !ok {
t.Errorf("expected type %T, got %T", "", v)
}
s = v.(string)
if s != "test456" {
t.Errorf("expected string %q, got %q", "test456", s)
}
i = r.Count()
if i != 2 {
t.Errorf("expecetd count %d, got %d", 2, i)
}
s = r.Add("test3", "test123").(string)
if s != "test123" {
t.Errorf("expected string %q, got %q", "test123", s)
}
i = r.Count()
if i != 3 {
t.Errorf("expecetd count %d, got %d", 3, i)
}
}
func TestRegistry_Contains(t *testing.T) {
var b bool
r := New()
r.Add("test1", "")
r.Add("test3", "")
b = r.Contains("test1")
if !b {
t.Errorf("expected contains %v, got %v", true, b)
}
b = r.Contains("test2")
if b {
t.Errorf("expected contains %v, got %v", false, b)
}
}
func TestRegistry_Get(t *testing.T) {
var i int
var s string
var v interface{}
type type1 struct {
value int
}
type type2 struct {
value string
}
t1 := type1{
value: 123,
}
t2 := type2{
value: "test123",
}
r := New()
r.Add("test1", t1)
r.Add("test2", t2)
v = r.Get("test1")
if _, ok := v.(type1); !ok {
t.Errorf("expected type %T, got %T", type1{}, v)
}
if !reflect.DeepEqual(v.(type1), t1) {
t.Errorf("expected object %v, got %v", t1, v)
}
v = r.Get("test2")
if _, ok := v.(type2); !ok {
t.Errorf("expected type %T, got %T", type2{}, v)
}
if !reflect.DeepEqual(v.(type2), t2) {
t.Errorf("expected object %v, got %v", t2, v)
}
s = v.(type2).value
if s != "test123" {
t.Errorf("expecetd string %q, got %q", "test123", s)
}
r.Add("test2", t1)
v = r.Get("test2")
if _, ok := v.(type2); ok {
t.Errorf("expected type %T, got %T", type1{}, v)
}
i = v.(type1).value
if i != 123 {
t.Errorf("expecetd integer %d, got %d", 123, i)
}
}
func TestRegistry_Remove(t *testing.T) {
var i int
var b bool
var s string
var v interface{}
r := New()
r.Add("test1", "test123")
r.Add("test2", "")
v = r.Remove("test1")
if _, ok := v.(string); !ok {
t.Errorf("")
}
s = v.(string)
if s != "test123" {
t.Errorf("expecetd string %q, got %q", "test123", s)
}
v = r.Remove("test3")
if v != nil {
t.Errorf("expecetd object %v, got %v", nil, v)
}
b = r.Empty()
if b {
t.Errorf("expected empty %v, got %v", false, b)
}
i = r.Count()
if i != 1 {
t.Errorf("expecetd count %d, got %d", 1, i)
}
b = r.Contains("test1")
if b {
t.Errorf("expected contains %v, got %v", false, b)
}
}
func TestRegistry_Names(t *testing.T) {
var ss1, ss2 []string
r := New()
r.Add("test1", "")
r.Add("test3", "")
r.Add("test5", "")
ss1 = r.Names()
sort.Strings(ss1)
ss2 = []string{"test1", "test3", "test5"}
if !reflect.DeepEqual(ss1, ss2) {
t.Errorf("expected object %v, got %v", ss2, ss1)
}
}
func TestRegistry_Values(t *testing.T) {
var si1, si2 []interface{}
type type1 struct {
value int
}
type type2 struct {
value string
}
t1 := type1{
value: 123,
}
t2 := type2{
value: "test123",
}
r := New()
r.Add("test1", t1)
r.Add("test2", t2)
// Order of elements in a map is random,
// therefore we normalize to ensure that
// the integer value is always first.
si1 = r.Values()
if _, ok := si1[0].(type2); ok {
si1 = append([]interface{}{si1[1]}, si1...)
si1 = si1[:len(si1)-1]
}
si2 = []interface{}{t1, t2}
if !reflect.DeepEqual(si1, si2) {
t.Errorf("expected object %v, got %v", si2, si1)
}
}
func TestRegistry_Each(t *testing.T) {
var i int
var ss1, ss2 []string
r := New()
r.Add("test1", "")
r.Add("test2", "")
r.Each(func(k string, v interface{}) bool {
ss1 = append(ss1, k)
return true
})
sort.Strings(ss1)
ss2 = []string{"test1", "test2"}
if !reflect.DeepEqual(ss1, ss2) {
t.Errorf("expecetd object %v, got %v", ss2, ss1)
}
i = 0
r.Each(func(k string, v interface{}) bool {
i++
return false
})
if i != 1 {
t.Errorf("expecetd integer %d, got %d", 1, i)
}
}
func TestRegistry_Count(t *testing.T) {
var i int
r := New()
i = r.Count()
if i != 0 {
t.Errorf("expecetd count %d, got %d", 0, i)
}
r.Add("test1", "")
r.Add("test2", "")
i = r.Count()
if i != 2 {
t.Errorf("expecetd count %d, got %d", 2, i)
}
}
func TestRegistry_Empty(t *testing.T) {
var b bool
r := New()
b = r.Empty()
if !b {
t.Errorf("expecetd empty %v, got %v", true, b)
}
r.Add("test1", "")
b = r.Empty()
if b {
t.Errorf("expecetd empty %v, got %v", false, b)
}
}
func TestRegistry_Copy(t *testing.T) {
var b bool
var r1, r2 *Registry
var v interface{}
type type1 struct {
value int
}
type type2 struct {
value string
}
t1 := type1{
value: 123,
}
t2 := type2{
value: "test123",
}
r1 = New()
r1.Add("test1", t1)
r1.Add("test2", t2)
v = r1.Copy()
if _, ok := v.(*Registry); !ok {
t.Fatalf("expected type %T, got %T", &Registry{}, v)
}
r2 = v.(*Registry)
b = reflect.DeepEqual(r1, r2)
if !b {
t.Errorf("expecetd object %v, got %v", true, b)
}
r1.Reset()
b = reflect.DeepEqual(r1, r2)
if b {
t.Errorf("expecetd object %v, got %v", false, b)
}
}
func TestRegistry_Load(t *testing.T) {
var i int
var v interface{}
type type1 struct {
value int
}
type type2 struct {
value string
}
t1 := type1{
value: 123,
}
t2 := type2{
value: "test123",
}
m := map[string]interface{}{
"test1": t1,
"test2": t2,
}
r := New()
r.Load(m)
i = r.Count()
if i != 2 {
t.Errorf("expecetd count %d, got %d", 2, i)
}
v = r.Get("test1")
if _, ok := v.(type1); !ok {
t.Errorf("expected type %T, got %T", type1{}, v)
}
if !reflect.DeepEqual(v.(type1), t1) {
t.Errorf("expected object %v, got %v", t1, v)
}
}
func TestRegistry_Save(t *testing.T) {
var m1, m2 map[string]interface{}
type type1 struct {
value int
}
type type2 struct {
value string
}
t1 := type1{
value: 123,
}
t2 := type2{
value: "test123",
}
m1 = map[string]interface{}{
"test1": t1,
"test2": t2,
}
r := New()
r.Add("test1", t1)
r.Add("test2", t2)
m2 = r.Save()
if !reflect.DeepEqual(m1, m2) {
t.Errorf("expecetd object %v, got %v", m1, m2)
}
}
func TestRegistry_Reset(t *testing.T) {
var b bool
r := New()
r.Add("test1", "")
r.Reset()
b = r.Empty()
if !b {
t.Errorf("expecetd empty %v, got %v", true, b)
}
}
func TestRegistry_String(t *testing.T) {
var s1, s2 string
r := New()
r.Add("test1", "")
s1 = r.String()
s2 = "Registry(1): [test1]"
if s1 != s2 {
t.Errorf("expecetd string %q, got %q", s2, s1)
}
}
func TestRegistrySerial(t *testing.T) {
const iterations = 1000
var i int
var a [iterations]int
r := New()
for j := 0; j < iterations; j++ {
r.Add(fmt.Sprintf("test%d", j), j)
a[j] = r.Get(fmt.Sprintf("test%d", j)).(int)
}
sort.Ints(a[0:iterations])
i = r.Count()
if i != iterations {
t.Errorf("expecetd count %d, got %d", i, iterations)
}
for i = 0; i < iterations; i++ {
if i != a[i] {
t.Errorf("expecetd integer %d, got %v", i, nil)
}
}
}
func TestRegistryConcurrent(t *testing.T) {
const iterations = 1000
var i int
var a [iterations]int
c := make(chan int)
r := New()
go func() {
for j := 0; j < iterations/2; j++ {
r.Add(fmt.Sprintf("test%d", j), j)
v := r.Get(fmt.Sprintf("test%d", j))
c <- v.(int)
}
}()
go func() {
for j := iterations / 2; j < iterations; j++ {
r.Add(fmt.Sprintf("test%d", j), j)
v := r.Get(fmt.Sprintf("test%d", j))
c <- v.(int)
}
}()
i = 0
for v := range c {
a[i] = v
i++
if i == iterations {
break
}
}
sort.Ints(a[0:iterations])
i = r.Count()
if i != iterations {
t.Errorf("expecetd count %d, got %d", i, iterations)
}
for i := 0; i < iterations; i++ {
if i != a[i] {
t.Errorf("expecetd integer %d, got %v", i, nil)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment