-
-
Save austinzh/7794cdccdf99eb8b8b4060e75597cdbc to your computer and use it in GitHub Desktop.
TensorFlow Keras Model Training Example with Apache Arrow Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import multiprocessing | |
import os | |
import socket | |
import sys | |
from sklearn.preprocessing import StandardScaler | |
import numpy as np | |
import pandas as pd | |
import pyarrow as pa | |
import pyarrow.csv | |
import tensorflow as tf | |
tf.enable_eager_execution() | |
import tensorflow_io.arrow as arrow_io | |
import warnings | |
warnings.simplefilter(action='ignore', category=FutureWarning) | |
def write_csv(filename, num_records): | |
"""Generate sample data and write to a CSV file.""" | |
data = {'label': np.random.binomial(1, 0.5, num_records)} | |
data['x0'] = np.random.randn(num_records) + 5 * data['label'] | |
data['x1'] = np.random.randn(num_records) + 5 * data['label'] | |
df = pd.DataFrame(data) | |
df.to_csv('sample.csv', index=False) | |
df = None | |
def read_and_process(filename): | |
"""Read the given CSV file and yield processed Arrow batches.""" | |
# Read a CSV file into an Arrow Table with threading enabled and | |
# set block_size in bytes to break the file into chunks for granularity, | |
# which determines the number of batches in the resulting pyarrow.Table | |
opts = pyarrow.csv.ReadOptions(use_threads=True, block_size=4096) | |
table = pyarrow.csv.read_csv(filename, opts) | |
# Fit the feature transform | |
df = table.to_pandas() | |
scaler = StandardScaler().fit(df[['x0', 'x1']]) | |
# Iterate over batches in the pyarrow.Table and apply processing | |
for batch in table.to_batches(): | |
df = batch.to_pandas() | |
# Process the batch and apply feature transform | |
X_scaled = scaler.transform(df[['x0', 'x1']]) | |
df_scaled = pd.DataFrame({'label': df['label'], | |
'x0': X_scaled[:, 0], | |
'x1': X_scaled[:, 1]}) | |
batch_scaled = pa.RecordBatch.from_pandas(df_scaled, preserve_index=False) | |
yield batch_scaled | |
def read_and_process_dir(directory): | |
"""Read a directory of CSV files and yield processed Arrow batches.""" | |
for f in os.listdir(directory): | |
if f.endswith(".csv"): | |
filename = os.path.join(directory, f) | |
for batch in read_and_process(filename): | |
yield batch | |
def serve_csv_data(ip_addr, port_num, directory): | |
""" | |
Create a socket and serve Arrow record batches as a stream read from the | |
given directory containing CVS files. | |
""" | |
# Create the socket | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
sock.bind((ip_addr, port_num)) | |
sock.listen(1) | |
# Serve forever, each client will get one iteration over data | |
while True: | |
conn, _ = sock.accept() | |
outfile = conn.makefile(mode='wb') | |
writer = None | |
try: | |
# Read directory and iterate over each batch in each file | |
batch_iter = read_and_process_dir(directory) | |
for batch in batch_iter: | |
# Initialize the pyarrow writer on first batch | |
if writer is None: | |
writer = pa.RecordBatchStreamWriter(outfile, batch.schema) | |
# Write the batch to the client stream | |
writer.write_batch(batch) | |
# Cleanup client connection | |
finally: | |
if writer is not None: | |
writer.close() | |
outfile.close() | |
conn.close() | |
sock.close() | |
def start_server_process(host_addr, host_port, serve_dir): | |
"""Start a process to serve CSV data as an Arrow stream.""" | |
server = multiprocessing.Process(target=serve_csv_data, | |
args=(host_addr, host_port, serve_dir)) | |
server.daemon = True | |
server.start() | |
def make_local_dataset(filename): | |
"""Make a TensorFlow Arrow Dataset that reads from a local CSV file.""" | |
# Read the local file and get a record batch iterator | |
batch_iter = read_and_process(filename) | |
# Create the Arrow Dataset as a stream from local iterator of record batches | |
ds = arrow_io.ArrowStreamDataset.from_record_batches( | |
batch_iter, | |
columns=(0, 1, 2), | |
output_types=(tf.int64, tf.float64, tf.float64), | |
output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])), | |
batch_mode='auto', | |
record_batch_iter_factory=partial(read_and_process, filename)) | |
# Map the dataset to combine feature columns to single tensor | |
ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l)) | |
return ds | |
def make_remote_dataset(endpoint): | |
"""Make a TensorFlow Arrow Dataset that reads from a remote Arrow stream.""" | |
# Create the Arrow Dataset from a remote endpoint serving a stream | |
ds = arrow_io.ArrowStreamDataset( | |
[endpoint], | |
columns=(0, 1, 2), | |
output_types=(tf.int64, tf.float64, tf.float64), | |
batch_mode='auto') | |
# Map the dataset to combine feature columns to single tensor | |
ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l)) | |
return ds | |
def model_fit(ds): | |
"""Create and fit a Keras logistic regression model.""" | |
# Build the Keras model | |
model = tf.keras.Sequential() | |
model.add(tf.keras.layers.Dense(1, input_shape=(2,), activation='sigmoid')) | |
model.compile(optimizer='sgd', loss='mean_squared_error', metrics=['accuracy']) | |
# Fit the model on the given dataset | |
model.fit(ds, epochs=5, shuffle=False) | |
return model | |
if __name__ == '__main__': | |
# Parse flag to run local or remote dataset | |
run_remote = False | |
if len(sys.argv) >= 2 and sys.argv[1] == '--run-remote': | |
run_remote = True | |
# Write a sample data as a CSV file | |
filename = 'sample.csv' | |
num_records = 1000 | |
write_csv(filename, num_records) | |
if not run_remote: | |
print('Running model fit on local file: {}'.format(filename)) | |
make_dataset_fn = partial(make_local_dataset, | |
filename=filename) | |
else: | |
host_addr = '127.0.0.1' | |
host_port = 8888 | |
serve_dir = './' | |
print('Running model fit with remote host: {}:{}, serving directory: {}' | |
.format(host_addr, host_port, serve_dir)) | |
start_server_process(host_addr, host_port, serve_dir) | |
make_dataset_fn = partial(make_remote_dataset, | |
endpoint='{}:{}'.format(host_addr, host_port)) | |
# Create the dataset | |
ds = make_dataset_fn() | |
# Fit the model | |
model = model_fit(ds) | |
print("Fit model with weights: {}".format(model.get_weights())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment