Last active
July 26, 2022 20:07
-
-
Save dmmiller612/d89f23462f0cd53cb18b78f3be9fbf92 to your computer and use it in GitHub Desktop.
SparkFlow example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sparkflow.graph_utils import build_graph | |
from sparkflow.tensorflow_async import SparkAsyncDL | |
import tensorflow as tf | |
from pyspark.ml.feature import VectorAssembler, OneHotEncoder | |
from pyspark.ml.pipeline import Pipeline | |
#simple tensorflow network | |
def small_model(): | |
x = tf.placeholder(tf.float32, shape=[None, 784], name='x') | |
y = tf.placeholder(tf.float32, shape=[None, 10], name='y') | |
layer1 = tf.layers.dense(x, 256, activation=tf.nn.relu) | |
layer2 = tf.layers.dense(layer1, 256, activation=tf.nn.relu) | |
out = tf.layers.dense(layer2, 10) | |
z = tf.argmax(out, 1, name='out') | |
loss = tf.losses.softmax_cross_entropy(y, out) | |
return loss | |
df = spark.read.option("inferSchema", "true").csv('mnist_train.csv') | |
#convert graph to json | |
tensorflow_graph = build_graph(small_model) | |
#Assemble and one hot encode | |
va = VectorAssembler(inputCols=df.columns[1:785], outputCol='features') | |
encoded = OneHotEncoder(inputCol='_c0', outputCol='labels', dropLast=False) | |
spark_model = SparkAsyncDL( | |
inputCol='features', | |
tensorflowGraph=tensorflow_graph, | |
tfInput='x:0', | |
tfLabel='y:0', | |
tfOutput='out:0', | |
tfLearningRate=.001, | |
iters=20 | |
) | |
p = Pipeline(stages=[va, encoded, spark_model]).fit(df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment