Last active
October 15, 2018 11:32
-
-
Save allwefantasy/59d1f89026d8049e23f182c7d1de2870 to your computer and use it in GitHub Desktop.
cifar2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- download cifar-10 images from https://github.com/allwefantasy/spark-deep-learning-toy/releases | |
set json = '''{}'''; | |
load jsonStr.`json` as emptyData; | |
set labelMappingPath = "/tmp/si"; | |
set imageConvertPath = "/tmp/cifar_train_data"; | |
set modelPath = "/tmp/bigdl"; | |
run emptyData as ImageLoaderExt.`/Users/allwefantasy/Downloads/cifar/train` where | |
and code=''' | |
def apply(params:Map[String,String]) = { | |
Resize(28, 28) -> | |
MatToTensor() -> ImageFrameToSample() | |
} | |
''' | |
as data; | |
-- convert image path to number label | |
select split(split(imageName,"_")[1],"\\.")[0] as labelStr,features from data as newdata; | |
train newdata as StringIndex.`${labelMappingPath}` where inputCol="labelStr" and outputCol="labelIndex" as newdata1; | |
predict newdata as StringIndex.`${labelMappingPath}` as newdata2; | |
select (cast(labelIndex as int) + 1) as label,features from newdata2 as newdata3; | |
-- save image processing result. | |
save overwrite newdata3 as parquet.`${imageConvertPath}`; | |
load parquet.`${imageConvertPath}` as newdata3; | |
select array(cast(label as float)) as label,features from newdata3 as newdata4; | |
--train with LeNet5 model | |
train newdata4 as BigDLClassifyExt.`${modelPath}` where | |
fitParam.0.featureSize="[3,28,28]" | |
and fitParam.0.classNum="10" | |
and fitParam.0.maxEpoch="50" | |
and fitParam.0.code=''' | |
def apply(params:Map[String,String])={ | |
val model = Sequential() | |
model.add(Reshape(Array(3, 28, 28), inputShape = Shape(28, 28, 3))) | |
model.add(Convolution2D(6, 5, 5, activation = "tanh").setName("conv1_5x5")) | |
model.add(MaxPooling2D()) | |
model.add(Convolution2D(12, 5, 5, activation = "tanh").setName("conv2_5x5")) | |
model.add(MaxPooling2D()) | |
model.add(Flatten()) | |
model.add(Dense(100, activation = "tanh").setName("fc1")) | |
model.add(Dense(params("classNum").toInt, activation = "softmax").setName("fc2")) | |
} | |
''' | |
; | |
-- batch predict | |
predict newdata4 as BigDLClassifyExt.`${modelPath}` as predictdata; | |
-- deploy with api server | |
-- register BigDLClassifyExt.`/tmp/bigdl` as mnistPredict; | |
-- select | |
-- vec_argmax(mnistPredict(vec_dense(features))) as predict_label, | |
-- label from data | |
-- as output; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment