Last active
January 22, 2024 13:20
-
-
Save Ovid/1a1d7869b29816cc4c82d00ad42aa2fc to your computer and use it in GitHub Desktop.
Force validation of overridden methods in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def enforce_overrides(cls): | |
""" | |
Class decorator to ensure methods marked with @override are actually overriding. | |
from abc import ABC | |
class A(ABC): | |
@abstractmethod | |
def foo(self): | |
pass | |
@enforce_overrides | |
class B(A): | |
@override | |
def foo(self): | |
pass | |
@enforce_overrides will check all methods in the current class. If they override a | |
parent class method, they must be marked with @override or else a TypeError will be | |
raised. | |
If a method is marked with @override but does not override a parent class method, a | |
TypeError will also be raised. | |
""" | |
for name, method in vars(cls).items(): | |
if not callable(method): | |
continue # ignore attributes | |
if getattr(method, "_is_overriden", False): # Check if method is marked as override | |
if not any(name in vars(base) for base in cls.__bases__): | |
raise TypeError(f"Method '{name}' in '{cls.__name__}' is not overriding any method of the base classes") | |
else: | |
for base in cls.__bases__: | |
if name in vars(base): | |
raise TypeError(f"Method '{name}' in '{cls.__name__}' is overriding a method in '{base}' class but is not marked with @override") | |
return cls | |
def override(method): | |
"""Decorator to indicate that the method overrides a method of the base class.""" | |
method._is_overriden = True | |
return method |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment