Last active
June 7, 2020 01:06
-
-
Save matpalm/9bfbd138caf8f264c451fdbf1b3ae34a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EmbeddingModel(keras.Model): | |
def train_step(self, data): | |
anchors, positives = data | |
print("a,p", anchors.shape, positives.shape) | |
with tf.GradientTape() as tape: | |
# Run both anchors and positives through model. | |
anchor_embeddings = self(anchors, training=True) | |
positive_embeddings = self(positives, training=True) | |
print("ae,pe", anchor_embeddings.shape, positive_embeddings.shape) | |
# Calculate cosine similarity between anchors and positives. As they have | |
# be normalised this is just the pair wise dot products. | |
similarities = tf.einsum('ae,pe->ap', anchor_embeddings, | |
positive_embeddings) | |
print("s", similarities.shape) | |
# Since we intend to use these as logits we scale them by a temperature. | |
# This value would normally be chosen as a hyper parameter. | |
temperature = 0.2 | |
similarities /= temperature | |
# We use these similarities as logits for a softmax. The labels for | |
# this call are just the sequence 0, 1, 2, ... since we want the main | |
# diagonal values, which correspond to the anchor/positive pairs, to be | |
# high. This loss will move embeddings for the anchor/positive pairs | |
# together and move all other pairs apart. | |
labels = tf.range(similarities.get_shape()[0]) | |
loss = self.compiled_loss(labels, similarities) | |
# Calculate gradients and apply via optimizer. | |
gradients = tape.gradient(loss, self.trainable_variables) | |
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) | |
# Update and return metrics (specifically the one for the loss value) | |
self.compiled_metrics.update_state(labels, similarities) | |
return {m.name: m.result() for m in self.metrics} | |
inputs = layers.Input(shape=(height_width, height_width, 3)) | |
x = layers.Conv2D(filters=4, kernel_size=3, strides=2, | |
activation='relu')(inputs) | |
x = layers.Conv2D(filters=8, kernel_size=3, strides=2, | |
activation='relu')(x) | |
x = layers.Conv2D(filters=16, kernel_size=3, strides=2, | |
activation='relu')(x) | |
x = layers.GlobalAveragePooling2D()(x) | |
embeddings = layers.Dense(units=4, activation=None)(x) | |
embeddings = tf.nn.l2_normalize(embeddings, axis=-1) | |
model = EmbeddingModel(inputs, embeddings) | |
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
model.compile(optimizer="adam", loss=loss_fn) | |
dataset = build_dataset(batch_size=32, num_batchs=30) | |
history = model.fit(dataset, epochs=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Epoch 1/10 | |
a,p (None, 32, 32, 3) (None, 32, 32, 3) | |
ae,pe (None, 4) (None, 4) | |
s (None, None) | |
--------------------------------------------------------------------------- | |
ValueError Traceback (most recent call last) | |
<ipython-input-28-e96bb720671d> in <module>() | |
56 | |
57 dataset = build_dataset(batch_size=32, num_batchs=30) | |
---> 58 history = model.fit(dataset, epochs=10) | |
10 frames | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs) | |
966 except Exception as e: # pylint:disable=broad-except | |
967 if hasattr(e, "ag_error_metadata"): | |
--> 968 raise e.ag_error_metadata.to_exception(e) | |
969 else: | |
970 raise | |
ValueError: in user code: | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function * | |
outputs = self.distribute_strategy.run( | |
<ipython-input-28-e96bb720671d>:31 train_step * | |
labels = tf.range(similarities.get_shape()[0]) | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1571 range ** | |
limit = ops.convert_to_tensor(limit, dtype=dtype, name="limit") | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1341 convert_to_tensor | |
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py:321 _constant_tensor_conversion_function | |
return constant(v, dtype=dtype, name=name) | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py:262 constant | |
allow_broadcast=True) | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/constant_op.py:300 _constant_impl | |
allow_broadcast=allow_broadcast)) | |
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/tensor_util.py:439 make_tensor_proto | |
raise ValueError("None values not supported.") | |
ValueError: None values not supported. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ONLY difference is the hardcoded tf.reshape of anchors & positives | |
class EmbeddingModel(keras.Model): | |
def train_step(self, data): | |
anchors, positives = data | |
anchors = tf.reshape(anchors, (32, 32, 32, 3)) # <----- !!!!!!!!!!!!!!!!!!!! | |
positives = tf.reshape(positives, (32, 32, 32, 3)) # <----- !!!!!!!!!!!!!!!!!!!! | |
print("a,p", anchors.shape, positives.shape) | |
with tf.GradientTape() as tape: | |
# Run both anchors and positives through model. | |
anchor_embeddings = self(anchors, training=True) | |
positive_embeddings = self(positives, training=True) | |
print("ae,pe", anchor_embeddings.shape, positive_embeddings.shape) | |
# Calculate cosine similarity between anchors and positives. As they have | |
# be normalised this is just the pair wise dot products. | |
similarities = tf.einsum('ae,pe->ap', anchor_embeddings, | |
positive_embeddings) | |
print("s", similarities.shape) | |
# Since we intend to use these as logits we scale them by a temperature. | |
# This value would normally be chosen as a hyper parameter. | |
temperature = 0.2 | |
similarities /= temperature | |
# We use these similarities as logits for a softmax. The labels for | |
# this call are just the sequence 0, 1, 2, ... since we want the main | |
# diagonal values, which correspond to the anchor/positive pairs, to be | |
# high. This loss will move embeddings for the anchor/positive pairs | |
# together and move all other pairs apart. | |
labels = tf.range(similarities.get_shape()[0]) | |
loss = self.compiled_loss(labels, similarities) | |
# Calculate gradients and apply via optimizer. | |
gradients = tape.gradient(loss, self.trainable_variables) | |
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) | |
# Update and return metrics (specifically the one for the loss value) | |
self.compiled_metrics.update_state(labels, similarities) | |
return {m.name: m.result() for m in self.metrics} | |
inputs = layers.Input(shape=(height_width, height_width, 3)) | |
x = layers.Conv2D(filters=4, kernel_size=3, strides=2, | |
activation='relu')(inputs) | |
x = layers.Conv2D(filters=8, kernel_size=3, strides=2, | |
activation='relu')(x) | |
x = layers.Conv2D(filters=16, kernel_size=3, strides=2, | |
activation='relu')(x) | |
x = layers.GlobalAveragePooling2D()(x) | |
embeddings = layers.Dense(units=4, activation=None)(x) | |
embeddings = tf.nn.l2_normalize(embeddings, axis=-1) | |
model = EmbeddingModel(inputs, embeddings) | |
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
model.compile(optimizer="adam", loss=loss_fn) | |
dataset = build_dataset(batch_size=32, num_batchs=30) | |
history = model.fit(dataset, epochs=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Epoch 1/10 | |
a,p (32, 32, 32, 3) (32, 32, 32, 3) | |
ae,pe (32, 4) (32, 4) | |
s (32, 32) | |
a,p (32, 32, 32, 3) (32, 32, 32, 3) | |
ae,pe (32, 4) (32, 4) | |
s (32, 32) | |
30/30 [==============================] - 1s 21ms/step - loss: 3.3582 | |
Epoch 2/10 | |
30/30 [==============================] - 1s 21ms/step - loss: 2.8144 | |
Epoch 3/10 | |
30/30 [==============================] - 1s 21ms/step - loss: 2.5596 | |
Epoch 4/10 | |
30/30 [==============================] - 1s 24ms/step - loss: 2.2366 | |
Epoch 5/10 | |
30/30 [==============================] - 1s 20ms/step - loss: 2.1365 | |
Epoch 6/10 | |
30/30 [==============================] - 1s 20ms/step - loss: 1.9939 | |
Epoch 7/10 | |
30/30 [==============================] - 1s 19ms/step - loss: 1.8852 | |
Epoch 8/10 | |
30/30 [==============================] - 1s 19ms/step - loss: 1.8468 | |
Epoch 9/10 | |
30/30 [==============================] - 1s 20ms/step - loss: 1.7414 | |
Epoch 10/10 | |
30/30 [==============================] - 1s 19ms/step - loss: 1.7000 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment