Skip to content

Instantly share code, notes, and snippets.

@hideaki-t
Last active December 19, 2015 07:59
Show Gist options
  • Save hideaki-t/5922987 to your computer and use it in GitHub Desktop.
Save hideaki-t/5922987 to your computer and use it in GitHub Desktop.
Switch/Case implementation using Enum
from enum import Enum
class Signal(Enum):
red = 1
yellow = 2
green = 3
def switch(e, **cases):
cls = e.__class__
has_default = False
casenames = cases.keys()
names = cls.__members__.keys()
# check having extra case or default(_)
diff = casenames - names
if diff == {'_'}:
has_default = True
elif diff:
# has extra cases
extra = diff - {'_'}
raise Exception("extra cases {} are given for {}".format(
extra, cls.__name__), extra)
# check names are covered by cases
diff = names - casenames
if diff and not has_default:
raise Exception("given cases do not cover all names {}".format(diff), diff)
name = e.name
if name in cases:
f = cases[name]
else:
f = cases['_']
return f(e)
if __name__ == '__main__':
import unittest
class TestSwitch(unittest.TestCase):
def test_cover_all(self):
signal = Signal.red
self.assertEqual(switch(signal,
red = lambda x: "Go carefully {}".format(x),
green = lambda x: "Go! {}".format(x),
yellow = lambda x: "Go fast {}".format(x)
),
"Go carefully Signal.red")
signal = Signal.green
self.assertEqual(switch(signal,
red = lambda x: "Go carefully {}".format(x),
green = lambda x: "Go! {}".format(x),
yellow = lambda x: "Go fast {}".format(x)
),
"Go! Signal.green")
def test_extra_default(self):
signal = Signal.yellow
self.assertEqual(switch(signal,
red = lambda x: "Go carefully/r {}".format(x),
green = lambda x: "Go! {}".format(x),
yellow = lambda x: "Go fast {}".format(x),
_ = lambda x: "Go carefully {}/_".format(x)
),
"Go fast Signal.yellow")
def test_default(self):
signal = Signal.red
self.assertEqual(switch(signal,
green = lambda x: "Go! {}".format(x),
yellow = lambda x: "Go fast {}".format(x),
_ = lambda x: "Go carefully {}".format(x)
),
"Go carefully Signal.red")
def test_empty(self):
with self.assertRaises(Exception) as cm:
signal = Signal.red
switch(signal)
self.assertEqual(cm.exception.args[1], set(Signal.__members__.keys()))
def test_extra(self):
with self.assertRaises(Exception) as cm:
signal = Signal.red
switch(signal,
black = lambda x: x,
blue = lambda x: x,
red = lambda x: x,
green = lambda x: x,
yellow = lambda x: x
)
self.assertEqual(cm.exception.args[1],
{'black', 'blue', 'red', 'green', 'yellow'} -
set(Signal.__members__.keys()))
def test_extra_with_default(self):
with self.assertRaises(Exception) as cm:
signal = Signal.red
switch(signal,
blue = lambda x: x,
red = lambda x: x,
green = lambda x: x,
yellow = lambda x: x,
_ = lambda x: x
)
self.assertEqual(cm.exception.args[1],
{'blue', 'red', 'green', 'yellow'} -
set(Signal.__members__.keys()))
def test_extra_and_default_only(self):
with self.assertRaises(Exception) as cm:
signal = Signal.red
switch(signal,
black = lambda x: x,
blue = lambda x: x,
_ = lambda x: x
)
self.assertEqual(cm.exception.args[1],
{'blue', 'black'} -
set(Signal.__members__.keys()))
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment