Skip to content

Instantly share code, notes, and snippets.

@0xquad
Created March 27, 2016 16:48
Show Gist options
  • Save 0xquad/3f24f1cf3498fd10a27e to your computer and use it in GitHub Desktop.
Save 0xquad/3f24f1cf3498fd10a27e to your computer and use it in GitHub Desktop.
Testing generators with unittest Mocks
#!/usr/bin/env python
def func():
return subfunc()
def subfunc():
return 0
def g():
start = func()
yield start + 1
yield start + 2
if __name__ == '__main__':
import unittest
from unittest.mock import Mock, MagicMock, patch, call
class TestGenerator(unittest.TestCase):
def test_001(self):
gen1 = g()
x = next(gen1)
self.assertEqual(x, 1)
x = next(gen1)
self.assertEqual(x, 2)
with self.assertRaises(StopIteration):
next(gen1)
@patch('__main__.subfunc')
def test_002(self, subfunc_mock):
subfunc_mock.return_value = 1000
gen1 = g()
x = next(gen1)
self.assertEqual(x, 1001)
x = next(gen1)
self.assertEqual(x, 1002)
with self.assertRaises(StopIteration):
next(gen1)
def test_003(self):
with patch('__main__.subfunc') as subfunc_mock:
subfunc_mock.return_value = 1000
gen1 = g()
x = next(gen1)
self.assertEqual(x, 1001)
x = next(gen1)
self.assertEqual(x, 1002)
with self.assertRaises(StopIteration):
next(gen1)
def test_004(self):
subfunc_mock = Mock(return_value=1000)
patcher = patch('__main__.subfunc', subfunc_mock)
gen1 = g()
x = next(gen1)
self.assertEqual(x, 1)
x = next(gen1)
self.assertEqual(x, 2)
with self.assertRaises(StopIteration):
next(gen1)
with patcher:
gen1 = g()
x = next(gen1)
self.assertEqual(x, 1001)
x = next(gen1)
self.assertEqual(x, 1002)
with self.assertRaises(StopIteration):
next(gen1)
def test_005(self):
subfunc_mock = Mock(return_value=1000)
patcher = patch('__main__.subfunc', subfunc_mock)
x = func()
self.assertNotEqual(x, 1000)
with patcher:
x = func()
self.assertEqual(x, 1000)
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment