Skip to content

Instantly share code, notes, and snippets.

@shivamMg
Created January 14, 2025 18:47
Show Gist options
  • Save shivamMg/a0cae1a88ab252e0aaead727115acbd6 to your computer and use it in GitHub Desktop.
Save shivamMg/a0cae1a88ab252e0aaead727115acbd6 to your computer and use it in GitHub Desktop.
LRU Cache with TTL
import time
from datetime import datetime, timedelta
class Node:
def __init__(self, key, value, expiry):
self.key = key
self.value = value
self.expiry = expiry
self.next = None
self.prev = None
class DoublyLinkedList:
def __init__(self):
self.head = self.tail = None
def add_head(self, node):
if not self.head:
self.head = self.tail = node
else:
node.next = self.head
self.head.prev = node
self.head = node
def remove(self, node):
if node is self.head and self.head is self.tail:
self.head = self.tail = None
elif node is self.head:
self.head = node.next
self.head.prev = None
elif node is self.tail:
self.tail = node.prev
self.tail.next = None
else:
prev = node.prev
next = node.next
prev.next = next
next.prev = prev
def remove_tail(self) -> Node:
node = self.tail
if self.head is self.tail:
self.head = self.tail = None
else:
self.tail = node.prev
self.tail.next = None
return node
class LRUCache:
"""LRU Cache implementation with TTL"""
def __init__(self, size: int):
self.size = size
self.dict = {}
self.list = DoublyLinkedList()
def set(self, key: str, value: str, ttl: int): # ttl is in seconds
assert ttl > 0
expiry = datetime.now() + timedelta(seconds=ttl)
if key not in self.dict:
node = Node(key, value, expiry)
if len(self.dict) == self.size:
tail_node = self.list.remove_tail() # evict least recent node
del self.dict[tail_node.key]
else:
node = self.dict[key]
node.value = value
node.expiry = expiry
self.list.remove(node) # remove here to add it to head later
self.list.add_head(node)
self.dict[key] = node
def get(self, key: str):
node = self.dict.get(key)
if node is None:
return None
if node.expiry <= datetime.now():
self.list.remove(node)
del self.dict[key]
return None
return node.value
import unittest
import time
from lru_cache_ttl import LRUCache
class TestLRUCache(unittest.TestCase):
def test_set_and_get(self):
cache = LRUCache(2)
cache.set("a", "1", 2)
cache.set("b", "2", 4)
assert cache.get("a") == "1"
assert cache.get("b") == "2"
def test_eviction_by_size(self):
cache = LRUCache(2)
cache.set("a", "1", 2)
cache.set("b", "2", 4)
cache.set("c", "3", 8)
assert cache.get("a") is None # evicted due to size
assert cache.get("b") == "2"
assert cache.get("c") == "3"
def test_eviction_by_expiry(self):
cache = LRUCache(2)
cache.set("a", "1", 2)
cache.set("b", "2", 4)
time.sleep(3)
assert cache.get("a") is None # expired
assert cache.get("b") == "2"
time.sleep(2)
assert cache.get("b") is None # expired
def test_reset_key(self):
cache = LRUCache(1)
cache.set("d", "5", 2)
assert cache.get("d") == "5"
cache.set("d", "5", 1) # reset ttl
assert cache.get("d") == "5"
cache.set("d", "6", 1) # reset ttl and value
assert cache.get("d") == "6"
time.sleep(2)
assert cache.get("d") is None # expired
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment