Skip to content

Instantly share code, notes, and snippets.

@PirosB3
Last active December 12, 2015 08:19
Show Gist options
  • Save PirosB3/4743581 to your computer and use it in GitHub Desktop.
Save PirosB3/4743581 to your computer and use it in GitHub Desktop.
import unittest
class Trie(object):
def __init__(self):
self._rootNode = TrieNode()
def __setitem__(self, key, value):
node = self._rootNode.searchInnerNode(key, True)
node.value = value
def __getitem__(self, key):
node = self._rootNode.searchInnerNode(key)
if not node.value:
raise KeyError
return node.value
class TrieNode(object):
def __init__(self, initialData= None):
self._innerData = initialData or {}
self.value = None
def searchInnerNode(self, word, altering= False):
index, rest = word[:1], word[1:]
node = self.getInnerNode(index, altering)
if rest:
return node.searchInnerNode(rest, altering)
return node
def getInnerNode(self, index, altering= False):
try:
return self._innerData[index]
except KeyError:
if not altering:
raise KeyError
newTrieNode = TrieNode()
self._innerData[index] = newTrieNode
return newTrieNode
class TrieNodeTestCase(unittest.TestCase):
def test_get_node_no_create(self):
n = TrieNode()
n._innerData['a'] = TrieNode()
self.assertTrue(n.getInnerNode('a'))
with self.assertRaises(KeyError):
n.getInnerNode('k')
def test_get_node_create(self):
n = TrieNode()
aNode = n.getInnerNode('a', True)
self.assertEqual(aNode, n.getInnerNode('a'))
def test_search_child_node_no_create(self):
eNode = TrieNode()
n = TrieNode({
'h' : TrieNode({
'e' : eNode
})
})
self.assertEqual(eNode, n.searchInnerNode('he'))
with self.assertRaises(KeyError):
n.getInnerNode('hel')
def test_search_child_node_create(self):
eNode = TrieNode()
n = TrieNode({
'h' : TrieNode({
'e' : eNode
})
})
lNode = n.getInnerNode('hel', True)
self.assertTrue(lNode, n.getInnerNode('hel'))
class TrieTestCase(unittest.TestCase):
def test_get_node(self):
t = Trie()
t._rootNode.searchInnerNode('hello', True).value = 'world'
self.assertEqual('world', t['hello'])
with self.assertRaises(KeyError):
t['hel']
t['hex']
def test_set_node(self):
t = Trie()
t['hello'] = 'world'
t['hex'] = 'no'
self.assertEqual('world', t._rootNode.searchInnerNode('hello').value)
eNode = t._rootNode.searchInnerNode('he')
self.assertTrue('x' in eNode._innerData)
self.assertTrue('l' in eNode._innerData)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment