Skip to content

Instantly share code, notes, and snippets.

@dmmiller612
Last active July 26, 2022 20:06
Show Gist options
  • Save dmmiller612/f72af7936dd3c0d0da2197e6046e88df to your computer and use it in GitHub Desktop.
Save dmmiller612/f72af7936dd3c0d0da2197e6046e88df to your computer and use it in GitHub Desktop.
SparkTorch Example
from sparktorch import SparkTorch, serialize_torch_obj
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import rand
from pyspark.ml.pipeline import Pipeline
import torch
import torch.nn as nn
spark = SparkSession.builder.appName("examples").master('local[*]').getOrCreate()
# Read in mnist_train.csv dataset
df = spark.read.option("inferSchema", "true").csv('examples/mnist_train.csv').orderBy(rand()).repartition(2)
# Create a simple neural network
network = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# Build the pytorch object
torch_obj = serialize_torch_obj(
model=network,
criterion=nn.CrossEntropyLoss(),
optimizer=torch.optim.Adam,
lr=0.001
)
# Setup features (In this csv, the label is the first column)
vector_assembler = VectorAssembler(inputCols=df.columns[1:785], outputCol='features')
# Setup a SparkTorch model for training
# Note: This uses the barrier execution mode, which is sensitive to the number of partitions
spark_model = SparkTorch(
inputCol='features',
labelCol='_c0',
torchObj=torch_obj,
iters=1000,
miniBatch=256, # Setup internal mini batch size
earlyStopPatience=40, # Add early stopping based on validation loss
validationPct=0.2 # add validation percentage
)
# Create and save the Pipeline
Pipeline(stages=[vector_assembler, spark_model]).fit(df).save('mnist_model')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment