Skip to content

Instantly share code, notes, and snippets.

@smurching
Last active November 17, 2017 03:11
Show Gist options
  • Save smurching/3e050c40a87c360ca9e5577556ddea05 to your computer and use it in GitHub Desktop.
Save smurching/3e050c40a87c360ca9e5577556ddea05 to your computer and use it in GitHub Desktop.
KerasVectorTransformer
# Import transformer
from sparkdl.transformers import KerasVectorTransformer
# Create input DataFrame
data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),),
(Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),),
(Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)]
df = spark.createDataFrame(data, ["features"])
# Create KerasVectorTransformer
transformer = KerasVectorTransformer(inputCol="features",
outputCol="transformed_features", modelFile="path/to/my/model.h5")
# Compute result, which has a column "transformed_features" of
# DenseVectors corresponding to Keras model output
result = transformer.transform(df)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment