Last active
March 22, 2021 07:00
-
-
Save kwilczynski/8cb77f389a9643a2ab8f93f1737b8361 to your computer and use it in GitHub Desktop.
Generic Registry type using concurrent sharded map in Go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package registry | |
import ( | |
"fmt" | |
"hash/fnv" | |
"sync" | |
) | |
type Registrar interface { | |
Add(string, interface{}) interface{} | |
Contains(string) bool | |
Get(string) interface{} | |
Remove(string) interface{} | |
Names() []string | |
Values() []interface{} | |
Each(EachCallback) | |
Count() int | |
Empty() bool | |
Copy() interface{} | |
Load(map[string]interface{}) | |
Save() map[string]interface{} | |
Reset() | |
String() string | |
} | |
type EachCallback func(string, interface{}) bool | |
const RegistryShards = 32 | |
type Registry []*registry | |
type registry struct { | |
sync.RWMutex | |
items map[string]interface{} | |
} | |
type tuple struct { | |
Key string | |
Value interface{} | |
} | |
func New() *Registry { | |
r := make(Registry, RegistryShards) | |
for i := 0; i < RegistryShards; i++ { | |
r[i] = ®istry{ | |
items: make(map[string]interface{}), | |
} | |
} | |
return &r | |
} | |
func (r *Registry) Add(name string, value interface{}) interface{} { | |
s := r.shard(name) | |
s.Lock() | |
s.items[name] = value | |
s.Unlock() | |
return value | |
} | |
func (r *Registry) Contains(name string) bool { | |
s := r.shard(name) | |
s.RLock() | |
_, found := s.items[name] | |
s.RUnlock() | |
return found | |
} | |
func (r *Registry) Get(name string) interface{} { | |
s := r.shard(name) | |
s.RLock() | |
value := s.items[name] | |
s.RUnlock() | |
return value | |
} | |
func (r *Registry) Remove(name string) interface{} { | |
s := r.shard(name) | |
s.Lock() | |
value := s.items[name] | |
delete(s.items, name) | |
s.Unlock() | |
return value | |
} | |
func (r *Registry) Names() []string { | |
c := r.collect() | |
names := make([]string, cap(c)) | |
i := 0 | |
for t := range c { | |
names[i] = t.Key | |
i++ | |
} | |
return names | |
} | |
func (r *Registry) Values() []interface{} { | |
c := r.collect() | |
values := make([]interface{}, cap(c)) | |
i := 0 | |
for t := range c { | |
values[i] = t.Value | |
i++ | |
} | |
return values | |
} | |
func (r *Registry) Each(function EachCallback) { | |
for t := range r.collect() { | |
if !function(t.Key, t.Value) { | |
break | |
} | |
} | |
} | |
func (r *Registry) Count() int { | |
c := 0 | |
for i := 0; i < RegistryShards; i++ { | |
s := (*r)[i] | |
s.RLock() | |
c += len(s.items) | |
s.RUnlock() | |
} | |
return c | |
} | |
func (r *Registry) Empty() bool { | |
return r.Count() == 0 | |
} | |
func (r *Registry) Copy() interface{} { | |
copy := make(Registry, RegistryShards) | |
for i := 0; i < RegistryShards; i++ { | |
s := (*r)[i] | |
s.RLock() | |
copy[i] = ®istry{ | |
items: make(map[string]interface{}, len(s.items)), | |
} | |
s.RUnlock() | |
} | |
for t := range r.collect() { | |
copy.Add(t.Key, t.Value) | |
} | |
return © | |
} | |
func (r *Registry) Load(data map[string]interface{}) { | |
for k, v := range data { | |
s := r.shard(k) | |
s.Lock() | |
s.items[k] = v | |
s.Unlock() | |
} | |
} | |
func (r *Registry) Save() map[string]interface{} { | |
data := make(map[string]interface{}, r.Count()) | |
for t := range r.collect() { | |
data[t.Key] = t.Value | |
} | |
return data | |
} | |
func (r *Registry) Reset() { | |
for i := 0; i < RegistryShards; i++ { | |
s := (*r)[i] | |
s.Lock() | |
s.items = make(map[string]interface{}) | |
s.Unlock() | |
} | |
} | |
func (r *Registry) String() string { | |
return fmt.Sprintf("Registry(%d): %v", r.Count(), r.Names()) | |
} | |
func (r *Registry) hash(name string) uint32 { | |
h := fnv.New32() | |
if _, err := h.Write([]byte(name)); err != nil { | |
panic("unreachable") | |
} | |
return h.Sum32() | |
} | |
func (r *Registry) shard(name string) *registry { | |
return (*r)[r.hash(name)&uint32(RegistryShards-1)] | |
} | |
func (r *Registry) collect() <-chan tuple { | |
count := r.Count() | |
c := make(chan tuple, count) | |
go func() { | |
wg := sync.WaitGroup{} | |
wg.Add(RegistryShards) | |
for _, shard := range *r { | |
go func(s *registry) { | |
s.RLock() | |
for k, v := range s.items { | |
c <- tuple{k, v} | |
} | |
s.RUnlock() | |
wg.Done() | |
}(shard) | |
} | |
wg.Wait() | |
close(c) | |
}() | |
return c | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package registry | |
import ( | |
"fmt" | |
"reflect" | |
"sort" | |
"testing" | |
) | |
func TestRegistry_New(t *testing.T) { | |
var b bool | |
r := New() | |
if r == nil { | |
t.Fatalf("expecetd type %T, got %T", &Registry{}, r) | |
} | |
func(v interface{}) { | |
if _, ok := v.(*Registry); !ok { | |
t.Fatalf("expected type %T, got %T", &Registry{}, v) | |
} | |
}(r) | |
b = r.Empty() | |
if !b { | |
t.Errorf("expected empty %v, got %v", true, b) | |
} | |
} | |
func TestRegistry_Add(t *testing.T) { | |
var i int | |
var s string | |
var v interface{} | |
r := New() | |
r.Add("test1", "") | |
r.Add("test2", "") | |
i = r.Count() | |
if i != 2 { | |
t.Errorf("expecetd count %d, got %d", 2, i) | |
} | |
v = r.Add("test2", "test456") | |
if _, ok := v.(string); !ok { | |
t.Errorf("expected type %T, got %T", "", v) | |
} | |
s = v.(string) | |
if s != "test456" { | |
t.Errorf("expected string %q, got %q", "test456", s) | |
} | |
i = r.Count() | |
if i != 2 { | |
t.Errorf("expecetd count %d, got %d", 2, i) | |
} | |
s = r.Add("test3", "test123").(string) | |
if s != "test123" { | |
t.Errorf("expected string %q, got %q", "test123", s) | |
} | |
i = r.Count() | |
if i != 3 { | |
t.Errorf("expecetd count %d, got %d", 3, i) | |
} | |
} | |
func TestRegistry_Contains(t *testing.T) { | |
var b bool | |
r := New() | |
r.Add("test1", "") | |
r.Add("test3", "") | |
b = r.Contains("test1") | |
if !b { | |
t.Errorf("expected contains %v, got %v", true, b) | |
} | |
b = r.Contains("test2") | |
if b { | |
t.Errorf("expected contains %v, got %v", false, b) | |
} | |
} | |
func TestRegistry_Get(t *testing.T) { | |
var i int | |
var s string | |
var v interface{} | |
type type1 struct { | |
value int | |
} | |
type type2 struct { | |
value string | |
} | |
t1 := type1{ | |
value: 123, | |
} | |
t2 := type2{ | |
value: "test123", | |
} | |
r := New() | |
r.Add("test1", t1) | |
r.Add("test2", t2) | |
v = r.Get("test1") | |
if _, ok := v.(type1); !ok { | |
t.Errorf("expected type %T, got %T", type1{}, v) | |
} | |
if !reflect.DeepEqual(v.(type1), t1) { | |
t.Errorf("expected object %v, got %v", t1, v) | |
} | |
v = r.Get("test2") | |
if _, ok := v.(type2); !ok { | |
t.Errorf("expected type %T, got %T", type2{}, v) | |
} | |
if !reflect.DeepEqual(v.(type2), t2) { | |
t.Errorf("expected object %v, got %v", t2, v) | |
} | |
s = v.(type2).value | |
if s != "test123" { | |
t.Errorf("expecetd string %q, got %q", "test123", s) | |
} | |
r.Add("test2", t1) | |
v = r.Get("test2") | |
if _, ok := v.(type2); ok { | |
t.Errorf("expected type %T, got %T", type1{}, v) | |
} | |
i = v.(type1).value | |
if i != 123 { | |
t.Errorf("expecetd integer %d, got %d", 123, i) | |
} | |
} | |
func TestRegistry_Remove(t *testing.T) { | |
var i int | |
var b bool | |
var s string | |
var v interface{} | |
r := New() | |
r.Add("test1", "test123") | |
r.Add("test2", "") | |
v = r.Remove("test1") | |
if _, ok := v.(string); !ok { | |
t.Errorf("") | |
} | |
s = v.(string) | |
if s != "test123" { | |
t.Errorf("expecetd string %q, got %q", "test123", s) | |
} | |
v = r.Remove("test3") | |
if v != nil { | |
t.Errorf("expecetd object %v, got %v", nil, v) | |
} | |
b = r.Empty() | |
if b { | |
t.Errorf("expected empty %v, got %v", false, b) | |
} | |
i = r.Count() | |
if i != 1 { | |
t.Errorf("expecetd count %d, got %d", 1, i) | |
} | |
b = r.Contains("test1") | |
if b { | |
t.Errorf("expected contains %v, got %v", false, b) | |
} | |
} | |
func TestRegistry_Names(t *testing.T) { | |
var ss1, ss2 []string | |
r := New() | |
r.Add("test1", "") | |
r.Add("test3", "") | |
r.Add("test5", "") | |
ss1 = r.Names() | |
sort.Strings(ss1) | |
ss2 = []string{"test1", "test3", "test5"} | |
if !reflect.DeepEqual(ss1, ss2) { | |
t.Errorf("expected object %v, got %v", ss2, ss1) | |
} | |
} | |
func TestRegistry_Values(t *testing.T) { | |
var si1, si2 []interface{} | |
type type1 struct { | |
value int | |
} | |
type type2 struct { | |
value string | |
} | |
t1 := type1{ | |
value: 123, | |
} | |
t2 := type2{ | |
value: "test123", | |
} | |
r := New() | |
r.Add("test1", t1) | |
r.Add("test2", t2) | |
// Order of elements in a map is random, | |
// therefore we normalize to ensure that | |
// the integer value is always first. | |
si1 = r.Values() | |
if _, ok := si1[0].(type2); ok { | |
si1 = append([]interface{}{si1[1]}, si1...) | |
si1 = si1[:len(si1)-1] | |
} | |
si2 = []interface{}{t1, t2} | |
if !reflect.DeepEqual(si1, si2) { | |
t.Errorf("expected object %v, got %v", si2, si1) | |
} | |
} | |
func TestRegistry_Each(t *testing.T) { | |
var i int | |
var ss1, ss2 []string | |
r := New() | |
r.Add("test1", "") | |
r.Add("test2", "") | |
r.Each(func(k string, v interface{}) bool { | |
ss1 = append(ss1, k) | |
return true | |
}) | |
sort.Strings(ss1) | |
ss2 = []string{"test1", "test2"} | |
if !reflect.DeepEqual(ss1, ss2) { | |
t.Errorf("expecetd object %v, got %v", ss2, ss1) | |
} | |
i = 0 | |
r.Each(func(k string, v interface{}) bool { | |
i++ | |
return false | |
}) | |
if i != 1 { | |
t.Errorf("expecetd integer %d, got %d", 1, i) | |
} | |
} | |
func TestRegistry_Count(t *testing.T) { | |
var i int | |
r := New() | |
i = r.Count() | |
if i != 0 { | |
t.Errorf("expecetd count %d, got %d", 0, i) | |
} | |
r.Add("test1", "") | |
r.Add("test2", "") | |
i = r.Count() | |
if i != 2 { | |
t.Errorf("expecetd count %d, got %d", 2, i) | |
} | |
} | |
func TestRegistry_Empty(t *testing.T) { | |
var b bool | |
r := New() | |
b = r.Empty() | |
if !b { | |
t.Errorf("expecetd empty %v, got %v", true, b) | |
} | |
r.Add("test1", "") | |
b = r.Empty() | |
if b { | |
t.Errorf("expecetd empty %v, got %v", false, b) | |
} | |
} | |
func TestRegistry_Copy(t *testing.T) { | |
var b bool | |
var r1, r2 *Registry | |
var v interface{} | |
type type1 struct { | |
value int | |
} | |
type type2 struct { | |
value string | |
} | |
t1 := type1{ | |
value: 123, | |
} | |
t2 := type2{ | |
value: "test123", | |
} | |
r1 = New() | |
r1.Add("test1", t1) | |
r1.Add("test2", t2) | |
v = r1.Copy() | |
if _, ok := v.(*Registry); !ok { | |
t.Fatalf("expected type %T, got %T", &Registry{}, v) | |
} | |
r2 = v.(*Registry) | |
b = reflect.DeepEqual(r1, r2) | |
if !b { | |
t.Errorf("expecetd object %v, got %v", true, b) | |
} | |
r1.Reset() | |
b = reflect.DeepEqual(r1, r2) | |
if b { | |
t.Errorf("expecetd object %v, got %v", false, b) | |
} | |
} | |
func TestRegistry_Load(t *testing.T) { | |
var i int | |
var v interface{} | |
type type1 struct { | |
value int | |
} | |
type type2 struct { | |
value string | |
} | |
t1 := type1{ | |
value: 123, | |
} | |
t2 := type2{ | |
value: "test123", | |
} | |
m := map[string]interface{}{ | |
"test1": t1, | |
"test2": t2, | |
} | |
r := New() | |
r.Load(m) | |
i = r.Count() | |
if i != 2 { | |
t.Errorf("expecetd count %d, got %d", 2, i) | |
} | |
v = r.Get("test1") | |
if _, ok := v.(type1); !ok { | |
t.Errorf("expected type %T, got %T", type1{}, v) | |
} | |
if !reflect.DeepEqual(v.(type1), t1) { | |
t.Errorf("expected object %v, got %v", t1, v) | |
} | |
} | |
func TestRegistry_Save(t *testing.T) { | |
var m1, m2 map[string]interface{} | |
type type1 struct { | |
value int | |
} | |
type type2 struct { | |
value string | |
} | |
t1 := type1{ | |
value: 123, | |
} | |
t2 := type2{ | |
value: "test123", | |
} | |
m1 = map[string]interface{}{ | |
"test1": t1, | |
"test2": t2, | |
} | |
r := New() | |
r.Add("test1", t1) | |
r.Add("test2", t2) | |
m2 = r.Save() | |
if !reflect.DeepEqual(m1, m2) { | |
t.Errorf("expecetd object %v, got %v", m1, m2) | |
} | |
} | |
func TestRegistry_Reset(t *testing.T) { | |
var b bool | |
r := New() | |
r.Add("test1", "") | |
r.Reset() | |
b = r.Empty() | |
if !b { | |
t.Errorf("expecetd empty %v, got %v", true, b) | |
} | |
} | |
func TestRegistry_String(t *testing.T) { | |
var s1, s2 string | |
r := New() | |
r.Add("test1", "") | |
s1 = r.String() | |
s2 = "Registry(1): [test1]" | |
if s1 != s2 { | |
t.Errorf("expecetd string %q, got %q", s2, s1) | |
} | |
} | |
func TestRegistrySerial(t *testing.T) { | |
const iterations = 1000 | |
var i int | |
var a [iterations]int | |
r := New() | |
for j := 0; j < iterations; j++ { | |
r.Add(fmt.Sprintf("test%d", j), j) | |
a[j] = r.Get(fmt.Sprintf("test%d", j)).(int) | |
} | |
sort.Ints(a[0:iterations]) | |
i = r.Count() | |
if i != iterations { | |
t.Errorf("expecetd count %d, got %d", i, iterations) | |
} | |
for i = 0; i < iterations; i++ { | |
if i != a[i] { | |
t.Errorf("expecetd integer %d, got %v", i, nil) | |
} | |
} | |
} | |
func TestRegistryConcurrent(t *testing.T) { | |
const iterations = 1000 | |
var i int | |
var a [iterations]int | |
c := make(chan int) | |
r := New() | |
go func() { | |
for j := 0; j < iterations/2; j++ { | |
r.Add(fmt.Sprintf("test%d", j), j) | |
v := r.Get(fmt.Sprintf("test%d", j)) | |
c <- v.(int) | |
} | |
}() | |
go func() { | |
for j := iterations / 2; j < iterations; j++ { | |
r.Add(fmt.Sprintf("test%d", j), j) | |
v := r.Get(fmt.Sprintf("test%d", j)) | |
c <- v.(int) | |
} | |
}() | |
i = 0 | |
for v := range c { | |
a[i] = v | |
i++ | |
if i == iterations { | |
break | |
} | |
} | |
sort.Ints(a[0:iterations]) | |
i = r.Count() | |
if i != iterations { | |
t.Errorf("expecetd count %d, got %d", i, iterations) | |
} | |
for i := 0; i < iterations; i++ { | |
if i != a[i] { | |
t.Errorf("expecetd integer %d, got %v", i, nil) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment