Last active
November 17, 2017 03:11
-
-
Save smurching/3e050c40a87c360ca9e5577556ddea05 to your computer and use it in GitHub Desktop.
KerasVectorTransformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import transformer | |
from sparkdl.transformers import KerasVectorTransformer | |
# Create input DataFrame | |
data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), | |
(Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), | |
(Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] | |
df = spark.createDataFrame(data, ["features"]) | |
# Create KerasVectorTransformer | |
transformer = KerasVectorTransformer(inputCol="features", | |
outputCol="transformed_features", modelFile="path/to/my/model.h5") | |
# Compute result, which has a column "transformed_features" of | |
# DenseVectors corresponding to Keras model output | |
result = transformer.transform(df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment