Skip to content

Instantly share code, notes, and snippets.

@ntrrgc
Created December 6, 2024 12:45
Show Gist options
  • Save ntrrgc/1d8d2d94086b24eee466a9125ad6b1ed to your computer and use it in GitHub Desktop.
Save ntrrgc/1d8d2d94086b24eee466a9125ad6b1ed to your computer and use it in GitHub Desktop.
Python MultiDict, in a copy-paste-able library, because I find myself needing them surprisingly often
from collections.abc import ItemsView, Set, KeysView
from typing import Iterable, Iterator, NamedTuple, Hashable
class MultiDict[K: Hashable, V: Hashable]:
"""
Associates sets of values with a key.
Values can repeat for different keys if needed.
"""
def __init__(self):
self.__vals_by_key: dict[K, set[V]] = {}
def add(self, key: K, value: V):
self.__vals_by_key.setdefault(key, set())
self.__vals_by_key[key].add(value)
def has_single(self, key: K, value: V) -> bool:
value_set = self.__vals_by_key.get(key, set())
return value in value_set
def has_key(self, key: K) -> bool:
value_set = self.__vals_by_key.get(key)
if value_set is None:
return False
assert value_set, "there must not be empty sets"
return True
def remove_single(self, key: K, value: V):
value_set = self.__vals_by_key[key]
value_set.remove(value)
if not value_set:
del self.__vals_by_key[key]
def remove_key(self, key: K):
del self.__vals_by_key[key]
def discard_single(self, key: K, value: V):
value_set = self.__vals_by_key.get(key)
if not value_set:
return
value_set.discard(value)
if not value_set:
del self.__vals_by_key[key]
def discard_key(self, key: K):
self.__vals_by_key.pop(key, None)
def clear(self):
self.__vals_by_key.clear()
def keys(self) -> KeysView[K]:
return self.__vals_by_key.keys()
def values(self) -> Iterator[V]:
for value_set in self.__vals_by_key.values():
assert value_set, "there must not be empty sets"
yield from value_set
def items(self) -> ItemsView[K, Set[V]]:
return self.__vals_by_key.items()
def get(self, key: K) -> Set[V]:
return self.__vals_by_key.get(key, set())
def __getitem__(self, key: K) -> Set[V]:
return self.__vals_by_key[key]
def count_keys(self) -> int:
return len(self.__vals_by_key)
def __contains__(self, key: K) -> bool:
return self.has_key(key)
def __len__(self) -> int:
return self.count_keys()
def __iter__(self) -> Iterator[K]:
return iter(self.keys())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment