Skip to content

Instantly share code, notes, and snippets.

@jnothman
Last active December 17, 2015 18:09
Show Gist options
  • Save jnothman/5650845 to your computer and use it in GitHub Desktop.
Save jnothman/5650845 to your computer and use it in GitHub Desktop.
An extension of numpy.testing.assert_warns that handles checking the warning message/filename and handles multiple warnings using a cascade of critera.
"""An extension of numpy.testing.assert_warns
Handles checking of message and multiple warnings.
"""
from __future__ import absolute_import, print_function
import sys
import re
import warnings
try:
from collections import Sequence, Mapping
except ImportError:
Sequence = (list, tuple)
Mapping = (dict,)
# The following two classes are copied from python 2.6 warnings module (context
# manager)
class WarningMessage(object):
"""
Holds the result of a single showwarning() call.
Notes
-----
`WarningMessage` is copied from the Python 2.6 warnings module,
so it can be used in NumPy with older Python versions.
"""
_WARNING_DETAILS = ("message", "category", "filename", "lineno", "file",
"line")
def __init__(self, message, category, filename, lineno, file=None,
line=None):
local_values = locals()
for attr in self._WARNING_DETAILS:
setattr(self, attr, local_values[attr])
if category:
self._category_name = category.__name__
else:
self._category_name = None
def __str__(self):
return ("{message : %r, category : %r, filename : %r, lineno : %s, "
"line : %r}" % (self.message, self._category_name,
self.filename, self.lineno, self.line))
class WarningManager(object):
"""
A context manager that copies and restores the warnings filter upon
exiting the context.
The 'record' argument specifies whether warnings should be captured by a
custom implementation of ``warnings.showwarning()`` and be appended to a
list returned by the context manager. Otherwise None is returned by the
context manager. The objects appended to the list are arguments whose
attributes mirror the arguments to ``showwarning()``.
The 'module' argument is to specify an alternative module to the module
named 'warnings' and imported under that name. This argument is only useful
when testing the warnings module itself.
Notes
-----
`WarningManager` is a copy of the ``catch_warnings`` context manager
from the Python 2.6 warnings module, with slight modifications.
It is copied so it can be used in NumPy with older Python versions.
"""
def __init__(self, record=False, module=None):
self._record = record
if module is None:
self._module = sys.modules['warnings']
else:
self._module = module
self._entered = False
def __enter__(self):
if self._entered:
raise RuntimeError("Cannot enter %r twice" % self)
self._entered = True
self._filters = self._module.filters
self._module.filters = self._filters[:]
self._showwarning = self._module.showwarning
if self._record:
log = []
def showwarning(*args, **kwargs):
log.append(WarningMessage(*args, **kwargs))
self._module.showwarning = showwarning
return log
else:
return None
def __exit__(self):
if not self._entered:
raise RuntimeError("Cannot exit %r without entering first" % self)
self._module.filters = self._filters
self._module.showwarning = self._showwarning
class WarningMatcher(object):
"""Matches a specified number and description of warnings
Parameters
----------
category : class, optional
The expected warning class or a super-class (default matches any).
message : regular expression or string, optional
A regular expression expected to be matched (ignoring case) in the
warning's message (default matches any).
filename : regular expression or string, optional
A regular expression expected to be matched in the warning's filename
(default matches any).
count : integer, '+', or '*' (default 1)
The exact number of warnings that should match, or '*' for 0 or more,
or '+' for 1 or more.
"""
def __init__(self, category=None, message=None, filename=None, count=1):
self.category = category
self.message = self._compile_re(message)
self.filename = self._compile_re(filename)
self.count = count
if count not in ('*', '+') and type(count) != int:
raise ValueError('"count" should be "*", "+" or an integer')
def _compile_re(self, obj):
if obj is None or hasattr(obj, 'search'):
return obj
return re.compile(obj, flags=re.IGNORECASE)
@classmethod
def _from_arg(cls, arg):
if isinstance(arg, cls):
return arg
if isinstance(arg, Mapping):
return cls(**arg)
return cls(arg)
def matches_category(self, warning):
return self.category is None or issubclass(warning.category, self.category)
def matches_message(self, warning):
return self.message is None or self.message.search(str(warning.message))
def matches_filename(self, warning):
return self.filename is None or self.filename.search(warning.filename)
def matches(self, warning):
return (self.matches_category(warning) and
self.matches_message(warning) and self.matches_filename(warning))
def __call__(self, record, err_prefix=''):
"""Check the matching warnings in `record` are of the specified count
Raises `AssertionError` prefixed by `err_prefix` if the specified count
of matching warnings is not found.
Returns
-------
unmatched : list of `WarningMessage`
Any warnings that did not match this `WarningMatcher`'s criteria.
"""
unmatched = []
n_matched = 0
for warning in record:
if n_matched != self.count and self.matches(warning):
n_matched += 1
else:
unmatched.append(warning)
if self.count == '*':
pass
elif self.count == '+':
if n_matched == 0:
raise AssertionError('%sExpected 1 or more warnings matching '
' %s' % (err_prefix, self))
elif self.count != n_matched:
raise AssertionError('%sExpected %d warnings matching %s, got '
'%s' % (err_prefix, self.count, self,
n_matched))
return unmatched
def __repr__(self):
return '%s(%s)'.format(self.__class__.__name__, self._kwargs_str())
def _kwargs_str(self, with_count=True):
parts = []
if self.category is not None:
parts.append(('category', self.category))
if self.message is not None:
parts.append(('message', self.message.pattern))
if self.filename is not None:
parts.append(('filename', self.filename.pattern))
if with_count:
fields.append(('count', self.count))
return ', '.join('%s=%r' % part for part in parts)
def __str__(self, with_count=False):
args = self._kwargs_str(with_count=False)
if not args:
return '<any>'
else:
return '<%s>' % args
def _unmatched_warnings(matchers, err_prefix, func, *args, **kw):
"""Checks callable raises matched warnings and returns those unmatched"""
if not isinstance(matchers, Sequence):
matchers = [matchers]
matchers = [WarningMatcher._from_arg(matcher) for matcher in matchers]
# XXX: once we may depend on python >= 2.6, this can be replaced by the
# warnings module context manager.
ctx = WarningManager(record=True)
l = ctx.__enter__()
warnings.simplefilter('always')
try:
result = func(*args, **kw)
finally:
ctx.__exit__()
for matcher in matchers:
l = matcher(l, err_prefix)
return result, l
def assert_warns(warning_class, func, *args, **kw):
"""
Fail unless the given callable throws the specified warning(s).
When the callable is invoked with arguments args and keyword arguments kw,
any warnings raised must be matched by `warning_class`. If any unmatched warnings
If a different type of warning is thrown, it will not be caught, and the
test case will be deemed to have suffered an error.
Parameters
----------
warning_class : class or mapping, or sequence thereof.
The class(es) or matcher(s) for warnings that `func` is expected to
raise. Entries may be mappings with the following keys:
* 'category': the expected warning class or a super-class (default
any).
* 'message': a regular expression expected to be matched (ignoring
case) in the warning's message (default any).
* 'filename': a regular expression expected to be matched in the
warning's filename (default any).
* 'count': the exact number of warnings that should match, or '*'
for 0 or more, or '+' for 1 or more (default 1).
Each rule is applied in turn to consume warnings in their order of
appearance. As such, specific rules should usually precede more general
rules. For example, to require a `UserWarning`, but allow any others to
also occur, use `warning_class=[UserWarning, {'count': '*'}]`.
func : callable
The callable to test.
\\*args : Arguments
Arguments passed to `func`.
\\*\\*kw : Kwargs
Keyword arguments passed to `func`.
Returns
-------
The value returned by `func`.
"""
err_prefix = 'When calling %s: ' % func.__name__
result, unmatched = _unmatched_warnings(warning_class, err_prefix, func,
*args, **kw)
if unmatched:
raise AssertionError('%sWarnings %s not matched' % (err_prefix,
unmatched))
return result
def assert_no_warnings(func, *args, **kw):
"""
Fail if the given callable produces any warnings.
Parameters
----------
func : callable
The callable to test.
\\*args : Arguments
Arguments passed to `func`.
\\*\\*kw : Kwargs
Keyword arguments passed to `func`.
Returns
-------
The value returned by `func`.
"""
err_prefix = 'When calling %s: ' % func.__name__
result, warnings = _unmatched_warnings([], err_prefix, func, *args, **kw)
if warnings:
raise AssertionError("Got warnings when calling %s: %s"
% (func.__name__, warnings))
return result
class TestWarns(unittest.TestCase):
@staticmethod
def _make_warner(*warning_args):
def fn():
for warning in warning_args:
warnings.warn(*warning)
return 3
return fn
def setUp(self):
self._before_filters = sys.modules['warnings'].filters[:]
def tearDown(self):
sys.modules['warnings'].filters = self._before_filters
def test_expected_warning(self):
f = self._make_warner(("yo",))
before_filters = sys.modules['warnings'].filters[:]
assert_equal(assert_warns(UserWarning, f), 3)
after_filters = sys.modules['warnings'].filters
# Check that the warnings state is unchanged
assert_equal(before_filters, after_filters,
"assert_warns does not preserver warnings state")
def text_expected_absence(self):
assert_equal(assert_no_warnings(lambda x: x, 1), 1)
def test_unexpected_warning(self):
f = self._make_warner(("yo",))
before_filters = sys.modules['warnings'].filters[:]
assert_raises(AssertionError, assert_no_warnings, f)
after_filters = sys.modules['warnings'].filters
# Check that the warnings state is unchanged
assert_equal(before_filters, after_filters,
"assert_warns does not preserver warnings state")
def test_unexpected_type(self):
f = self._make_warner(("yo", DeprecationWarning))
before_filters = sys.modules['warnings'].filters[:]
assert_raises(AssertionError, assert_warns, UserWarning, f)
after_filters = sys.modules['warnings'].filters
# Check that the warnings state is unchanged
assert_equal(before_filters, after_filters,
"assert_warns does not preserver warnings state")
def test_unexpected_absence(self):
assert_raises(AssertionError, assert_warns, DeprecationWarning,
lambda: 1)
def test_count_warnings_exact(self):
f0 = self._make_warner()
f1 = self._make_warner(("msg 1",))
f2 = self._make_warner(("msg 1",), ("msg 2",))
assert_warns({'count': 0}, f0)
assert_raises(AssertionError, assert_warns, {'count': 0}, f1)
assert_raises(AssertionError, assert_warns, {'count': 0}, f2)
assert_warns({'count': 1}, f1)
assert_raises(AssertionError, assert_warns, {'count': 1}, f0)
assert_raises(AssertionError, assert_warns, {'count': 1}, f2)
assert_warns({'count': 2}, f2)
assert_raises(AssertionError, assert_warns, {'count': 2}, f0)
assert_raises(AssertionError, assert_warns, {'count': 2}, f1)
def test_count_warnings_wildcard(self):
f0 = self._make_warner()
f1 = self._make_warner(("msg 1",))
f2 = self._make_warner(("msg 1",), ("msg 2",))
assert_warns({'count': '*'}, f0)
assert_warns({'count': '*'}, f1)
assert_warns({'count': '*'}, f2)
assert_warns({'count': '+'}, f1)
assert_warns({'count': '+'}, f2)
assert_raises(AssertionError, assert_warns, {'count': '+'}, f0)
assert_raises(ValueError, assert_warns, {'count': '!'}, f2)
def test_filter_by_category(self):
f_user = self._make_warner(("foo", UserWarning,))
f_depr = self._make_warner(("foo", DeprecationWarning,))
f_both = self._make_warner(("foo", UserWarning,),
("foo", DeprecationWarning,))
assert_warns({}, f_user)
assert_warns({}, f_depr)
assert_warns({'count': 2}, f_both)
assert_warns({'category': Warning}, f_user)
assert_warns({'category': Warning}, f_depr)
assert_warns({'category': Warning, 'count': 2}, f_both)
assert_warns({'category': UserWarning}, f_user)
assert_raises(AssertionError, assert_warns,
{'category': UserWarning}, f_depr)
assert_raises(AssertionError, assert_warns,
{'category': UserWarning, 'count': 2}, f_both)
assert_warns({'category': DeprecationWarning}, f_depr)
assert_raises(AssertionError, assert_warns,
{'category': DeprecationWarning}, f_user)
assert_raises(AssertionError, assert_warns,
{'category': DeprecationWarning, 'count': 2}, f_both)
def test_filter_by_message(self):
f_foo = self._make_warner(("foo",))
f_bar = self._make_warner(("bar",))
f_both = self._make_warner(("bar",), ("both",))
assert_warns({}, f_foo)
assert_warns({}, f_bar)
assert_warns({'count': 2}, f_both)
assert_warns({'message': 'fo'}, f_foo)
assert_warns({'message': 'FO'}, f_foo)
assert_warns({'message': 'o$'}, f_foo)
assert_warns({'message': re.compile('oo$')}, f_foo)
assert_raises(AssertionError, assert_warns,
{'message': re.compile('OO$')}, f_foo)
assert_raises(AssertionError, assert_warns,
{'message': 'fo'}, f_bar)
assert_raises(AssertionError, assert_warns,
{'message': re.compile('oo$')}, f_bar)
assert_raises(AssertionError, assert_warns,
{'message': re.compile('oo$')}, f_both)
assert_raises(AssertionError, assert_warns,
{'message': re.compile('oo$'), 'count': 2}, f_both)
def test_filter_by_filename(self):
# XXX: need a way to spoof this
f_foo = self._make_warner(("foo",))
assert_warns({'filename': 'test_'}, f_foo)
assert_raises(AssertionError, assert_warns,
{'filename': r'!~`junk /\\'}, f_foo)
def test_filter_conjunction(self):
f = self._make_warner(("foo", DeprecationWarning))
assert_warns({'category': DeprecationWarning,
'message': 'foo',
'filename': 'test_'}, f)
assert_raises(AssertionError, assert_warns, {
'category': UserWarning,
'message': 'foo',
'filename': 'test_'}, f)
assert_raises(AssertionError, assert_warns, {
'category': DeprecationWarning,
'message': 'bar',
'filename': 'test_'}, f)
assert_raises(AssertionError, assert_warns, {
'category': DeprecationWarning,
'message': 'foo',
'filename': '!~junk \\/'}, f)
def test_matchers_cascade(self):
f = self._make_warner(('foo',), ('bar',))
m1 = {'message': 'foo'}
m2 = {'message': 'bar'}
m_all = {}
assert_warns([m1, m2], f)
assert_warns([m2, m1], f)
assert_warns([m_all, m2], f)
assert_raises(AssertionError, assert_warns, [m_all, m1], f)
def test_args_passed(self):
def f(a, b):
return a + b
assert_equal(assert_no_warnings(f, 1, b=2), 3)
def f(a, b):
warnings.warn('foo')
return a + b
assert_equal(assert_warns({}, f, 1, b=2), 3)
Copyright (c) 2005-2013, NumPy Developers, Joel Nothman.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
* Neither the name of the NumPy Developers nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment