Last active
December 17, 2015 18:09
-
-
Save jnothman/5650845 to your computer and use it in GitHub Desktop.
An extension of numpy.testing.assert_warns that handles checking the warning message/filename and handles multiple warnings using a cascade of critera.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""An extension of numpy.testing.assert_warns | |
Handles checking of message and multiple warnings. | |
""" | |
from __future__ import absolute_import, print_function | |
import sys | |
import re | |
import warnings | |
try: | |
from collections import Sequence, Mapping | |
except ImportError: | |
Sequence = (list, tuple) | |
Mapping = (dict,) | |
# The following two classes are copied from python 2.6 warnings module (context | |
# manager) | |
class WarningMessage(object): | |
""" | |
Holds the result of a single showwarning() call. | |
Notes | |
----- | |
`WarningMessage` is copied from the Python 2.6 warnings module, | |
so it can be used in NumPy with older Python versions. | |
""" | |
_WARNING_DETAILS = ("message", "category", "filename", "lineno", "file", | |
"line") | |
def __init__(self, message, category, filename, lineno, file=None, | |
line=None): | |
local_values = locals() | |
for attr in self._WARNING_DETAILS: | |
setattr(self, attr, local_values[attr]) | |
if category: | |
self._category_name = category.__name__ | |
else: | |
self._category_name = None | |
def __str__(self): | |
return ("{message : %r, category : %r, filename : %r, lineno : %s, " | |
"line : %r}" % (self.message, self._category_name, | |
self.filename, self.lineno, self.line)) | |
class WarningManager(object): | |
""" | |
A context manager that copies and restores the warnings filter upon | |
exiting the context. | |
The 'record' argument specifies whether warnings should be captured by a | |
custom implementation of ``warnings.showwarning()`` and be appended to a | |
list returned by the context manager. Otherwise None is returned by the | |
context manager. The objects appended to the list are arguments whose | |
attributes mirror the arguments to ``showwarning()``. | |
The 'module' argument is to specify an alternative module to the module | |
named 'warnings' and imported under that name. This argument is only useful | |
when testing the warnings module itself. | |
Notes | |
----- | |
`WarningManager` is a copy of the ``catch_warnings`` context manager | |
from the Python 2.6 warnings module, with slight modifications. | |
It is copied so it can be used in NumPy with older Python versions. | |
""" | |
def __init__(self, record=False, module=None): | |
self._record = record | |
if module is None: | |
self._module = sys.modules['warnings'] | |
else: | |
self._module = module | |
self._entered = False | |
def __enter__(self): | |
if self._entered: | |
raise RuntimeError("Cannot enter %r twice" % self) | |
self._entered = True | |
self._filters = self._module.filters | |
self._module.filters = self._filters[:] | |
self._showwarning = self._module.showwarning | |
if self._record: | |
log = [] | |
def showwarning(*args, **kwargs): | |
log.append(WarningMessage(*args, **kwargs)) | |
self._module.showwarning = showwarning | |
return log | |
else: | |
return None | |
def __exit__(self): | |
if not self._entered: | |
raise RuntimeError("Cannot exit %r without entering first" % self) | |
self._module.filters = self._filters | |
self._module.showwarning = self._showwarning | |
class WarningMatcher(object): | |
"""Matches a specified number and description of warnings | |
Parameters | |
---------- | |
category : class, optional | |
The expected warning class or a super-class (default matches any). | |
message : regular expression or string, optional | |
A regular expression expected to be matched (ignoring case) in the | |
warning's message (default matches any). | |
filename : regular expression or string, optional | |
A regular expression expected to be matched in the warning's filename | |
(default matches any). | |
count : integer, '+', or '*' (default 1) | |
The exact number of warnings that should match, or '*' for 0 or more, | |
or '+' for 1 or more. | |
""" | |
def __init__(self, category=None, message=None, filename=None, count=1): | |
self.category = category | |
self.message = self._compile_re(message) | |
self.filename = self._compile_re(filename) | |
self.count = count | |
if count not in ('*', '+') and type(count) != int: | |
raise ValueError('"count" should be "*", "+" or an integer') | |
def _compile_re(self, obj): | |
if obj is None or hasattr(obj, 'search'): | |
return obj | |
return re.compile(obj, flags=re.IGNORECASE) | |
@classmethod | |
def _from_arg(cls, arg): | |
if isinstance(arg, cls): | |
return arg | |
if isinstance(arg, Mapping): | |
return cls(**arg) | |
return cls(arg) | |
def matches_category(self, warning): | |
return self.category is None or issubclass(warning.category, self.category) | |
def matches_message(self, warning): | |
return self.message is None or self.message.search(str(warning.message)) | |
def matches_filename(self, warning): | |
return self.filename is None or self.filename.search(warning.filename) | |
def matches(self, warning): | |
return (self.matches_category(warning) and | |
self.matches_message(warning) and self.matches_filename(warning)) | |
def __call__(self, record, err_prefix=''): | |
"""Check the matching warnings in `record` are of the specified count | |
Raises `AssertionError` prefixed by `err_prefix` if the specified count | |
of matching warnings is not found. | |
Returns | |
------- | |
unmatched : list of `WarningMessage` | |
Any warnings that did not match this `WarningMatcher`'s criteria. | |
""" | |
unmatched = [] | |
n_matched = 0 | |
for warning in record: | |
if n_matched != self.count and self.matches(warning): | |
n_matched += 1 | |
else: | |
unmatched.append(warning) | |
if self.count == '*': | |
pass | |
elif self.count == '+': | |
if n_matched == 0: | |
raise AssertionError('%sExpected 1 or more warnings matching ' | |
' %s' % (err_prefix, self)) | |
elif self.count != n_matched: | |
raise AssertionError('%sExpected %d warnings matching %s, got ' | |
'%s' % (err_prefix, self.count, self, | |
n_matched)) | |
return unmatched | |
def __repr__(self): | |
return '%s(%s)'.format(self.__class__.__name__, self._kwargs_str()) | |
def _kwargs_str(self, with_count=True): | |
parts = [] | |
if self.category is not None: | |
parts.append(('category', self.category)) | |
if self.message is not None: | |
parts.append(('message', self.message.pattern)) | |
if self.filename is not None: | |
parts.append(('filename', self.filename.pattern)) | |
if with_count: | |
fields.append(('count', self.count)) | |
return ', '.join('%s=%r' % part for part in parts) | |
def __str__(self, with_count=False): | |
args = self._kwargs_str(with_count=False) | |
if not args: | |
return '<any>' | |
else: | |
return '<%s>' % args | |
def _unmatched_warnings(matchers, err_prefix, func, *args, **kw): | |
"""Checks callable raises matched warnings and returns those unmatched""" | |
if not isinstance(matchers, Sequence): | |
matchers = [matchers] | |
matchers = [WarningMatcher._from_arg(matcher) for matcher in matchers] | |
# XXX: once we may depend on python >= 2.6, this can be replaced by the | |
# warnings module context manager. | |
ctx = WarningManager(record=True) | |
l = ctx.__enter__() | |
warnings.simplefilter('always') | |
try: | |
result = func(*args, **kw) | |
finally: | |
ctx.__exit__() | |
for matcher in matchers: | |
l = matcher(l, err_prefix) | |
return result, l | |
def assert_warns(warning_class, func, *args, **kw): | |
""" | |
Fail unless the given callable throws the specified warning(s). | |
When the callable is invoked with arguments args and keyword arguments kw, | |
any warnings raised must be matched by `warning_class`. If any unmatched warnings | |
If a different type of warning is thrown, it will not be caught, and the | |
test case will be deemed to have suffered an error. | |
Parameters | |
---------- | |
warning_class : class or mapping, or sequence thereof. | |
The class(es) or matcher(s) for warnings that `func` is expected to | |
raise. Entries may be mappings with the following keys: | |
* 'category': the expected warning class or a super-class (default | |
any). | |
* 'message': a regular expression expected to be matched (ignoring | |
case) in the warning's message (default any). | |
* 'filename': a regular expression expected to be matched in the | |
warning's filename (default any). | |
* 'count': the exact number of warnings that should match, or '*' | |
for 0 or more, or '+' for 1 or more (default 1). | |
Each rule is applied in turn to consume warnings in their order of | |
appearance. As such, specific rules should usually precede more general | |
rules. For example, to require a `UserWarning`, but allow any others to | |
also occur, use `warning_class=[UserWarning, {'count': '*'}]`. | |
func : callable | |
The callable to test. | |
\\*args : Arguments | |
Arguments passed to `func`. | |
\\*\\*kw : Kwargs | |
Keyword arguments passed to `func`. | |
Returns | |
------- | |
The value returned by `func`. | |
""" | |
err_prefix = 'When calling %s: ' % func.__name__ | |
result, unmatched = _unmatched_warnings(warning_class, err_prefix, func, | |
*args, **kw) | |
if unmatched: | |
raise AssertionError('%sWarnings %s not matched' % (err_prefix, | |
unmatched)) | |
return result | |
def assert_no_warnings(func, *args, **kw): | |
""" | |
Fail if the given callable produces any warnings. | |
Parameters | |
---------- | |
func : callable | |
The callable to test. | |
\\*args : Arguments | |
Arguments passed to `func`. | |
\\*\\*kw : Kwargs | |
Keyword arguments passed to `func`. | |
Returns | |
------- | |
The value returned by `func`. | |
""" | |
err_prefix = 'When calling %s: ' % func.__name__ | |
result, warnings = _unmatched_warnings([], err_prefix, func, *args, **kw) | |
if warnings: | |
raise AssertionError("Got warnings when calling %s: %s" | |
% (func.__name__, warnings)) | |
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TestWarns(unittest.TestCase): | |
@staticmethod | |
def _make_warner(*warning_args): | |
def fn(): | |
for warning in warning_args: | |
warnings.warn(*warning) | |
return 3 | |
return fn | |
def setUp(self): | |
self._before_filters = sys.modules['warnings'].filters[:] | |
def tearDown(self): | |
sys.modules['warnings'].filters = self._before_filters | |
def test_expected_warning(self): | |
f = self._make_warner(("yo",)) | |
before_filters = sys.modules['warnings'].filters[:] | |
assert_equal(assert_warns(UserWarning, f), 3) | |
after_filters = sys.modules['warnings'].filters | |
# Check that the warnings state is unchanged | |
assert_equal(before_filters, after_filters, | |
"assert_warns does not preserver warnings state") | |
def text_expected_absence(self): | |
assert_equal(assert_no_warnings(lambda x: x, 1), 1) | |
def test_unexpected_warning(self): | |
f = self._make_warner(("yo",)) | |
before_filters = sys.modules['warnings'].filters[:] | |
assert_raises(AssertionError, assert_no_warnings, f) | |
after_filters = sys.modules['warnings'].filters | |
# Check that the warnings state is unchanged | |
assert_equal(before_filters, after_filters, | |
"assert_warns does not preserver warnings state") | |
def test_unexpected_type(self): | |
f = self._make_warner(("yo", DeprecationWarning)) | |
before_filters = sys.modules['warnings'].filters[:] | |
assert_raises(AssertionError, assert_warns, UserWarning, f) | |
after_filters = sys.modules['warnings'].filters | |
# Check that the warnings state is unchanged | |
assert_equal(before_filters, after_filters, | |
"assert_warns does not preserver warnings state") | |
def test_unexpected_absence(self): | |
assert_raises(AssertionError, assert_warns, DeprecationWarning, | |
lambda: 1) | |
def test_count_warnings_exact(self): | |
f0 = self._make_warner() | |
f1 = self._make_warner(("msg 1",)) | |
f2 = self._make_warner(("msg 1",), ("msg 2",)) | |
assert_warns({'count': 0}, f0) | |
assert_raises(AssertionError, assert_warns, {'count': 0}, f1) | |
assert_raises(AssertionError, assert_warns, {'count': 0}, f2) | |
assert_warns({'count': 1}, f1) | |
assert_raises(AssertionError, assert_warns, {'count': 1}, f0) | |
assert_raises(AssertionError, assert_warns, {'count': 1}, f2) | |
assert_warns({'count': 2}, f2) | |
assert_raises(AssertionError, assert_warns, {'count': 2}, f0) | |
assert_raises(AssertionError, assert_warns, {'count': 2}, f1) | |
def test_count_warnings_wildcard(self): | |
f0 = self._make_warner() | |
f1 = self._make_warner(("msg 1",)) | |
f2 = self._make_warner(("msg 1",), ("msg 2",)) | |
assert_warns({'count': '*'}, f0) | |
assert_warns({'count': '*'}, f1) | |
assert_warns({'count': '*'}, f2) | |
assert_warns({'count': '+'}, f1) | |
assert_warns({'count': '+'}, f2) | |
assert_raises(AssertionError, assert_warns, {'count': '+'}, f0) | |
assert_raises(ValueError, assert_warns, {'count': '!'}, f2) | |
def test_filter_by_category(self): | |
f_user = self._make_warner(("foo", UserWarning,)) | |
f_depr = self._make_warner(("foo", DeprecationWarning,)) | |
f_both = self._make_warner(("foo", UserWarning,), | |
("foo", DeprecationWarning,)) | |
assert_warns({}, f_user) | |
assert_warns({}, f_depr) | |
assert_warns({'count': 2}, f_both) | |
assert_warns({'category': Warning}, f_user) | |
assert_warns({'category': Warning}, f_depr) | |
assert_warns({'category': Warning, 'count': 2}, f_both) | |
assert_warns({'category': UserWarning}, f_user) | |
assert_raises(AssertionError, assert_warns, | |
{'category': UserWarning}, f_depr) | |
assert_raises(AssertionError, assert_warns, | |
{'category': UserWarning, 'count': 2}, f_both) | |
assert_warns({'category': DeprecationWarning}, f_depr) | |
assert_raises(AssertionError, assert_warns, | |
{'category': DeprecationWarning}, f_user) | |
assert_raises(AssertionError, assert_warns, | |
{'category': DeprecationWarning, 'count': 2}, f_both) | |
def test_filter_by_message(self): | |
f_foo = self._make_warner(("foo",)) | |
f_bar = self._make_warner(("bar",)) | |
f_both = self._make_warner(("bar",), ("both",)) | |
assert_warns({}, f_foo) | |
assert_warns({}, f_bar) | |
assert_warns({'count': 2}, f_both) | |
assert_warns({'message': 'fo'}, f_foo) | |
assert_warns({'message': 'FO'}, f_foo) | |
assert_warns({'message': 'o$'}, f_foo) | |
assert_warns({'message': re.compile('oo$')}, f_foo) | |
assert_raises(AssertionError, assert_warns, | |
{'message': re.compile('OO$')}, f_foo) | |
assert_raises(AssertionError, assert_warns, | |
{'message': 'fo'}, f_bar) | |
assert_raises(AssertionError, assert_warns, | |
{'message': re.compile('oo$')}, f_bar) | |
assert_raises(AssertionError, assert_warns, | |
{'message': re.compile('oo$')}, f_both) | |
assert_raises(AssertionError, assert_warns, | |
{'message': re.compile('oo$'), 'count': 2}, f_both) | |
def test_filter_by_filename(self): | |
# XXX: need a way to spoof this | |
f_foo = self._make_warner(("foo",)) | |
assert_warns({'filename': 'test_'}, f_foo) | |
assert_raises(AssertionError, assert_warns, | |
{'filename': r'!~`junk /\\'}, f_foo) | |
def test_filter_conjunction(self): | |
f = self._make_warner(("foo", DeprecationWarning)) | |
assert_warns({'category': DeprecationWarning, | |
'message': 'foo', | |
'filename': 'test_'}, f) | |
assert_raises(AssertionError, assert_warns, { | |
'category': UserWarning, | |
'message': 'foo', | |
'filename': 'test_'}, f) | |
assert_raises(AssertionError, assert_warns, { | |
'category': DeprecationWarning, | |
'message': 'bar', | |
'filename': 'test_'}, f) | |
assert_raises(AssertionError, assert_warns, { | |
'category': DeprecationWarning, | |
'message': 'foo', | |
'filename': '!~junk \\/'}, f) | |
def test_matchers_cascade(self): | |
f = self._make_warner(('foo',), ('bar',)) | |
m1 = {'message': 'foo'} | |
m2 = {'message': 'bar'} | |
m_all = {} | |
assert_warns([m1, m2], f) | |
assert_warns([m2, m1], f) | |
assert_warns([m_all, m2], f) | |
assert_raises(AssertionError, assert_warns, [m_all, m1], f) | |
def test_args_passed(self): | |
def f(a, b): | |
return a + b | |
assert_equal(assert_no_warnings(f, 1, b=2), 3) | |
def f(a, b): | |
warnings.warn('foo') | |
return a + b | |
assert_equal(assert_warns({}, f, 1, b=2), 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Copyright (c) 2005-2013, NumPy Developers, Joel Nothman. | |
All rights reserved. | |
Redistribution and use in source and binary forms, with or without | |
modification, are permitted provided that the following conditions are | |
met: | |
* Redistributions of source code must retain the above copyright | |
notice, this list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above | |
copyright notice, this list of conditions and the following | |
disclaimer in the documentation and/or other materials provided | |
with the distribution. | |
* Neither the name of the NumPy Developers nor the names of any | |
contributors may be used to endorse or promote products derived | |
from this software without specific prior written permission. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment