Created
April 18, 2019 13:55
-
-
Save raizam/eb0fceb69675b54aeafb5669db7aee0f to your computer and use it in GitHub Desktop.
Single file PatriciaTrie extracted from trienet (https://github.com/gmamaladze/trienet/tree/master/TrieNet).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This code is distributed under MIT license. Copyright (c) 2013 George Mamaladze | |
// See license.txt or http://opensource.org/licenses/mit-license.php | |
using System; | |
using System.Collections; | |
using System.Collections.Generic; | |
using System.Diagnostics; | |
using System.Linq; | |
using System.Text; | |
namespace System.Collections.Specialized | |
{ | |
// This code is distributed under MIT license. Copyright (c) 2013 George Mamaladze | |
// See license.txt or http://opensource.org/licenses/mit-license.php | |
public enum MatchKind | |
{ | |
ExactMatch, | |
Contains, | |
IsContained, | |
Partial, | |
} | |
/// <summary> | |
/// Interface to be implemented by a data structure | |
/// which allows adding values <see cref="TValue"/> associated with <b>string</b> keys. | |
/// The interface allows retrieveal of multiple values | |
/// </summary> | |
/// <typeparam name="TValue"></typeparam> | |
public interface ITrie<TValue> | |
{ | |
IEnumerable<TValue> Retrieve(string query); | |
void Add(string key, TValue value); | |
} | |
public abstract class TrieNodeBase<TValue> | |
{ | |
protected abstract int KeyLength { get; } | |
protected abstract IEnumerable<TValue> Values(); | |
protected abstract IEnumerable<TrieNodeBase<TValue>> Children(); | |
public void Add(string key, int position, TValue value) | |
{ | |
if (key == null) throw new ArgumentNullException("key"); | |
if (EndOfString(position, key)) | |
{ | |
AddValue(value); | |
return; | |
} | |
TrieNodeBase<TValue> child = GetOrCreateChild(key[position]); | |
child.Add(key, position + 1, value); | |
} | |
protected abstract void AddValue(TValue value); | |
protected abstract TrieNodeBase<TValue> GetOrCreateChild(char key); | |
protected virtual IEnumerable<TValue> Retrieve(string query, int position) | |
{ | |
return | |
EndOfString(position, query) | |
? ValuesDeep() | |
: SearchDeep(query, position); | |
} | |
protected virtual IEnumerable<TValue> SearchDeep(string query, int position) | |
{ | |
TrieNodeBase<TValue> nextNode = GetChildOrNull(query, position); | |
return nextNode != null | |
? nextNode.Retrieve(query, position + nextNode.KeyLength) | |
: Enumerable.Empty<TValue>(); | |
} | |
protected abstract TrieNodeBase<TValue> GetChildOrNull(string query, int position); | |
private static bool EndOfString(int position, string text) | |
{ | |
return position >= text.Length; | |
} | |
private IEnumerable<TValue> ValuesDeep() | |
{ | |
return | |
Subtree() | |
.SelectMany(node => node.Values()); | |
} | |
protected IEnumerable<TrieNodeBase<TValue>> Subtree() | |
{ | |
return | |
Enumerable.Repeat(this, 1) | |
.Concat(Children().SelectMany(child => child.Subtree())); | |
} | |
} | |
public class PatriciaSuffixTrie<TValue> : ITrie<TValue> | |
{ | |
private readonly int m_MinQueryLength; | |
private readonly PatriciaTrie<TValue> m_InnerTrie; | |
public PatriciaSuffixTrie(int minQueryLength) | |
: this(minQueryLength, new PatriciaTrie<TValue>()) | |
{ | |
} | |
internal PatriciaSuffixTrie(int minQueryLength, PatriciaTrie<TValue> innerTrie) | |
{ | |
m_MinQueryLength = minQueryLength; | |
m_InnerTrie = innerTrie; | |
} | |
protected int MinQueryLength | |
{ | |
get { return m_MinQueryLength; } | |
} | |
public IEnumerable<TValue> Retrieve(string query) | |
{ | |
return | |
m_InnerTrie | |
.Retrieve(query) | |
.Distinct(); | |
} | |
public void Add(string key, TValue value) | |
{ | |
IEnumerable<StringPartition> allSuffixes = GetAllSuffixes(MinQueryLength, key); | |
foreach (StringPartition currentSuffix in allSuffixes) | |
{ | |
m_InnerTrie.Add(currentSuffix, value); | |
} | |
} | |
private static IEnumerable<StringPartition> GetAllSuffixes(int minSuffixLength, string word) | |
{ | |
for (int i = word.Length - minSuffixLength; i >= 0; i--) | |
{ | |
yield return new StringPartition(word, i); | |
} | |
} | |
} | |
public class PatriciaTrie<TValue> : | |
PatriciaTrieNode<TValue>, | |
ITrie<TValue> | |
{ | |
public PatriciaTrie() | |
: base( | |
new StringPartition(string.Empty), | |
new Queue<TValue>(), | |
new Dictionary<char, PatriciaTrieNode<TValue>>()) | |
{ | |
} | |
public IEnumerable<TValue> Retrieve(string query) | |
{ | |
return Retrieve(query, 0); | |
} | |
public virtual void Add(string key, TValue value) | |
{ | |
if (key == null) throw new ArgumentNullException("key"); | |
Add(new StringPartition(key), value); | |
} | |
internal override void Add(StringPartition keyRest, TValue value) | |
{ | |
GetOrCreateChild(keyRest, value); | |
} | |
} | |
[DebuggerDisplay("'{m_Key}'")] | |
public class PatriciaTrieNode<TValue> : TrieNodeBase<TValue> | |
{ | |
private Dictionary<char, PatriciaTrieNode<TValue>> m_Children; | |
private StringPartition m_Key; | |
private Queue<TValue> m_Values; | |
protected PatriciaTrieNode(StringPartition key, TValue value) | |
: this(key, new Queue<TValue>(new[] { value }), new Dictionary<char, PatriciaTrieNode<TValue>>()) | |
{ | |
} | |
protected PatriciaTrieNode(StringPartition key, Queue<TValue> values, | |
Dictionary<char, PatriciaTrieNode<TValue>> children) | |
{ | |
m_Values = values; | |
m_Key = key; | |
m_Children = children; | |
} | |
protected override int KeyLength | |
{ | |
get { return m_Key.Length; } | |
} | |
protected override IEnumerable<TValue> Values() | |
{ | |
return m_Values; | |
} | |
protected override IEnumerable<TrieNodeBase<TValue>> Children() | |
{ | |
return m_Children.Values; | |
} | |
protected override void AddValue(TValue value) | |
{ | |
m_Values.Enqueue(value); | |
} | |
internal virtual void Add(StringPartition keyRest, TValue value) | |
{ | |
ZipResult zipResult = m_Key.ZipWith(keyRest); | |
switch (zipResult.MatchKind) | |
{ | |
case MatchKind.ExactMatch: | |
AddValue(value); | |
break; | |
case MatchKind.IsContained: | |
GetOrCreateChild(zipResult.OtherRest, value); | |
break; | |
case MatchKind.Contains: | |
SplitOne(zipResult, value); | |
break; | |
case MatchKind.Partial: | |
SplitTwo(zipResult, value); | |
break; | |
} | |
} | |
private void SplitOne(ZipResult zipResult, TValue value) | |
{ | |
var leftChild = new PatriciaTrieNode<TValue>(zipResult.ThisRest, m_Values, m_Children); | |
m_Children = new Dictionary<char, PatriciaTrieNode<TValue>>(); | |
m_Values = new Queue<TValue>(); | |
AddValue(value); | |
m_Key = zipResult.CommonHead; | |
m_Children.Add(zipResult.ThisRest[0], leftChild); | |
} | |
private void SplitTwo(ZipResult zipResult, TValue value) | |
{ | |
var leftChild = new PatriciaTrieNode<TValue>(zipResult.ThisRest, m_Values, m_Children); | |
var rightChild = new PatriciaTrieNode<TValue>(zipResult.OtherRest, value); | |
m_Children = new Dictionary<char, PatriciaTrieNode<TValue>>(); | |
m_Values = new Queue<TValue>(); | |
m_Key = zipResult.CommonHead; | |
char leftKey = zipResult.ThisRest[0]; | |
m_Children.Add(leftKey, leftChild); | |
char rightKey = zipResult.OtherRest[0]; | |
m_Children.Add(rightKey, rightChild); | |
} | |
protected void GetOrCreateChild(StringPartition key, TValue value) | |
{ | |
PatriciaTrieNode<TValue> child; | |
if (!m_Children.TryGetValue(key[0], out child)) | |
{ | |
child = new PatriciaTrieNode<TValue>(key, value); | |
m_Children.Add(key[0], child); | |
} | |
else | |
{ | |
child.Add(key, value); | |
} | |
} | |
protected override TrieNodeBase<TValue> GetOrCreateChild(char key) | |
{ | |
throw new NotSupportedException("Use alternative signature instead."); | |
} | |
protected override TrieNodeBase<TValue> GetChildOrNull(string query, int position) | |
{ | |
if (query == null) throw new ArgumentNullException("query"); | |
PatriciaTrieNode<TValue> child; | |
if (m_Children.TryGetValue(query[position], out child)) | |
{ | |
var queryPartition = new StringPartition(query, position, child.m_Key.Length); | |
if (child.m_Key.StartsWith(queryPartition)) | |
{ | |
return child; | |
} | |
} | |
return null; | |
} | |
public string Traversal() | |
{ | |
var result = new StringBuilder(); | |
result.Append(m_Key); | |
string subtreeResult = string.Join(" ; ", m_Children.Values.Select(node => node.Traversal()).ToArray()); | |
if (subtreeResult.Length != 0) | |
{ | |
result.Append("["); | |
result.Append(subtreeResult); | |
result.Append("]"); | |
} | |
return result.ToString(); | |
} | |
public override string ToString() | |
{ | |
return | |
string.Format( | |
"Key: {0}, Values: {1} Children:{2}, ", | |
m_Key, | |
Values().Count(), | |
String.Join(";", m_Children.Keys)); | |
} | |
} | |
public struct SplitResult | |
{ | |
private readonly StringPartition m_Head; | |
private readonly StringPartition m_Rest; | |
public SplitResult(StringPartition head, StringPartition rest) | |
{ | |
m_Head = head; | |
m_Rest = rest; | |
} | |
public StringPartition Rest | |
{ | |
get { return m_Rest; } | |
} | |
public StringPartition Head | |
{ | |
get { return m_Head; } | |
} | |
public bool Equals(SplitResult other) | |
{ | |
return m_Head == other.m_Head && m_Rest == other.m_Rest; | |
} | |
public override bool Equals(object obj) | |
{ | |
if (ReferenceEquals(null, obj)) return false; | |
return obj is SplitResult && Equals((SplitResult)obj); | |
} | |
public override int GetHashCode() | |
{ | |
unchecked | |
{ | |
return (m_Head.GetHashCode() * 397) ^ m_Rest.GetHashCode(); | |
} | |
} | |
public static bool operator ==(SplitResult left, SplitResult right) | |
{ | |
return left.Equals(right); | |
} | |
public static bool operator !=(SplitResult left, SplitResult right) | |
{ | |
return !(left == right); | |
} | |
} | |
[DebuggerDisplay( | |
"{m_Origin.Substring(0,m_StartIndex)} [ {m_Origin.Substring(m_StartIndex,m_PartitionLength)} ] {m_Origin.Substring(m_StartIndex + m_PartitionLength)}" | |
)] | |
public struct StringPartition : IEnumerable<char> | |
{ | |
private readonly string m_Origin; | |
private readonly int m_PartitionLength; | |
private readonly int m_StartIndex; | |
public StringPartition(string origin) | |
: this(origin, 0, origin == null ? 0 : origin.Length) | |
{ | |
} | |
public StringPartition(string origin, int startIndex) | |
: this(origin, startIndex, origin == null ? 0 : origin.Length - startIndex) | |
{ | |
} | |
public StringPartition(string origin, int startIndex, int partitionLength) | |
{ | |
if (origin == null) throw new ArgumentNullException("origin"); | |
if (startIndex < 0) throw new ArgumentOutOfRangeException("startIndex", "The value must be non negative."); | |
if (partitionLength < 0) | |
throw new ArgumentOutOfRangeException("partitionLength", "The value must be non negative."); | |
m_Origin = string.Intern(origin); | |
m_StartIndex = startIndex; | |
int availableLength = m_Origin.Length - startIndex; | |
m_PartitionLength = Math.Min(partitionLength, availableLength); | |
} | |
public char this[int index] | |
{ | |
get { return m_Origin[m_StartIndex + index]; } | |
} | |
public int Length | |
{ | |
get { return m_PartitionLength; } | |
} | |
#region IEnumerable<char> Members | |
public IEnumerator<char> GetEnumerator() | |
{ | |
for (int i = 0; i < m_PartitionLength; i++) | |
{ | |
yield return this[i]; | |
} | |
} | |
#endregion | |
public bool Equals(StringPartition other) | |
{ | |
return string.Equals(m_Origin, other.m_Origin) && m_PartitionLength == other.m_PartitionLength && | |
m_StartIndex == other.m_StartIndex; | |
} | |
public override bool Equals(object obj) | |
{ | |
if (ReferenceEquals(null, obj)) return false; | |
return obj is StringPartition && Equals((StringPartition)obj); | |
} | |
public override int GetHashCode() | |
{ | |
unchecked | |
{ | |
int hashCode = (m_Origin != null ? m_Origin.GetHashCode() : 0); | |
hashCode = (hashCode * 397) ^ m_PartitionLength; | |
hashCode = (hashCode * 397) ^ m_StartIndex; | |
return hashCode; | |
} | |
} | |
public bool StartsWith(StringPartition other) | |
{ | |
if (Length < other.Length) | |
{ | |
return false; | |
} | |
for (int i = 0; i < other.Length; i++) | |
{ | |
if (this[i] != other[i]) | |
{ | |
return false; | |
} | |
} | |
return true; | |
} | |
public SplitResult Split(int splitAt) | |
{ | |
var head = new StringPartition(m_Origin, m_StartIndex, splitAt); | |
var rest = new StringPartition(m_Origin, m_StartIndex + splitAt, Length - splitAt); | |
return new SplitResult(head, rest); | |
} | |
public ZipResult ZipWith(StringPartition other) | |
{ | |
int splitIndex = 0; | |
using (IEnumerator<char> thisEnumerator = GetEnumerator()) | |
using (IEnumerator<char> otherEnumerator = other.GetEnumerator()) | |
{ | |
while (thisEnumerator.MoveNext() && otherEnumerator.MoveNext()) | |
{ | |
if (thisEnumerator.Current != otherEnumerator.Current) | |
{ | |
break; | |
} | |
splitIndex++; | |
} | |
} | |
SplitResult thisSplitted = Split(splitIndex); | |
SplitResult otherSplitted = other.Split(splitIndex); | |
StringPartition commonHead = thisSplitted.Head; | |
StringPartition restThis = thisSplitted.Rest; | |
StringPartition restOther = otherSplitted.Rest; | |
return new ZipResult(commonHead, restThis, restOther); | |
} | |
public override string ToString() | |
{ | |
var result = new string(this.ToArray()); | |
return string.Intern(result); | |
} | |
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); | |
public static bool operator ==(StringPartition left, StringPartition right) | |
{ | |
return left.Equals(right); | |
} | |
public static bool operator !=(StringPartition left, StringPartition right) | |
{ | |
return !(left == right); | |
} | |
} | |
[DebuggerDisplay("Head: '{CommonHead}', This: '{ThisRest}', Other: '{OtherRest}', Kind: {MatchKind}")] | |
public struct ZipResult | |
{ | |
private readonly StringPartition m_CommonHead; | |
private readonly StringPartition m_OtherRest; | |
private readonly StringPartition m_ThisRest; | |
public ZipResult(StringPartition commonHead, StringPartition thisRest, StringPartition otherRest) | |
{ | |
m_CommonHead = commonHead; | |
m_ThisRest = thisRest; | |
m_OtherRest = otherRest; | |
} | |
public MatchKind MatchKind | |
{ | |
get | |
{ | |
return m_ThisRest.Length == 0 | |
? (m_OtherRest.Length == 0 | |
? MatchKind.ExactMatch | |
: MatchKind.IsContained) | |
: (m_OtherRest.Length == 0 | |
? MatchKind.Contains | |
: MatchKind.Partial); | |
} | |
} | |
public StringPartition OtherRest | |
{ | |
get { return m_OtherRest; } | |
} | |
public StringPartition ThisRest | |
{ | |
get { return m_ThisRest; } | |
} | |
public StringPartition CommonHead | |
{ | |
get { return m_CommonHead; } | |
} | |
public bool Equals(ZipResult other) | |
{ | |
return | |
m_CommonHead == other.m_CommonHead && | |
m_OtherRest == other.m_OtherRest && | |
m_ThisRest == other.m_ThisRest; | |
} | |
public override bool Equals(object obj) | |
{ | |
if (ReferenceEquals(null, obj)) return false; | |
return obj is ZipResult && Equals((ZipResult)obj); | |
} | |
public override int GetHashCode() | |
{ | |
unchecked | |
{ | |
int hashCode = m_CommonHead.GetHashCode(); | |
hashCode = (hashCode * 397) ^ m_OtherRest.GetHashCode(); | |
hashCode = (hashCode * 397) ^ m_ThisRest.GetHashCode(); | |
return hashCode; | |
} | |
} | |
public static bool operator ==(ZipResult left, ZipResult right) | |
{ | |
return left.Equals(right); | |
} | |
public static bool operator !=(ZipResult left, ZipResult right) | |
{ | |
return !(left == right); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment