Created
August 4, 2018 02:33
-
-
Save Zhentar/eac2d9078860c29c58575e04fbe1deca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Runtime.CompilerServices; | |
public abstract class BaseDictionary<TKey, TValue> | |
{ | |
protected BaseDictionary() { } | |
protected BaseDictionary(int initialSize) => Initialize((int)Math.Ceiling(Math.Log(initialSize, 2))); | |
protected interface IKeyHandler<TKey> | |
{ | |
uint MapKeyToBucket(TKey key); | |
bool KeysEqual(TKey lhs, TKey rhs); | |
} | |
protected struct EquatableKeyHandler : IKeyHandler<TKey> | |
{ | |
public uint MapKeyToBucket(TKey key) => (uint)EqualityComparer<TKey>.Default.GetHashCode(key); | |
public bool KeysEqual(TKey lhs, TKey rhs) => EqualityComparer<TKey>.Default.Equals(lhs, rhs); | |
} | |
private struct Entry | |
{ | |
public TKey Key; | |
public TValue Value; | |
} | |
private int _sizePowTwo; | |
private uint _maxDisplacement; | |
private uint _mask; | |
private Entry[] _entries; | |
protected void Initialize(int powerOf2Size) | |
{ | |
_sizePowTwo = powerOf2Size; | |
_maxDisplacement = (uint) powerOf2Size; | |
_mask = (1u << powerOf2Size) - 1; | |
_entries = new Entry[(1 << powerOf2Size) + _maxDisplacement + 1]; | |
} | |
protected uint EntryCount { get; private set; } | |
protected void RemoveAll() => Initialize(_sizePowTwo); | |
protected bool RemoveEntry<THandler>(TKey key, THandler keyHandler) where THandler : IKeyHandler<TKey> | |
{ | |
ref TValue value = ref FindEntry(key, false, out var found, keyHandler); | |
if(found) { value = default; EntryCount--;} | |
return found; | |
} | |
protected ref TValue FindEntry<THandler>(TKey key, bool insertIfNotFound, out bool found, THandler keyHandler) where THandler : IKeyHandler<TKey> | |
{ | |
var index = keyHandler.MapKeyToBucket(key); | |
ref var baseEntry = ref _entries[index & _mask]; | |
var maxDisplacement = (IntPtr) _maxDisplacement; | |
for (IntPtr j = (IntPtr)0; j != maxDisplacement; j += 1) | |
{ | |
ref var checkEntry = ref Unsafe.Add(ref baseEntry, j); | |
if (keyHandler.KeysEqual(checkEntry.Key, key)) | |
{ | |
found = true; | |
return ref checkEntry.Value; | |
} | |
if (EqualityComparer<TValue>.Default.Equals(checkEntry.Value, default)) | |
{ | |
found = false; | |
return ref (insertIfNotFound ? ref InsertEntry(key, (uint) j, keyHandler) : ref checkEntry.Value); | |
} | |
} | |
found = false; | |
return ref (insertIfNotFound ? ref InsertEntry(key, null, keyHandler) : ref baseEntry.Value); | |
} | |
private static bool _dummy; //needed to satisfy recursive FindEntry call | |
private ref TValue InsertEntry<THandler>(TKey key, uint? index, THandler keyHandler) where THandler : IKeyHandler<TKey> | |
{ | |
if(index == null) | |
{ | |
Resize(keyHandler); | |
return ref FindEntry(key, true, out _dummy, keyHandler); | |
} | |
EntryCount++; | |
ref Entry entry = ref _entries[index.Value]; | |
entry.Key = key; | |
return ref entry.Value; | |
} | |
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |
private static uint FindFirstOpenSlot(Entry[] entries, uint expectedPos) | |
{ | |
for (uint j = expectedPos; j < entries.Length; j++) | |
{ | |
if (EqualityComparer<TValue>.Default.Equals(entries[j].Value, default)) { return j; } | |
} | |
throw new InvalidOperationException(); | |
} | |
private void Resize<THandler>(THandler keyHandler) where THandler : IKeyHandler<TKey> | |
{ | |
var oldEntries = _entries; | |
Initialize(_sizePowTwo + 1); | |
var mask = _mask; | |
var entries = _entries; | |
for (int i = 0; i < oldEntries.Length; i++) | |
{ | |
var newSlot = FindFirstOpenSlot(entries, keyHandler.MapKeyToBucket(oldEntries[i].Key) & mask); | |
entries[newSlot] = oldEntries[i]; | |
} | |
} | |
public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>> | |
{ | |
private readonly BaseDictionary<TKey, TValue> _parent; | |
private uint _index; | |
internal Enumerator(BaseDictionary<TKey, TValue> parent) | |
{ | |
Current = default; | |
_index = 0; | |
_parent = parent; | |
} | |
public bool MoveNext() | |
{ | |
Entry[] entries = _parent._entries; | |
for (uint index = _index; index < entries.Length; index++) | |
{ | |
if (!EqualityComparer<TValue>.Default.Equals(entries[index].Value, default)) | |
{ | |
Current = new KeyValuePair<TKey, TValue>(entries[index].Key, entries[index].Value); | |
_index = index + 1; | |
return true; | |
} | |
} | |
_index = uint.MaxValue; | |
Current = new KeyValuePair<TKey, TValue>(default, default); | |
return false; | |
} | |
public void Reset() | |
{ | |
throw new NotSupportedException(); | |
} | |
public KeyValuePair<TKey, TValue> Current { get; private set; } | |
object IEnumerator.Current => Current; | |
public void Dispose() { } | |
} | |
} | |
public class TransformationCache<TKey, TValue> : BaseDictionary<TKey, TValue> where TKey : IEquatable<TKey> | |
{ | |
public interface ITransformHandler<in TSource> | |
{ | |
TKey KeyForValue(TSource value); | |
TValue Transform(TSource source); | |
} | |
public TValue GetOrAdd<TSource, TTransformHandler>(TSource sourceValue, TTransformHandler transformer) where TTransformHandler : ITransformHandler<TSource> | |
{ | |
ref TValue valueRef = ref FindEntry(transformer.KeyForValue(sourceValue), true, out var found, default(EquatableKeyHandler)); | |
if (!found) | |
{ | |
valueRef = transformer.Transform(sourceValue); | |
} | |
return valueRef; | |
} | |
} | |
public class CountingDictionary<TKey> : BaseDictionary<TKey, int>, IEnumerable<KeyValuePair<TKey, int>> where TKey : IEquatable<TKey> | |
{ | |
IEnumerator<KeyValuePair<TKey, int>> IEnumerable<KeyValuePair<TKey, int>>.GetEnumerator() => GetEnumerator(); | |
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); | |
public Enumerator GetEnumerator() => new Enumerator(this); | |
public CountingDictionary(int initialSize = 64) : base(initialSize) { } | |
public int Increment(TKey key) | |
{ | |
ref int value = ref FindEntry(key, true, out _, default(EquatableKeyHandler)); | |
return ++value; | |
} | |
public int Decrement(TKey key) | |
{ | |
ref int value = ref FindEntry(key, true, out _, default(EquatableKeyHandler)); | |
return --value; | |
} | |
public int this[TKey key] | |
{ | |
get => FindEntry(key, false, out _, default(EquatableKeyHandler)); | |
set | |
{ | |
ref int entry = ref FindEntry(key, true, out _, default(EquatableKeyHandler)); | |
entry = value; | |
} | |
} | |
} | |
public class BoringButFastDictionary<TKey, TValue> : BaseDictionary<TKey, TValue>, IDictionary<TKey, TValue> where TKey : IEquatable<TKey> | |
{ | |
public BoringButFastDictionary(int initialSize = 64) : base(initialSize) { } | |
public TValue AddOrUpdate(TKey key, Func<TKey, TValue> addValueFactory, Func<TKey, TValue, TValue> updateValueFactory) | |
{ | |
ref TValue entryValue = ref FindEntry(key, true, out var found, default(EquatableKeyHandler)); | |
if (found) | |
{ | |
entryValue = updateValueFactory(key, entryValue); | |
} | |
else | |
{ | |
entryValue = addValueFactory(key); | |
} | |
return entryValue; | |
} | |
public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory) | |
{ | |
ref TValue entryValue = ref FindEntry(key, true, out var found, default(EquatableKeyHandler)); | |
if (!found) | |
{ | |
entryValue = valueFactory(key); | |
} | |
return entryValue; | |
} | |
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() => GetEnumerator(); | |
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); | |
public Enumerator GetEnumerator() => new Enumerator(this); | |
public void Add(KeyValuePair<TKey, TValue> item) => this[item.Key] = item.Value; | |
public void Clear() => RemoveAll(); | |
public bool Contains(KeyValuePair<TKey, TValue> item) | |
{ | |
var value = FindEntry(item.Key, false, out var found, default(EquatableKeyHandler)); | |
return found && EqualityComparer<TValue>.Default.Equals(item.Value, value); | |
} | |
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |
{ | |
foreach(var kvp in this) | |
{ | |
array[arrayIndex] = kvp; | |
arrayIndex++; | |
} | |
} | |
public bool Remove(KeyValuePair<TKey, TValue> item) | |
{ | |
if(Contains(item)) | |
{ | |
RemoveEntry(item.Key, default(EquatableKeyHandler)); | |
return true; | |
} | |
return false; | |
} | |
public int Count => (int)EntryCount; | |
public bool IsReadOnly => false; | |
public bool ContainsKey(TKey key) | |
{ | |
FindEntry(key, false, out var found, default(EquatableKeyHandler)); | |
return found; | |
} | |
public void Add(TKey key, TValue value) | |
{ | |
ref TValue valueRef = ref FindEntry(key, true, out var found, default(EquatableKeyHandler)); | |
if (found) { throw new ArgumentException("Duplicate key"); } | |
valueRef = value; | |
} | |
public bool Remove(TKey key) => RemoveEntry(key, default(EquatableKeyHandler)); | |
public bool TryGetValue(TKey key, out TValue value) | |
{ | |
value = FindEntry(key, false, out var found, default(EquatableKeyHandler)); | |
return found; | |
} | |
public TValue this[TKey key] | |
{ | |
get => FindEntry(key, false, out _, default(EquatableKeyHandler)); | |
set | |
{ | |
ref TValue entry = ref FindEntry(key, true, out _, default(EquatableKeyHandler)); | |
entry = value; | |
} | |
} | |
public ICollection<TKey> Keys => this.Select(kvp => kvp.Key).ToArray(); | |
public ICollection<TValue> Values => this.Select(kvp => kvp.Value).ToArray(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment