Created
January 12, 2021 06:12
-
-
Save emaxerrno/2c5adaec14b4a1af8443d0ce19b8c183 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from datetime import datetime | |
import time | |
import threading | |
import json | |
from kafka import KafkaProducer | |
from kafka.errors import KafkaError | |
from sklearn.model_selection import train_test_split | |
import pandas as pd | |
import tensorflow as tf | |
import tensorflow_io as tfio | |
print("tensorflow-io version: {}".format(tfio.__version__)) | |
print("tensorflow version: {}".format(tf.__version__)) | |
COLUMNS = [ | |
# labels | |
'class', | |
# low-level features | |
'lepton_1_pT', | |
'lepton_1_eta', | |
'lepton_1_phi', | |
'lepton_2_pT', | |
'lepton_2_eta', | |
'lepton_2_phi', | |
'missing_energy_magnitude', | |
'missing_energy_phi', | |
# high-level derived features | |
'MET_rel', | |
'axial_MET', | |
'M_R', | |
'M_TR_2', | |
'R', | |
'MT2', | |
'S_R', | |
'M_Delta_R', | |
'dPhi_r_b', | |
'cos(theta_r1)' | |
] | |
susy_iterator = pd.read_csv('SUSY.csv.gz', header=None, names=COLUMNS, chunksize=100000) | |
susy_df = next(susy_iterator) | |
susy_df.head() | |
# Number of datapoints and columns | |
len(susy_df), len(susy_df.columns) | |
# Number of datapoints belonging to each class (0: background noise, 1: signal) | |
len(susy_df[susy_df["class"]==0]), len(susy_df[susy_df["class"]==1]) | |
train_df, test_df = train_test_split(susy_df, test_size=0.4, shuffle=True) | |
print("Number of training samples: ",len(train_df)) | |
print("Number of testing sample: ",len(test_df)) | |
x_train_df = train_df.drop(["class"], axis=1) | |
y_train_df = train_df["class"] | |
x_test_df = test_df.drop(["class"], axis=1) | |
y_test_df = test_df["class"] | |
# The labels are set as the kafka message keys so as to store data | |
# in multiple-partitions. Thus, enabling efficient data retrieval | |
# using the consumer groups. | |
x_train = list(filter(None, x_train_df.to_csv(index=False).split("\n")[1:])) | |
y_train = list(filter(None, y_train_df.to_csv(index=False).split("\n")[1:])) | |
x_test = list(filter(None, x_test_df.to_csv(index=False).split("\n")[1:])) | |
y_test = list(filter(None, y_test_df.to_csv(index=False).split("\n")[1:])) | |
NUM_COLUMNS = len(x_train_df.columns) | |
len(x_train), len(y_train), len(x_test), len(y_test) | |
def error_callback(exc): | |
raise Exception('Error while sendig data to kafka: {0}'.format(str(exc))) | |
def write_to_kafka(topic_name, items): | |
count=0 | |
producer = KafkaProducer(bootstrap_servers=['127.0.0.1:9092']) | |
for message, key in items: | |
producer.send(topic_name, key=key.encode('utf-8'), value=message.encode('utf-8')).add_errback(error_callback) | |
count+=1 | |
producer.flush() | |
print("Wrote {0} messages into topic: {1}".format(count, topic_name)) | |
write_to_kafka("susy-train", zip(x_train, y_train)) | |
write_to_kafka("susy-test", zip(x_test, y_test)) | |
def decode_kafka_item(item): | |
message = tf.io.decode_csv(item.message, [[0.0] for i in range(NUM_COLUMNS)]) | |
key = tf.strings.to_number(item.key) | |
return (message, key) | |
BATCH_SIZE=64 | |
SHUFFLE_BUFFER_SIZE=64 | |
train_ds = tfio.IODataset.from_kafka('susy-train', partition=0, offset=0) | |
train_ds = train_ds.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) | |
train_ds = train_ds.map(decode_kafka_item) | |
train_ds = train_ds.batch(BATCH_SIZE) | |
OPTIMIZER="adam" | |
LOSS=tf.keras.losses.BinaryCrossentropy(from_logits=True) | |
METRICS=['accuracy'] | |
EPOCHS=10 | |
# design/build the model | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Input(shape=(NUM_COLUMNS,)), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(256, activation='relu'), | |
tf.keras.layers.Dropout(0.4), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dropout(0.4), | |
tf.keras.layers.Dense(1, activation='sigmoid') | |
]) | |
print(model.summary()) | |
# compile the model | |
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS) | |
# fit the model | |
model.fit(train_ds, epochs=EPOCHS) | |
test_ds = tfio.experimental.streaming.KafkaGroupIODataset( | |
topics=["susy-test"], | |
group_id="testcg", | |
servers="127.0.0.1:9092", | |
stream_timeout=10000, | |
configuration=[ | |
"session.timeout.ms=7000", | |
"max.poll.interval.ms=8000", | |
"auto.offset.reset=earliest" | |
], | |
) | |
def decode_kafka_test_item(raw_message, raw_key): | |
message = tf.io.decode_csv(raw_message, [[0.0] for i in range(NUM_COLUMNS)]) | |
key = tf.strings.to_number(raw_key) | |
return (message, key) | |
test_ds = test_ds.map(decode_kafka_test_item) | |
test_ds = test_ds.batch(BATCH_SIZE) | |
res = model.evaluate(test_ds) | |
print("test loss, test acc:", res) | |
online_train_ds = tfio.experimental.streaming.KafkaBatchIODataset( | |
topics=["susy-train"], | |
group_id="cgonline", | |
servers="127.0.0.1:9092", | |
stream_timeout=30000, # in milliseconds, to block indefinitely, set it to -1. | |
configuration=[ | |
"session.timeout.ms=7000", | |
"max.poll.interval.ms=8000", | |
"auto.offset.reset=earliest" | |
], | |
) | |
def error_callback(exc): | |
raise Exception('Error while sendig data to kafka: {0}'.format(str(exc))) | |
def write_to_kafka_after_sleep(topic_name, items): | |
time.sleep(30) | |
print("#"*100) | |
print("Writing messages into topic: {0} after a nice sleep !".format(topic_name)) | |
print("#"*100) | |
count=0 | |
producer = KafkaProducer(bootstrap_servers=['127.0.0.1:9092']) | |
for message, key in items: | |
producer.send(topic_name, | |
key=key.encode('utf-8'), | |
value=message.encode('utf-8') | |
).add_errback(error_callback) | |
count+=1 | |
producer.flush() | |
print("#"*100) | |
print("Wrote {0} messages into topic: {1}".format(count, topic_name)) | |
print("#"*100) | |
def decode_kafka_online_item(raw_message, raw_key): | |
message = tf.io.decode_csv(raw_message, [[0.0] for i in range(NUM_COLUMNS)]) | |
key = tf.strings.to_number(raw_key) | |
return (message, key) | |
thread = threading.Thread(target=write_to_kafka_after_sleep, | |
args=("susy-train", zip(x_train, y_train))) | |
thread.daemon = True | |
thread.start() | |
for mini_ds in online_train_ds: | |
mini_ds = mini_ds.shuffle(buffer_size=32) | |
mini_ds = mini_ds.map(decode_kafka_online_item) | |
mini_ds = mini_ds.batch(32) | |
model.fit(mini_ds, epochs=3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment