Last active
July 26, 2022 20:06
-
-
Save dmmiller612/f72af7936dd3c0d0da2197e6046e88df to your computer and use it in GitHub Desktop.
SparkTorch Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sparktorch import SparkTorch, serialize_torch_obj | |
from pyspark.sql import SparkSession | |
from pyspark.ml.feature import VectorAssembler | |
from pyspark.sql.functions import rand | |
from pyspark.ml.pipeline import Pipeline | |
import torch | |
import torch.nn as nn | |
spark = SparkSession.builder.appName("examples").master('local[*]').getOrCreate() | |
# Read in mnist_train.csv dataset | |
df = spark.read.option("inferSchema", "true").csv('examples/mnist_train.csv').orderBy(rand()).repartition(2) | |
# Create a simple neural network | |
network = nn.Sequential( | |
nn.Linear(784, 256), | |
nn.ReLU(), | |
nn.Linear(256, 256), | |
nn.ReLU(), | |
nn.Linear(256, 10) | |
) | |
# Build the pytorch object | |
torch_obj = serialize_torch_obj( | |
model=network, | |
criterion=nn.CrossEntropyLoss(), | |
optimizer=torch.optim.Adam, | |
lr=0.001 | |
) | |
# Setup features (In this csv, the label is the first column) | |
vector_assembler = VectorAssembler(inputCols=df.columns[1:785], outputCol='features') | |
# Setup a SparkTorch model for training | |
# Note: This uses the barrier execution mode, which is sensitive to the number of partitions | |
spark_model = SparkTorch( | |
inputCol='features', | |
labelCol='_c0', | |
torchObj=torch_obj, | |
iters=1000, | |
miniBatch=256, # Setup internal mini batch size | |
earlyStopPatience=40, # Add early stopping based on validation loss | |
validationPct=0.2 # add validation percentage | |
) | |
# Create and save the Pipeline | |
Pipeline(stages=[vector_assembler, spark_model]).fit(df).save('mnist_model') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment