Created
October 15, 2018 11:04
-
-
Save allwefantasy/1fd9ed63a892ca5292242a48180d14f7 to your computer and use it in GitHub Desktop.
cifar2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set json = '''{}'''; | |
load jsonStr.`json` as emptyData; | |
run emptyData as ImageLoaderExt.`/Users/allwefantasy/Downloads/cifar/train` | |
where numOfImageTasks="4" | |
and code=''' | |
def apply(params:Map[String,String]) = { | |
Resize(28, 28) -> | |
MatToTensor() -> ImageFrameToSample() | |
} | |
''' | |
as data; | |
select split(split(imageName,"_")[1],"\\.")[0] as labelStr,features from data as newdata; | |
train newdata as StringIndex.`/tmp/si` where inputCol="labelStr" and outputCol="labelIndex" as newdata1; | |
predict newdata as StringIndex.`/tmp/si` as newdata2; | |
select (cast(labelIndex as int) + 1) as label,features from newdata2 as newdata3; | |
save overwrite newdata3 as parquet.`/tmp/cifar_train_data`; | |
load parquet.`/tmp/cifar_train_data` as newdata3; | |
select array(cast(label as float)) as label,features from newdata3 as newdata4; | |
-- select * from newdata4 limit 1 as output; | |
train newdata4 as BigDLClassifyExt.`/tmp/bigdl` where | |
fitParam.0.featureSize="[3,28,28]" | |
and fitParam.0.classNum="10" | |
and fitParam.0.maxEpoch="50" | |
and fitParam.0.code=''' | |
def apply(params:Map[String,String])={ | |
val model = Sequential() | |
model.add(Reshape(Array(3, 28, 28), inputShape = Shape(28, 28, 3))) | |
model.add(Convolution2D(6, 5, 5, activation = "tanh").setName("conv1_5x5")) | |
model.add(MaxPooling2D()) | |
model.add(Convolution2D(12, 5, 5, activation = "tanh").setName("conv2_5x5")) | |
model.add(MaxPooling2D()) | |
model.add(Flatten()) | |
model.add(Dense(100, activation = "tanh").setName("fc1")) | |
model.add(Dense(params("classNum").toInt, activation = "softmax").setName("fc2")) | |
} | |
''' | |
; | |
predict newdata4 as BigDLClassifyExt.`/tmp/bigdl` as predictdata; | |
-- register BigDLClassifyExt.`/tmp/bigdl` as mnistPredict; | |
-- select | |
-- vec_argmax(mnistPredict(vec_dense(features))) as predict_label, | |
-- label from data | |
-- as output; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment