Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save east825/1e74fdedbaf2c49e15470badb26153b7 to your computer and use it in GitHub Desktop.
Save east825/1e74fdedbaf2c49e15470badb26153b7 to your computer and use it in GitHub Desktop.
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with arbitrarily nested structures."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class NestTest(test.TestCase):
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
def testFlattenDictOrder(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
def testPackDictOrder(self):
"""Packing orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
self.assertEqual(
collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
ordered_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
def testFlattenAndPackWithDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
named_tuple = collections.namedtuple("A", ("b", "c"))
mess = (
"z",
named_tuple(3, 4),
{
"c": (
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
),
"b": 5
},
17
)
flattened = nest.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
structure_of_mess = (
14,
named_tuple("a", True),
{
"c": (
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
),
"b": 3
},
"hi everybody",
)
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
self.assertEqual(unflattened, mess)
# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = unflattened[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
def testFlattenSparseValue(self):
st = sparse_tensor.SparseTensorValue([[0]], [0], [1])
single_value = st
list_of_values = [st, st, st]
nest_of_values = ((st), ((st), (st)))
dict_of_values = {"foo": st, "bar": st, "baz": st}
self.assertEqual([st], nest.flatten(single_value))
self.assertEqual([[st, st, st]], nest.flatten(list_of_values))
self.assertEqual([st, st, st], nest.flatten(nest_of_values))
self.assertEqual([st, st, st], nest.flatten(dict_of_values))
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertFalse(nest.is_sequence([1, 3, [4, 5]]))
self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
self.assertFalse(nest.is_sequence([]))
self.assertFalse(nest.is_sequence(set([1, 2])))
ones = array_ops.ones([2, 3])
self.assertFalse(nest.is_sequence(ones))
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))
self.assertTrue(nest.is_sequence({"foo": 1, "bar": 2}))
self.assertFalse(
nest.is_sequence(sparse_tensor.SparseTensorValue([[0]], [0], [1])))
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
structure_dictionary = {"foo": 2, "bar": 4, "baz": {"foo": 5, "bar": 6}}
structure_dictionary_diff_nested = {
"foo": 2,
"bar": 4,
"baz": {
"foo": 5,
"baz": 6
}
}
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure((0, 1), np.array([0, 1]))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(0, (0, 1))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(structure1, structure_different_nesting)
named_type_0 = collections.namedtuple("named_0", ("a", "b"))
named_type_1 = collections.namedtuple("named_1", ("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
named_type_0("a", "b"))
nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
named_type_0(3, 4), named_type_1(3, 4))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4))
with self.assertRaisesRegexp(ValueError,
"don't have the same nested structure"):
nest.assert_same_structure(((3,), 4), (3, (4,)))
structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
structure2_list = {"a": ((1, 2), 3), "b": 4, "d": (5, 6)}
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)
with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
nest.assert_same_structure(structure1_list, structure2_list)
with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
nest.assert_same_structure(structure_dictionary,
structure_dictionary_diff_nested)
nest.assert_same_structure(
structure_dictionary,
structure_dictionary_diff_nested,
check_types=False)
nest.assert_same_structure(
structure1_list, structure2_list, check_types=False)
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, foo="a")
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
def testAssertShallowStructure(self):
inp_ab = ("a", "b")
inp_abc = ("a", "b", "c")
expected_message = (
"The two structures don't have the same sequence length. Input "
"structure has length 2, while shallow structure has length 3.")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_abc, inp_ab)
inp_ab1 = ((1, 1), (2, 2))
inp_ab2 = {"a": (1, 1), "b": (2, 2)}
expected_message = (
"The two structures don't have the same sequence type. Input structure "
"has type <(type|class) 'tuple'>, while shallow structure has type "
"<(type|class) 'dict'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
expected_message = (
r"The two structures don't have the same keys. Input "
r"structure has keys \['c'\], while shallow structure has "
r"keys \['d'\].")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
nest.assert_shallow_structure(inp_ab, inp_ba)
def testFlattenUpTo(self):
input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
shallow_tree = ((True, True), (False, True))
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
input_tree_flattened = nest.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = ("input_tree_0", "input_tree_1")
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = (0,)
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = (0, 1)
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ("shallow_tree",)
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'str'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
input_tree = "input_tree"
shallow_tree = ("shallow_tree_9", "shallow_tree_8")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
# Using non-iterable elements.
input_tree = 0
shallow_tree = (9,)
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'int'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
input_tree = 0
shallow_tree = (9, 8)
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, list(shallow_tree))
# Using dict.
input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
shallow_tree = {"a": (True, True), "b": (False, True)}
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
def testMapStructureUpTo(self):
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)
data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7)))
name_list = ("evens", ("odds", "primes"))
out = nest.map_structure_up_to(
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes")))
if __name__ == "__main__":
test.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment