Last active
April 7, 2021 02:12
-
-
Save funny-falcon/084f7ebc159401dfd56bf33241a2ebd6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HashChTbl(K, V) | |
include Enumerable({K, V}) | |
include Iterable({K, V}) | |
getter size : Int32 = 0 | |
# index into SIZES array | |
@sz : UInt8 = 0_u8 | |
@rebuild_num : UInt16 = 0_u16 | |
@first : UInt32 = 0_u32 | |
@last : UInt32 = 0_u32 | |
# bins contains both elements indices and hashsums | |
@bins : Pointer(UInt32) = Pointer(UInt32).new(0) | |
@hashes: Pointer(UInt64) = Pointer(UInt64).new(0) | |
@entries : Pointer(Entry(K, V)) = Pointer(Entry(K, V)).new(0) | |
@block : (self, K -> V)? | |
def initialize(block : (self, K -> V)? = nil, initial_capacity : Int? = nil) | |
unless initial_capacity.nil? | |
resize_data(new_sz(initial_capacity)) | |
end | |
@block = block | |
end | |
def self.new(inital_capacity = nil, &block : (self, K -> V)) | |
new(block, initial_capacity: initial_capacity) | |
end | |
def self.new(default_value : V, initial_capacity = nil) | |
new(initial_capacity: initial_capacity){ default_value } | |
end | |
def []=(key : K, val : V) | |
hash = hash_key(key) | |
ent, _ = find_entry(hash, key) | |
if ent.null? | |
idx = push_entry(hash, key, val) | |
insert_entry_reuse(idx) | |
else | |
ent.value.value = val | |
end | |
val | |
end | |
def [](key) | |
fetch(key) | |
end | |
def []?(key) | |
fetch(key, nil) | |
end | |
def has_key?(key) | |
hash = hash_key(key) | |
entry, _ = find_entry(hash, key) | |
!entry.null? | |
end | |
def has_value?(val) | |
each_value do |value| | |
return true if value == val | |
end | |
false | |
end | |
def fetch(key) | |
fetch(key) do | |
if (block = @block) && key.is_a?(K) | |
block.call(self, key.as(K)) | |
else | |
raise KeyError.new "Missing hash key: #{key.inspect}" | |
end | |
end | |
end | |
def fetch(key, default) | |
fetch(key) { default } | |
end | |
def fetch(key) | |
hash = hash_key(key) | |
entry, _ = find_entry(hash, key) | |
entry ? entry.value.value : yield key | |
end | |
def values_at(*indexes : K) | |
indexes.map { |index| self[index] } | |
end | |
def key(value) | |
key(value) { raise KeyError.new "Missing hash key for value: #{value}" } | |
end | |
def key?(value) | |
key(value) { nil } | |
end | |
def delete(key) | |
delete(key) { nil } | |
end | |
def delete(key) | |
hash = hash_key(key) | |
entry, idx = find_entry(hash, key) | |
unless entry.null? | |
value = entry.value.value | |
delete_idx(idx) | |
value | |
else | |
yield key | |
end | |
end | |
def delete_if | |
iter_entries do |entry, idx| | |
if yield(entry.value.key, entry.value.value) | |
delete_idx(idx) | |
end | |
end | |
self | |
end | |
def empty? | |
@size == 0 | |
end | |
def each : Nil | |
iter_entries do |entry, _| | |
yield({entry.value.key, entry.value.value}) | |
end | |
end | |
def each | |
EntryIterator(K, V).new(self, @first, @rebuild_num) | |
end | |
def each_key | |
iter_entries do |entry, _| | |
yield entry.value.key | |
end | |
end | |
def each_key | |
KeyIterator(K, V).new(self, @first, @rebuild_num) | |
end | |
def each_value | |
iter_entries do |entry, _| | |
yield entry.value.value | |
end | |
end | |
def each_value | |
ValueIterator(K, V).new(self, @first, @rebuild_num) | |
end | |
def keys | |
keys = Array(K).new(@size) | |
each_key { |key| keys << key } | |
keys | |
end | |
def values | |
values = Array(V).new(@size) | |
each_value { |value| values << value } | |
values | |
end | |
def key_index(key) | |
hash = hash_key(key) | |
entry, idx = find_entry(hash, key) | |
if entry.null? | |
nil | |
elsif @last - @first == @size | |
idx - @first | |
else | |
index = 0 | |
@first.upto(idx) do |i| | |
index += @hashes[i] != 0_u64 ? 1 : 0 | |
end | |
index | |
end | |
end | |
def merge(other : Hash(L, W)) forall L, W | |
hash = Hash(K | L, V | W).new | |
hash.merge! self | |
hash.merge! other | |
hash | |
end | |
def merge(other : Hash(L, W), &block : K, V, W -> V | W) forall L, W | |
hash = Hash(K | L, V | W).new | |
hash.merge! self | |
hash.merge!(other) { |k, v1, v2| yield k, v1, v2 } | |
hash | |
end | |
def merge!(other : Hash) | |
other.each do |k, v| | |
self[k] = v | |
end | |
self | |
end | |
def merge!(other : Hash, &block) | |
other.each do |k, v| | |
if self.has_key?(k) | |
self[k] = yield k, self[k], v | |
else | |
self[k] = v | |
end | |
end | |
self | |
end | |
def select(&block : K, V -> _) | |
reject { |k, v| !yield(k, v) } | |
end | |
def select!(&block : K, V -> _) | |
reject! { |k, v| !yield(k, v) } | |
end | |
def reject(&block : K, V -> _) | |
each_with_object({} of K => V) do |(k, v), memo| | |
memo[k] = v unless yield k, v | |
end | |
end | |
def reject!(&block : K, V -> _) | |
num_entries = @size | |
delete_if &block | |
num_entries == @size ? nil : self | |
end | |
def reject(*keys) | |
hash = self.dup | |
hash.reject!(*keys) | |
end | |
def reject!(keys : Array | Tuple) | |
keys.each { |k| delete(k) } | |
self | |
end | |
def reject!(*keys) | |
reject!(keys) | |
end | |
def select(keys : Array | Tuple) | |
hash = {} of K => V | |
keys.each { |k| hash[k] = self[k] if has_key?(k) } | |
hash | |
end | |
def select(*keys) | |
self.select(keys) | |
end | |
def select!(keys : Array | Tuple) | |
each { |k, v| delete(k) unless keys.includes?(k) } | |
self | |
end | |
def select!(*keys) | |
select!(keys) | |
end | |
def compact | |
each_with_object({} of K => typeof(self.first_value.not_nil!)) do |(key, value), memo| | |
memo[key] = value unless value.nil? | |
end | |
end | |
def compact! | |
reject! { |key, value| value.nil? } | |
end | |
def self.zip(ary1 : Array(K), ary2 : Array(V)) | |
hash = {} of K => V | |
ary1.each_with_index do |key, i| | |
hash[key] = ary2[i] | |
end | |
hash | |
end | |
def first_key | |
nil.not_nil! if empty? | |
@entries[@first].key | |
end | |
def first_key? | |
unless empty? | |
@entries[@first].key | |
else | |
nil | |
end | |
end | |
def first_value | |
nil.not_nil! if empty? | |
@entries[@first].value | |
end | |
def first_value? | |
unless empty? | |
@entries[@first].value | |
else | |
nil | |
end | |
end | |
def shift | |
shift { raise IndexError.new } | |
end | |
def shift? | |
shift { nil } | |
end | |
def shift | |
unless empty? | |
idx = @first | |
res = {@entries[idx].key, @entries[idx].value} | |
delete_idx(idx) | |
res | |
else | |
yield | |
end | |
end | |
def clear | |
resize_data(0) | |
@size = 0 | |
@first = nil | |
@last = nil | |
self | |
end | |
def ==(other : Hash) | |
return false unless size == other.size | |
each do |key, value| | |
entry = other.find_entry(key) | |
return false unless entry && entry.value == value | |
end | |
true | |
end | |
def hash(hasher) | |
# The hash value must be the same regardless of the | |
# order of the keys. | |
result = hasher.result | |
each do |key, value| | |
copy = hasher | |
copy = key.hash(copy) | |
copy = value.hash(copy) | |
result += copy.result | |
end | |
result.hash(hasher) | |
end | |
def dup | |
copy = super | |
copy.init_dup(self) | |
end | |
protected def init_dup(original) | |
bins = nbins(@sz) | |
entries = nentries(@sz) | |
@bins = Pointer(UInt32).malloc(bins) | |
@hashes = Pointer(UInt64).malloc(entries) | |
@entries = Pointer(Entry(K, V)).malloc(entries) | |
@bins.copy_from(original.@bins, bins) | |
@hashes.copy_from(original.@hashes, entries) | |
@entries.copy_from(original.@entries, entries) | |
self | |
end | |
def clone | |
copy = dup | |
copy.init_clone | |
end | |
protected def init_clone | |
iter_entries do |entry, _| | |
entry.value.key = entry.value.key.clone | |
entry.value.value = entry.value.value.clone | |
end | |
self | |
end | |
def inspect(io : IO) | |
to_s(io) | |
end | |
def to_s(io : IO) | |
executed = exec_recursive(:to_s) do | |
io << "{" | |
found_one = false | |
each do |key, value| | |
io << ", " if found_one | |
key.inspect(io) | |
io << " => " | |
value.inspect(io) | |
found_one = true | |
end | |
io << "}" | |
end | |
io << "{...}" unless executed | |
end | |
def pretty_print(pp) : Nil | |
executed = exec_recursive(:pretty_print) do | |
pp.list("{", self, "}") do |key, value| | |
pp.group do | |
key.pretty_print(pp) | |
pp.text " =>" | |
pp.nest do | |
pp.breakable | |
value.pretty_print(pp) | |
end | |
end | |
end | |
end | |
pp.text "{...}" unless executed | |
end | |
# Returns `self`. | |
def to_h | |
self | |
end | |
def invert | |
hash = Hash(V, K).new(initial_capacity: @size) | |
each do |k, v| | |
hash[v] = k | |
end | |
hash | |
end | |
@[AlwaysInline] | |
private def iter_entries | |
return if empty? | |
rnum = @rebuild_num | |
@first.upto(@last-1) do |idx| | |
if @hashes[idx] != 0_u64 | |
yield (@entries + idx), idx | |
raise "Hash modified during iteration" unless rnum == @rebuild_num | |
end | |
end | |
end | |
@[AlwaysInline] | |
protected def iter_bins(hash : UInt64) | |
mask = binmask(@sz) | |
pos = hash & mask | |
mix = hash | |
d = 1_u64 | |
while true | |
yield pos.to_u32 | |
pos = (pos + d) & mask | |
mix >>= 8 | |
d += (1_u64 + mix) & mask | |
end | |
end | |
@[AlwaysInline] | |
private def iter_search(hash, key : K) | |
return if empty? | |
if binmask(@sz) == 0 | |
@first.upto(@last-1) do |idx| | |
yield idx unless @hashes[idx] == 0 | |
end | |
else | |
iter_bins(hash) do |pos| | |
idx = @bins[pos] | |
break if idx == 0 | |
idx -= 1 | |
yield idx unless @hashes[idx] == 0 | |
end | |
end | |
end | |
private def find_entry(hash, key : K) : {Pointer(Entry(K, V)), UInt32} | |
iter_search(hash, key) do |idx| | |
if @hashes[idx] == hash && @entries[idx].key == key | |
return @entries+idx, idx | |
end | |
end | |
{Pointer(Entry(K, V)).new(0), 0_u32} | |
end | |
protected def insert_entry_simple(idx : UInt32) | |
iter_bins(@hashes[idx]) do |pos| | |
if @bins[pos] == 0 | |
@bins[pos] = idx+1 | |
break | |
end | |
end | |
end | |
protected def insert_entry_reuse(idx : UInt32) | |
return if binmask(@sz) == 0 | |
reuse = @size != @last | |
iter_bins(@hashes[idx]) do |pos| | |
oidx = @bins[pos] | |
if oidx == 0 || (reuse && @hashes[oidx-1] == 0_u64) | |
@bins[pos] = idx+1 | |
break | |
end | |
end | |
end | |
def hash_key(key) | |
h = key.hash.to_u64 | |
h | (1_u64 << 63) | |
end | |
struct Entry(K, V) | |
property key : K | |
property value : V | |
def initialize(@key : K, @value : V) | |
end | |
end | |
private def resize_data(newsz) | |
oldsz = @sz | |
old_bins = nbins(oldsz) | |
new_bins = nbins(newsz) | |
old_ents = nentries(oldsz) | |
new_ents = nentries(newsz) | |
@hashes = @hashes.realloc(new_ents) | |
@entries = @entries.realloc(new_ents) | |
if new_bins != old_bins | |
@bins = @bins.realloc(new_bins) | |
end | |
if new_ents > old_ents | |
(@entries + old_ents).clear(new_ents - old_ents) | |
end | |
@sz = newsz | |
end | |
private def push_entry(hash : UInt64, key : K, val : V) : UInt32 | |
if @last == nentries(@sz) | |
rehash | |
end | |
idx = @last | |
@hashes[idx] = hash | |
entry = @entries + idx | |
entry.value.key = key | |
entry.value.value = val | |
@last += 1 | |
@size += 1 | |
idx | |
end | |
private def rehash | |
@rebuild_num += 1_u16 | |
if need_shrink(@size, @sz) | |
reclaim_without_bins | |
if need_shrink(@size, @sz-1) | |
resize_data(@sz-1) | |
end | |
if binmask(@sz) != 0 | |
fix_bins | |
end | |
elsif nentries(@sz+1) == 0 | |
raise "Hash table too big" | |
else | |
resize_data(@sz+1) | |
if binmask(@sz) != binmask(@sz-1) | |
reclaim_without_bins | |
fix_bins | |
end | |
end | |
end | |
private def need_shrink(size : Int32, sz : UInt8) : Bool | |
sz > 1 && size < nentries(sz)/2 | |
end | |
private def reclaim_without_bins | |
pos = 0_u32 | |
unless empty? | |
idx = @first | |
if @first == 0_u32 | |
if @last == @size | |
pos = idx = @last | |
else | |
while @hashes[idx] != 0_u64 | |
idx += 1 | |
end | |
pos = idx | |
end | |
end | |
idx.upto(@last-1) do |idx| | |
unless @hashes[idx] == 0_u64 | |
@entries[pos] = @entries[idx] | |
@hashes[pos] = @hashes[idx] | |
pos+=1 | |
end | |
end | |
end | |
(@entries + pos).clear(@last - pos) | |
@first = 0_u32 | |
@last = pos | |
end | |
private def fix_bins | |
@bins.clear(nbins(@sz)) | |
return if empty? | |
@first.upto(@last-1) do |idx| | |
insert_entry_simple(idx) | |
end | |
end | |
private def delete_idx(idx) | |
@hashes[idx] = 0_u64 | |
(@entries + idx).clear | |
@size -= 1 | |
if @first == idx | |
if @size == 0 | |
@first = @last | |
else | |
idx += 1 | |
while @hashes[idx] == 0_u64 | |
idx += 1 | |
end | |
@first = idx | |
end | |
end | |
end | |
private def new_sz(size) | |
(1...SIZES.size).each do |i| | |
return i if SIZES[i].nentries >= size | |
end | |
raise "Hash table too big" | |
end | |
private def nbins(sz) | |
SIZES[sz].binmask+1 | |
end | |
private def binmask(sz) | |
SIZES[sz].binmask | |
end | |
private def nentries(sz) | |
SIZES[sz].nentries | |
end | |
alias PRIMITIVES = Int::Primitive | Float::Primitive | Symbol | |
alias INT_SMALL = Int8 | Int16 | Int32 | UInt8 | UInt16 | UInt32 | |
def primkey? | |
{{ K <= HashTbl::PRIMITIVES || K <= Enum }} | |
end | |
private module BaseIterator | |
def initialize(@hash, @current, @rebuild_num) | |
end | |
def base_next | |
if @hash.@rebuild_num != @rebuild_num | |
raise "Hash modified during iteration" | |
end | |
while @current < @hash.@last | |
if @hash.@hashes[@current] != 0_u64 | |
value = yield (@hash.@entries + @current) | |
@current += 1_u32 | |
return value | |
end | |
@current += 1_u32 | |
end | |
stop | |
end | |
def rewind | |
@current = @hash.@first | |
end | |
end | |
private class EntryIterator(K, V) | |
include BaseIterator | |
include Iterator({K, V}) | |
@hash : HashChTbl(K, V) | |
@current : UInt32 | |
@rebuild_num : UInt16 | |
def next | |
base_next { |entry| {entry.value.key, entry.value.value} } | |
end | |
end | |
private class KeyIterator(K, V) | |
include BaseIterator | |
include Iterator(K) | |
@hash : HashChTbl(K, V) | |
@current : UInt32 | |
@rebuild_num : UInt16 | |
def next | |
base_next &.value.key | |
end | |
end | |
private class ValueIterator(K, V) | |
include BaseIterator | |
include Iterator(V) | |
@hash : HashChTbl(K, V) | |
@current : UInt32 | |
@rebuild_num : UInt16 | |
def next | |
base_next &.value.value | |
end | |
end | |
record SizeItem, nentries : UInt32, binmask : UInt32 | |
{% begin %} | |
SIZES = StaticArray[ | |
SizeItem.new(0_u32, 0_u32), | |
SizeItem.new(8_u32, 0_u32), | |
SizeItem.new(16_u32, 31_u32), | |
{% for i in 5..30 %} | |
{% p = 1<<i %} | |
SizeItem.new({{p-p/4}}_u32, {{p-1}}_u32), | |
SizeItem.new({{p}}_u32, {{p*2-1}}_u32), | |
{% end %} | |
SizeItem.new(0_u32, 0_u32), | |
] | |
{% end %} | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment