This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_ds(from_i, to_i): | |
ds = tf.data.Dataset.range(from_i, to_i) | |
ds = ds.flat_map(lambda x: tf.data.Dataset.range(x * 3, (x + 1 ) * 3)) | |
return ds | |
def get_windowed_ds(i): | |
ds_from = get_ds(i, i+1) | |
ds_to = get_ds(i+1, i+2) | |
ds_concat = ds_from.concatenate(ds_to) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from functools import partial | |
def get_ds(from_i, to_i): | |
ds = tf.data.Dataset.range(from_i, to_i) | |
ds = ds.flat_map(lambda x: tf.data.Dataset.range(x*3, (x + 1 )* 3)) | |
return ds | |
datasets = [get_ds(i, i+1) for i in range(0, 5)] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# no parallel read performance improvement compared to interleave | |
dataset = tf.data.Dataset.range(2).flat_map(lambda x: tf.data.Dataset.range(x*3, (x + 1 )* 3)) | |
for t in dataset.take(5): | |
print(t.numpy()) | |
# 0 | |
# 1 | |
# 2 | |
# 3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dataset = tf.data.Dataset.range(2).interleave(lambda x: tf.data.Dataset.range(x*10, (x+ 1 )* 10)) | |
for t in dataset.take(5): | |
print(t.numpy()) | |
# 0 | |
# 10 | |
# 1 | |
# 11 | |
# 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from functools import partial | |
ds_raw = tf.constant([i for i in range(0, 100)]) | |
ds = tf.data.Dataset.from_tensor_slices(ds_raw) | |
ds_a = ds.window(2).flat_map(lambda x: x.batch(2)) | |
# for b, skip 1 and keep only last element | |
ds_b = ds.skip(1).window(2).flat_map(lambda x: x.batch(2)).map(lambda x: x[-1]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from functools import partial | |
def collate_pair(x): | |
return x[:-1], x[-1] | |
ds_raw = tf.constant([i for i in range(0, 100)]) | |
ds = tf.data.Dataset.from_tensor_slices(ds_raw) | |
ds = ds.window(3).flat_map(lambda x: x.batch(6)).map(collate_pair) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from functools import partial | |
def collate_pair(x, window): | |
# a, last sample of b | |
return x[:, 0], [x[-1, 1]] | |
ds_raw = tf.constant([[i, 1 if i % 3 == 0 else 0] for i in range(0, 100)]) | |
ds = tf.data.Dataset.from_tensor_slices(ds_raw) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"model_spec": { | |
"name": "bitcoin_predictor", | |
"signature_name": "", | |
"version": "1" | |
}, | |
"metadata": { | |
"signature_def": { | |
... | |
"serving_raw": { |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_serve_raw_fn(model, tf_transform_output): | |
""" | |
Returns a tf.function that preforms preprocessing and inference. | |
For usage as the seving_raw signature. | |
""" | |
model.tft_layer = tf_transform_output.transform_features_layer() | |
@tf.function | |
def serve_raw_fn(features): | |
# you cancan also drop label key or non-needed keys here |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_input_graph(input_feature_keys) -> Tuple[Input, tf.keras.layers.Layer]: | |
""" | |
Creates the named input layers, strips the column names and provides | |
them as a plain tensor. | |
Returns: | |
Tuple[Input, tf.keras.layers.Layer]: Input for your model -- A layer with output shape | |
[None (batch size), input_window_size, len(input_feature_keys)] | |
""" | |
# if you are using Tensorflow Transform or Tensorflow Extended |