Skip to content

Instantly share code, notes, and snippets.

@allwefantasy
Created October 15, 2018 11:04
Show Gist options
  • Save allwefantasy/1fd9ed63a892ca5292242a48180d14f7 to your computer and use it in GitHub Desktop.
Save allwefantasy/1fd9ed63a892ca5292242a48180d14f7 to your computer and use it in GitHub Desktop.
cifar2
set json = '''{}''';
load jsonStr.`json` as emptyData;
run emptyData as ImageLoaderExt.`/Users/allwefantasy/Downloads/cifar/train`
where numOfImageTasks="4"
and code='''
def apply(params:Map[String,String]) = {
Resize(28, 28) ->
MatToTensor() -> ImageFrameToSample()
}
'''
as data;
select split(split(imageName,"_")[1],"\\.")[0] as labelStr,features from data as newdata;
train newdata as StringIndex.`/tmp/si` where inputCol="labelStr" and outputCol="labelIndex" as newdata1;
predict newdata as StringIndex.`/tmp/si` as newdata2;
select (cast(labelIndex as int) + 1) as label,features from newdata2 as newdata3;
save overwrite newdata3 as parquet.`/tmp/cifar_train_data`;
load parquet.`/tmp/cifar_train_data` as newdata3;
select array(cast(label as float)) as label,features from newdata3 as newdata4;
-- select * from newdata4 limit 1 as output;
train newdata4 as BigDLClassifyExt.`/tmp/bigdl` where
fitParam.0.featureSize="[3,28,28]"
and fitParam.0.classNum="10"
and fitParam.0.maxEpoch="50"
and fitParam.0.code='''
def apply(params:Map[String,String])={
val model = Sequential()
model.add(Reshape(Array(3, 28, 28), inputShape = Shape(28, 28, 3)))
model.add(Convolution2D(6, 5, 5, activation = "tanh").setName("conv1_5x5"))
model.add(MaxPooling2D())
model.add(Convolution2D(12, 5, 5, activation = "tanh").setName("conv2_5x5"))
model.add(MaxPooling2D())
model.add(Flatten())
model.add(Dense(100, activation = "tanh").setName("fc1"))
model.add(Dense(params("classNum").toInt, activation = "softmax").setName("fc2"))
}
'''
;
predict newdata4 as BigDLClassifyExt.`/tmp/bigdl` as predictdata;
-- register BigDLClassifyExt.`/tmp/bigdl` as mnistPredict;
-- select
-- vec_argmax(mnistPredict(vec_dense(features))) as predict_label,
-- label from data
-- as output;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment