Created
May 29, 2015 07:09
-
-
Save danielrenshaw/8fd71250f9cba5a530c2 to your computer and use it in GitHub Desktop.
Theano diff (from a4e182d) for altering theano/compile/debugmode.py to enable NaN and inf checks during debugprinting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
16c17,18 | |
--- | |
> import re | |
> | |
517c519,521 | |
< scan_ops=None, profile=None): | |
--- | |
> scan_ops=None, profile=None, include_nan_info=False, | |
> include_inf_info=False, recursion_rules='ALWAYS', | |
> print_test_value=False): | |
558c563,586 | |
--- | |
> recursion_rules = set([recursion_rule for recursion_rule in recursion_rules.split(',') if len(recursion_rule) > 0]) | |
> | |
> def nan_inf_info(prefix, enabled, checker): | |
> if r is not None and enabled: | |
> if hasattr(r, 'tag') and r.tag is not None and hasattr(r.tag, 'test_value'): | |
> if isinstance(r.tag.test_value, numpy.ndarray): | |
> mask = checker(r.tag.test_value) | |
> | |
> if mask.all(): | |
> return ' <%s: ALL>' % prefix, 'ALL_' + prefix in recursion_rules | |
> elif mask.any(): | |
> return ' <%s: SOME>' % prefix, 'SOME_' + prefix in recursion_rules | |
> | |
> return ' <%s: NONE>' % prefix, 'NO_' + prefix in recursion_rules | |
> | |
> return ' <%s: NOT_NDARRAY>' % prefix, 'NO_TEST_VALUE' in recursion_rules | |
> | |
> return ' <%s: NO_TEST_VALUE>' % prefix, 'NO_TEST_VALUE' in recursion_rules | |
> | |
> return '', 'NO_TEST_VALUE' in recursion_rules | |
> | |
> nan_info, nan_recurse = nan_inf_info('NANS', include_nan_info, numpy.isnan) | |
> inf_info, inf_recurse = nan_inf_info('INFS', include_inf_info, numpy.isinf) | |
> | |
575c604,608 | |
--- | |
> if r is not None and print_test_value and hasattr(r, 'tag') and r.tag is not None and hasattr(r.tag, 'test_value'): | |
> test_value = ' %s %s' % (r.tag.test_value.shape, re.sub('\\s+', ' ', repr(r.tag.test_value))) | |
> else: | |
> test_value = '' | |
> | |
610,616c643,651 | |
< print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op, | |
< id_str, | |
< type_str, | |
< r_name, | |
< destroy_map_str, | |
< view_map_str, | |
< o), file=file) | |
--- | |
> print('%s%s %s%s%s%s \'%s\' %s %s %s%s' % (prefix, a.op, | |
> id_str, | |
> type_str, | |
> nan_info, | |
> inf_info, | |
> r_name, | |
> destroy_map_str, | |
> view_map_str, | |
> o, test_value), file=file) | |
618,624c653,661 | |
< print('%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op, | |
< a.outputs.index(r), | |
< id_str, type_str, | |
< r_name, | |
< destroy_map_str, | |
< view_map_str, | |
< o), file=file) | |
--- | |
> print('%s%s.%i %s%s%s%s \'%s\' %s %s %s%s' % (prefix, a.op, | |
> a.outputs.index(r), | |
> id_str, type_str, | |
> nan_info, | |
> inf_info, | |
> r_name, | |
> destroy_map_str, | |
> view_map_str, | |
> o, test_value), file=file) | |
633,644c670,682 | |
< print("%s%s %s%s '%s' %s %s %s --> " | |
< "%8.2es %4.1f%% %8.2es %4.1f%%" | |
< % (prefix, a.op, | |
< id_str, | |
< type_str, | |
< r_name, | |
< destroy_map_str, | |
< view_map_str, | |
< o, op_time, | |
< op_time_percent, | |
< tot_time, | |
< tot_time_percent), file=file) | |
--- | |
> print('%s%s %s%s%s%s \'%s\' %s %s %s%s --> %8.2es %4.1f%% %8.2es %4.1f%%'\ | |
> % (prefix, a.op, | |
> id_str, | |
> type_str, | |
> nan_info, | |
> inf_info, | |
> r_name, | |
> destroy_map_str, | |
> view_map_str, | |
> o, test_value, op_time, | |
> op_time_percent, | |
> tot_time, | |
> tot_time_percent), file=file) | |
646,657c684,696 | |
< print("%s%s.%i %s%s '%s' %s %s %s --> " | |
< "%8.2es %4.1f%% %8.2es %4.1f%%" | |
< % (prefix, a.op, | |
< a.outputs.index(r), | |
< id_str, type_str, | |
< r_name, | |
< destroy_map_str, | |
< view_map_str, | |
< o, op_time, | |
< op_time_percent, | |
< tot_time, | |
< tot_time_percent), file=file) | |
--- | |
> print('%s%s.%i %s%s%s%s \'%s\' %s %s %s%s --> %8.2es %4.1f%% %8.2es %4.1f%%'\ | |
> % (prefix, a.op, | |
> a.outputs.index(r), | |
> id_str, type_str, | |
> nan_info, | |
> inf_info, | |
> r_name, | |
> destroy_map_str, | |
> view_map_str, | |
> o, test_value, op_time, | |
> op_time_percent, | |
> tot_time, | |
> tot_time_percent), file=file) | |
659c699,700 | |
--- | |
> recurse = nan_recurse or inf_recurse or 'ALWAYS' in recursion_rules | |
> | |
674,678c715,723 | |
< debugprint(i, new_prefix, depth=depth - 1, done=done, | |
< print_type=print_type, file=file, order=order, | |
< ids=ids, stop_on_name=stop_on_name, | |
< prefix_child=new_prefix_child, | |
< scan_ops=scan_ops, profile=profile) | |
--- | |
> if recurse: | |
> debugprint(i, new_prefix, depth=depth - 1, done=done, | |
> print_type=print_type, file=file, order=order, | |
> ids=ids, stop_on_name=stop_on_name, | |
> prefix_child=new_prefix_child, scan_ops=scan_ops, | |
> profile=profile, include_nan_info=include_nan_info, | |
> include_inf_info=include_inf_info, | |
> recursion_rules=','.join(recursion_rules) if recurse else '', | |
> print_test_value=print_test_value) | |
683c728 | |
< print('%s%s %s%s' % (prefix, r, id_str, type_str), file=file) | |
--- | |
> print('%s%s %s%s%s%s%s' % (prefix, r, id_str, type_str, nan_info, inf_info, test_value), file=file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment