Last active
December 19, 2015 07:59
-
-
Save hideaki-t/5922987 to your computer and use it in GitHub Desktop.
Switch/Case implementation using Enum
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
class Signal(Enum): | |
red = 1 | |
yellow = 2 | |
green = 3 | |
def switch(e, **cases): | |
cls = e.__class__ | |
has_default = False | |
casenames = cases.keys() | |
names = cls.__members__.keys() | |
# check having extra case or default(_) | |
diff = casenames - names | |
if diff == {'_'}: | |
has_default = True | |
elif diff: | |
# has extra cases | |
extra = diff - {'_'} | |
raise Exception("extra cases {} are given for {}".format( | |
extra, cls.__name__), extra) | |
# check names are covered by cases | |
diff = names - casenames | |
if diff and not has_default: | |
raise Exception("given cases do not cover all names {}".format(diff), diff) | |
name = e.name | |
if name in cases: | |
f = cases[name] | |
else: | |
f = cases['_'] | |
return f(e) | |
if __name__ == '__main__': | |
import unittest | |
class TestSwitch(unittest.TestCase): | |
def test_cover_all(self): | |
signal = Signal.red | |
self.assertEqual(switch(signal, | |
red = lambda x: "Go carefully {}".format(x), | |
green = lambda x: "Go! {}".format(x), | |
yellow = lambda x: "Go fast {}".format(x) | |
), | |
"Go carefully Signal.red") | |
signal = Signal.green | |
self.assertEqual(switch(signal, | |
red = lambda x: "Go carefully {}".format(x), | |
green = lambda x: "Go! {}".format(x), | |
yellow = lambda x: "Go fast {}".format(x) | |
), | |
"Go! Signal.green") | |
def test_extra_default(self): | |
signal = Signal.yellow | |
self.assertEqual(switch(signal, | |
red = lambda x: "Go carefully/r {}".format(x), | |
green = lambda x: "Go! {}".format(x), | |
yellow = lambda x: "Go fast {}".format(x), | |
_ = lambda x: "Go carefully {}/_".format(x) | |
), | |
"Go fast Signal.yellow") | |
def test_default(self): | |
signal = Signal.red | |
self.assertEqual(switch(signal, | |
green = lambda x: "Go! {}".format(x), | |
yellow = lambda x: "Go fast {}".format(x), | |
_ = lambda x: "Go carefully {}".format(x) | |
), | |
"Go carefully Signal.red") | |
def test_empty(self): | |
with self.assertRaises(Exception) as cm: | |
signal = Signal.red | |
switch(signal) | |
self.assertEqual(cm.exception.args[1], set(Signal.__members__.keys())) | |
def test_extra(self): | |
with self.assertRaises(Exception) as cm: | |
signal = Signal.red | |
switch(signal, | |
black = lambda x: x, | |
blue = lambda x: x, | |
red = lambda x: x, | |
green = lambda x: x, | |
yellow = lambda x: x | |
) | |
self.assertEqual(cm.exception.args[1], | |
{'black', 'blue', 'red', 'green', 'yellow'} - | |
set(Signal.__members__.keys())) | |
def test_extra_with_default(self): | |
with self.assertRaises(Exception) as cm: | |
signal = Signal.red | |
switch(signal, | |
blue = lambda x: x, | |
red = lambda x: x, | |
green = lambda x: x, | |
yellow = lambda x: x, | |
_ = lambda x: x | |
) | |
self.assertEqual(cm.exception.args[1], | |
{'blue', 'red', 'green', 'yellow'} - | |
set(Signal.__members__.keys())) | |
def test_extra_and_default_only(self): | |
with self.assertRaises(Exception) as cm: | |
signal = Signal.red | |
switch(signal, | |
black = lambda x: x, | |
blue = lambda x: x, | |
_ = lambda x: x | |
) | |
self.assertEqual(cm.exception.args[1], | |
{'blue', 'black'} - | |
set(Signal.__members__.keys())) | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment