Skip to content

Instantly share code, notes, and snippets.

@raizam
Created April 18, 2019 13:55
Show Gist options
  • Save raizam/eb0fceb69675b54aeafb5669db7aee0f to your computer and use it in GitHub Desktop.
Save raizam/eb0fceb69675b54aeafb5669db7aee0f to your computer and use it in GitHub Desktop.
Single file PatriciaTrie extracted from trienet (https://github.com/gmamaladze/trienet/tree/master/TrieNet).
// This code is distributed under MIT license. Copyright (c) 2013 George Mamaladze
// See license.txt or http://opensource.org/licenses/mit-license.php
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
namespace System.Collections.Specialized
{
// This code is distributed under MIT license. Copyright (c) 2013 George Mamaladze
// See license.txt or http://opensource.org/licenses/mit-license.php
public enum MatchKind
{
ExactMatch,
Contains,
IsContained,
Partial,
}
/// <summary>
/// Interface to be implemented by a data structure
/// which allows adding values <see cref="TValue"/> associated with <b>string</b> keys.
/// The interface allows retrieveal of multiple values
/// </summary>
/// <typeparam name="TValue"></typeparam>
public interface ITrie<TValue>
{
IEnumerable<TValue> Retrieve(string query);
void Add(string key, TValue value);
}
public abstract class TrieNodeBase<TValue>
{
protected abstract int KeyLength { get; }
protected abstract IEnumerable<TValue> Values();
protected abstract IEnumerable<TrieNodeBase<TValue>> Children();
public void Add(string key, int position, TValue value)
{
if (key == null) throw new ArgumentNullException("key");
if (EndOfString(position, key))
{
AddValue(value);
return;
}
TrieNodeBase<TValue> child = GetOrCreateChild(key[position]);
child.Add(key, position + 1, value);
}
protected abstract void AddValue(TValue value);
protected abstract TrieNodeBase<TValue> GetOrCreateChild(char key);
protected virtual IEnumerable<TValue> Retrieve(string query, int position)
{
return
EndOfString(position, query)
? ValuesDeep()
: SearchDeep(query, position);
}
protected virtual IEnumerable<TValue> SearchDeep(string query, int position)
{
TrieNodeBase<TValue> nextNode = GetChildOrNull(query, position);
return nextNode != null
? nextNode.Retrieve(query, position + nextNode.KeyLength)
: Enumerable.Empty<TValue>();
}
protected abstract TrieNodeBase<TValue> GetChildOrNull(string query, int position);
private static bool EndOfString(int position, string text)
{
return position >= text.Length;
}
private IEnumerable<TValue> ValuesDeep()
{
return
Subtree()
.SelectMany(node => node.Values());
}
protected IEnumerable<TrieNodeBase<TValue>> Subtree()
{
return
Enumerable.Repeat(this, 1)
.Concat(Children().SelectMany(child => child.Subtree()));
}
}
public class PatriciaSuffixTrie<TValue> : ITrie<TValue>
{
private readonly int m_MinQueryLength;
private readonly PatriciaTrie<TValue> m_InnerTrie;
public PatriciaSuffixTrie(int minQueryLength)
: this(minQueryLength, new PatriciaTrie<TValue>())
{
}
internal PatriciaSuffixTrie(int minQueryLength, PatriciaTrie<TValue> innerTrie)
{
m_MinQueryLength = minQueryLength;
m_InnerTrie = innerTrie;
}
protected int MinQueryLength
{
get { return m_MinQueryLength; }
}
public IEnumerable<TValue> Retrieve(string query)
{
return
m_InnerTrie
.Retrieve(query)
.Distinct();
}
public void Add(string key, TValue value)
{
IEnumerable<StringPartition> allSuffixes = GetAllSuffixes(MinQueryLength, key);
foreach (StringPartition currentSuffix in allSuffixes)
{
m_InnerTrie.Add(currentSuffix, value);
}
}
private static IEnumerable<StringPartition> GetAllSuffixes(int minSuffixLength, string word)
{
for (int i = word.Length - minSuffixLength; i >= 0; i--)
{
yield return new StringPartition(word, i);
}
}
}
public class PatriciaTrie<TValue> :
PatriciaTrieNode<TValue>,
ITrie<TValue>
{
public PatriciaTrie()
: base(
new StringPartition(string.Empty),
new Queue<TValue>(),
new Dictionary<char, PatriciaTrieNode<TValue>>())
{
}
public IEnumerable<TValue> Retrieve(string query)
{
return Retrieve(query, 0);
}
public virtual void Add(string key, TValue value)
{
if (key == null) throw new ArgumentNullException("key");
Add(new StringPartition(key), value);
}
internal override void Add(StringPartition keyRest, TValue value)
{
GetOrCreateChild(keyRest, value);
}
}
[DebuggerDisplay("'{m_Key}'")]
public class PatriciaTrieNode<TValue> : TrieNodeBase<TValue>
{
private Dictionary<char, PatriciaTrieNode<TValue>> m_Children;
private StringPartition m_Key;
private Queue<TValue> m_Values;
protected PatriciaTrieNode(StringPartition key, TValue value)
: this(key, new Queue<TValue>(new[] { value }), new Dictionary<char, PatriciaTrieNode<TValue>>())
{
}
protected PatriciaTrieNode(StringPartition key, Queue<TValue> values,
Dictionary<char, PatriciaTrieNode<TValue>> children)
{
m_Values = values;
m_Key = key;
m_Children = children;
}
protected override int KeyLength
{
get { return m_Key.Length; }
}
protected override IEnumerable<TValue> Values()
{
return m_Values;
}
protected override IEnumerable<TrieNodeBase<TValue>> Children()
{
return m_Children.Values;
}
protected override void AddValue(TValue value)
{
m_Values.Enqueue(value);
}
internal virtual void Add(StringPartition keyRest, TValue value)
{
ZipResult zipResult = m_Key.ZipWith(keyRest);
switch (zipResult.MatchKind)
{
case MatchKind.ExactMatch:
AddValue(value);
break;
case MatchKind.IsContained:
GetOrCreateChild(zipResult.OtherRest, value);
break;
case MatchKind.Contains:
SplitOne(zipResult, value);
break;
case MatchKind.Partial:
SplitTwo(zipResult, value);
break;
}
}
private void SplitOne(ZipResult zipResult, TValue value)
{
var leftChild = new PatriciaTrieNode<TValue>(zipResult.ThisRest, m_Values, m_Children);
m_Children = new Dictionary<char, PatriciaTrieNode<TValue>>();
m_Values = new Queue<TValue>();
AddValue(value);
m_Key = zipResult.CommonHead;
m_Children.Add(zipResult.ThisRest[0], leftChild);
}
private void SplitTwo(ZipResult zipResult, TValue value)
{
var leftChild = new PatriciaTrieNode<TValue>(zipResult.ThisRest, m_Values, m_Children);
var rightChild = new PatriciaTrieNode<TValue>(zipResult.OtherRest, value);
m_Children = new Dictionary<char, PatriciaTrieNode<TValue>>();
m_Values = new Queue<TValue>();
m_Key = zipResult.CommonHead;
char leftKey = zipResult.ThisRest[0];
m_Children.Add(leftKey, leftChild);
char rightKey = zipResult.OtherRest[0];
m_Children.Add(rightKey, rightChild);
}
protected void GetOrCreateChild(StringPartition key, TValue value)
{
PatriciaTrieNode<TValue> child;
if (!m_Children.TryGetValue(key[0], out child))
{
child = new PatriciaTrieNode<TValue>(key, value);
m_Children.Add(key[0], child);
}
else
{
child.Add(key, value);
}
}
protected override TrieNodeBase<TValue> GetOrCreateChild(char key)
{
throw new NotSupportedException("Use alternative signature instead.");
}
protected override TrieNodeBase<TValue> GetChildOrNull(string query, int position)
{
if (query == null) throw new ArgumentNullException("query");
PatriciaTrieNode<TValue> child;
if (m_Children.TryGetValue(query[position], out child))
{
var queryPartition = new StringPartition(query, position, child.m_Key.Length);
if (child.m_Key.StartsWith(queryPartition))
{
return child;
}
}
return null;
}
public string Traversal()
{
var result = new StringBuilder();
result.Append(m_Key);
string subtreeResult = string.Join(" ; ", m_Children.Values.Select(node => node.Traversal()).ToArray());
if (subtreeResult.Length != 0)
{
result.Append("[");
result.Append(subtreeResult);
result.Append("]");
}
return result.ToString();
}
public override string ToString()
{
return
string.Format(
"Key: {0}, Values: {1} Children:{2}, ",
m_Key,
Values().Count(),
String.Join(";", m_Children.Keys));
}
}
public struct SplitResult
{
private readonly StringPartition m_Head;
private readonly StringPartition m_Rest;
public SplitResult(StringPartition head, StringPartition rest)
{
m_Head = head;
m_Rest = rest;
}
public StringPartition Rest
{
get { return m_Rest; }
}
public StringPartition Head
{
get { return m_Head; }
}
public bool Equals(SplitResult other)
{
return m_Head == other.m_Head && m_Rest == other.m_Rest;
}
public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
return obj is SplitResult && Equals((SplitResult)obj);
}
public override int GetHashCode()
{
unchecked
{
return (m_Head.GetHashCode() * 397) ^ m_Rest.GetHashCode();
}
}
public static bool operator ==(SplitResult left, SplitResult right)
{
return left.Equals(right);
}
public static bool operator !=(SplitResult left, SplitResult right)
{
return !(left == right);
}
}
[DebuggerDisplay(
"{m_Origin.Substring(0,m_StartIndex)} [ {m_Origin.Substring(m_StartIndex,m_PartitionLength)} ] {m_Origin.Substring(m_StartIndex + m_PartitionLength)}"
)]
public struct StringPartition : IEnumerable<char>
{
private readonly string m_Origin;
private readonly int m_PartitionLength;
private readonly int m_StartIndex;
public StringPartition(string origin)
: this(origin, 0, origin == null ? 0 : origin.Length)
{
}
public StringPartition(string origin, int startIndex)
: this(origin, startIndex, origin == null ? 0 : origin.Length - startIndex)
{
}
public StringPartition(string origin, int startIndex, int partitionLength)
{
if (origin == null) throw new ArgumentNullException("origin");
if (startIndex < 0) throw new ArgumentOutOfRangeException("startIndex", "The value must be non negative.");
if (partitionLength < 0)
throw new ArgumentOutOfRangeException("partitionLength", "The value must be non negative.");
m_Origin = string.Intern(origin);
m_StartIndex = startIndex;
int availableLength = m_Origin.Length - startIndex;
m_PartitionLength = Math.Min(partitionLength, availableLength);
}
public char this[int index]
{
get { return m_Origin[m_StartIndex + index]; }
}
public int Length
{
get { return m_PartitionLength; }
}
#region IEnumerable<char> Members
public IEnumerator<char> GetEnumerator()
{
for (int i = 0; i < m_PartitionLength; i++)
{
yield return this[i];
}
}
#endregion
public bool Equals(StringPartition other)
{
return string.Equals(m_Origin, other.m_Origin) && m_PartitionLength == other.m_PartitionLength &&
m_StartIndex == other.m_StartIndex;
}
public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
return obj is StringPartition && Equals((StringPartition)obj);
}
public override int GetHashCode()
{
unchecked
{
int hashCode = (m_Origin != null ? m_Origin.GetHashCode() : 0);
hashCode = (hashCode * 397) ^ m_PartitionLength;
hashCode = (hashCode * 397) ^ m_StartIndex;
return hashCode;
}
}
public bool StartsWith(StringPartition other)
{
if (Length < other.Length)
{
return false;
}
for (int i = 0; i < other.Length; i++)
{
if (this[i] != other[i])
{
return false;
}
}
return true;
}
public SplitResult Split(int splitAt)
{
var head = new StringPartition(m_Origin, m_StartIndex, splitAt);
var rest = new StringPartition(m_Origin, m_StartIndex + splitAt, Length - splitAt);
return new SplitResult(head, rest);
}
public ZipResult ZipWith(StringPartition other)
{
int splitIndex = 0;
using (IEnumerator<char> thisEnumerator = GetEnumerator())
using (IEnumerator<char> otherEnumerator = other.GetEnumerator())
{
while (thisEnumerator.MoveNext() && otherEnumerator.MoveNext())
{
if (thisEnumerator.Current != otherEnumerator.Current)
{
break;
}
splitIndex++;
}
}
SplitResult thisSplitted = Split(splitIndex);
SplitResult otherSplitted = other.Split(splitIndex);
StringPartition commonHead = thisSplitted.Head;
StringPartition restThis = thisSplitted.Rest;
StringPartition restOther = otherSplitted.Rest;
return new ZipResult(commonHead, restThis, restOther);
}
public override string ToString()
{
var result = new string(this.ToArray());
return string.Intern(result);
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public static bool operator ==(StringPartition left, StringPartition right)
{
return left.Equals(right);
}
public static bool operator !=(StringPartition left, StringPartition right)
{
return !(left == right);
}
}
[DebuggerDisplay("Head: '{CommonHead}', This: '{ThisRest}', Other: '{OtherRest}', Kind: {MatchKind}")]
public struct ZipResult
{
private readonly StringPartition m_CommonHead;
private readonly StringPartition m_OtherRest;
private readonly StringPartition m_ThisRest;
public ZipResult(StringPartition commonHead, StringPartition thisRest, StringPartition otherRest)
{
m_CommonHead = commonHead;
m_ThisRest = thisRest;
m_OtherRest = otherRest;
}
public MatchKind MatchKind
{
get
{
return m_ThisRest.Length == 0
? (m_OtherRest.Length == 0
? MatchKind.ExactMatch
: MatchKind.IsContained)
: (m_OtherRest.Length == 0
? MatchKind.Contains
: MatchKind.Partial);
}
}
public StringPartition OtherRest
{
get { return m_OtherRest; }
}
public StringPartition ThisRest
{
get { return m_ThisRest; }
}
public StringPartition CommonHead
{
get { return m_CommonHead; }
}
public bool Equals(ZipResult other)
{
return
m_CommonHead == other.m_CommonHead &&
m_OtherRest == other.m_OtherRest &&
m_ThisRest == other.m_ThisRest;
}
public override bool Equals(object obj)
{
if (ReferenceEquals(null, obj)) return false;
return obj is ZipResult && Equals((ZipResult)obj);
}
public override int GetHashCode()
{
unchecked
{
int hashCode = m_CommonHead.GetHashCode();
hashCode = (hashCode * 397) ^ m_OtherRest.GetHashCode();
hashCode = (hashCode * 397) ^ m_ThisRest.GetHashCode();
return hashCode;
}
}
public static bool operator ==(ZipResult left, ZipResult right)
{
return left.Equals(right);
}
public static bool operator !=(ZipResult left, ZipResult right)
{
return !(left == right);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment