Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created April 4, 2019 10:35
Allows to compare data in two DataFlows, eg. for regression tests.
import numpy as np
from tensorpack import DataFlow
class CompareData(DataFlow):
"""
Compares that two DataFlows generate equal data, raises ValueError if not.
"""
def __init__(self, a, b):
self.a = a
self.b = b
assert a.size() == b.size(), \
"Both DataFlows must have the same size! {} != {}".format(a.size(), b.size())
def reset_state(self):
for d in [self.a, self.b]:
d.reset_state()
def size(self):
"""
Return the minimum size among all.
"""
return min([self.a.size(), self.b.size()])
def get_data(self):
it_a = self.a.get_data()
it_b = self.b.get_data()
try:
while True:
data_a = next(it_a)
data_b = next(it_b)
yield compare_trees(data_a, data_b)
except StopIteration: # some of them are exhausted
pass
finally:
del it_a
del it_b
def compare_trees(a, b, path=""):
types = [type(x) for x in [a, b]]
if types[0] != types[1] and set(types) != {list, tuple}:
raise ValueError('%s: Non-equal types: %s' % (path, types))
if types[0] in [list, tuple]:
if len(a) != len(b):
raise ValueError('%s: Non-equal sequence lengths: %d %d' % (path, len(a), len(b)))
for i, (item_a, item_b) in enumerate(zip(a, b)):
compare_trees(item_a, item_b, path='%s/%d' % (path, i))
elif types[0] == np.ndarray:
if a.dtype != b.dtype:
raise ValueError('%s: Non-equal dtypes of numpy arrays: %s, %s' % (path, a.dtype, b.dtype))
if a.shape != b.shape:
raise ValueError('%s: Non-equal shapes of numpy arrays: %s, %s' % (path, a.shape, b.shape))
if not np.allclose(a, b):
raise ValueError('%s: Numpy array values are not close: %s %s' % (path, a, b))
else:
if a != b:
raise ValueError('%s: Non-equal values: %s %s' % (path, a, b))
print('%s: OK' % path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment