Created
February 1, 2019 21:33
-
-
Save gvx/679597e0fce2c5f274c00f837c2db117 to your computer and use it in GitHub Desktop.
Generalised hashable types in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass, field, fields, MISSING | |
def immutable_field(*, default=MISSING, default_factory=MISSING): | |
return field(default=default, default_factory=default_factory, metadata={'frozen': True}) | |
def mutable_field(*, default=MISSING, default_factory=MISSING): | |
return field(default=default, default_factory=default_factory, compare=False) | |
def halfmutable(_cls=None, *, init=True, repr=True, order=False): | |
def wrap(cls): | |
cls = dataclass(cls, init=init, repr=repr, order=order, frozen=True) | |
namespace = {} | |
immutable_fields = tuple(f.name for f in fields(cls) if f.metadata.get('frozen')) | |
exec(f''' | |
def __setattr__(self, key, value): | |
if key in {immutable_fields}: | |
raise ValueError(f"immutable field: {{key}}") | |
object.__setattr__(self, key, value) | |
def __delattr__(self, key): | |
if key in {immutable_fields}: | |
raise ValueError(f"immutable field: {{key}}") | |
object.__delattr__(self, key) | |
''', globals(), namespace) | |
cls.__setattr__ = namespace['__setattr__'] | |
cls.__delattr__ = namespace['__delattr__'] | |
return cls | |
# See if we're being called as @halfmutable or @halfmutable(). | |
if _cls is None: | |
# We're called with parens. | |
return wrap | |
# We're called as @halfmutable without parens. | |
return wrap(_cls) | |
@halfmutable | |
class MyHalfMutable: | |
spam: int = immutable_field() | |
bar: str = mutable_field() | |
ham: str = immutable_field(default='my hovercraft is full of eels') | |
foo: int = mutable_field(default=42) | |
a = MyHalfMutable(spam=1, bar='hotdog') | |
b = MyHalfMutable(spam=1, bar='nothing') | |
some_dict = {a: a} | |
print(b in some_dict) | |
a.foo = 'look ma, i can be modified' | |
print(a in some_dict) | |
a.spam = 10 # error: parts of an object that contribute to equality and hash value cannot be modified |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment