Created
October 15, 2018 10:59
-
-
Save allwefantasy/4213dd7461c6d00bb482f0e26cd5394c to your computer and use it in GitHub Desktop.
cifar train
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- set json = '''{}'''; | |
-- load jsonStr.`json` as emptyData; | |
-- -- run emptyData as MnistLoaderExt.`` where | |
-- -- mnistDir="/Users/allwefantasy/Downloads/mnist" | |
-- -- as data; | |
-- run emptyData as ImageLoaderExt.`/Users/allwefantasy/Downloads/cifar/test` | |
-- where numOfImageTasks="4" | |
-- and code=''' | |
-- def apply(params:Map[String,String]) = { | |
-- Resize(28, 28) -> | |
-- MatToTensor() -> ImageFrameToSample() | |
-- } | |
-- ''' | |
-- as data; | |
-- select split(split(imageName,"_")[1],"\\.")[0] as labelStr,features from data as newdata; | |
-- train newdata as StringIndex.`/tmp/si` where inputCol="labelStr" and outputCol="labelIndex" as newdata1; | |
-- predict newdata as StringIndex.`/tmp/si` as newdata2; | |
-- select (cast(labelIndex as int) + 1) as label,features from newdata2 as newdata3; | |
-- save overwrite newdata3 as parquet.`/tmp/cifar_train_data`; | |
load parquet.`/tmp/cifar_train_data` as newdata3; | |
select array(cast(label as float)) as label,features from newdata3 as newdata4; | |
-- select * from newdata4 limit 1 as output; | |
train newdata4 as BigDLClassifyExt.`/tmp/bigdl` where | |
fitParam.0.featureSize="[3,28,28]" | |
and fitParam.0.classNum="10" | |
and fitParam.0.maxEpoch="50" | |
and fitParam.0.code=''' | |
def apply(params:Map[String,String])={ | |
val model = Sequential() | |
model.add(Reshape(Array(3, 28, 28), inputShape = Shape(28, 28, 3))) | |
model.add(Convolution2D(6, 5, 5, activation = "tanh").setName("conv1_5x5")) | |
model.add(MaxPooling2D()) | |
model.add(Convolution2D(12, 5, 5, activation = "tanh").setName("conv2_5x5")) | |
model.add(MaxPooling2D()) | |
model.add(Flatten()) | |
model.add(Dense(100, activation = "tanh").setName("fc1")) | |
model.add(Dense(params("classNum").toInt, activation = "softmax").setName("fc2")) | |
} | |
''' | |
; | |
predict newdata4 as BigDLClassifyExt.`/tmp/bigdl`; | |
-- register BigDLClassifyExt.`/tmp/bigdl` as mnistPredict; | |
-- select | |
-- vec_argmax(mnistPredict(vec_dense(features))) as predict_label, | |
-- label from data | |
-- as output; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment