Skip to content

Instantly share code, notes, and snippets.

@allwefantasy
Last active November 9, 2018 03:06
Show Gist options
  • Save allwefantasy/c57b58621493d702f35ff8ecd4c4fcce to your computer and use it in GitHub Desktop.
Save allwefantasy/c57b58621493d702f35ff8ecd4c4fcce to your computer and use it in GitHub Desktop.
bigdl1
set trainDataPath="/tmp/cifar_train_data";
set testDataPath="/tmp/cifar_test_data";
load parquet.`${trainDataPath}` as tmpTrainData;
load parquet.`${testDataPath}` as tmpTestData;
select array(cast(label as float)) as label,features from tmpTrainData as trainData;
select array(cast(label as float)) as label,features from tmpTestData as testData;
train trainData as BigDLClassifyExt.`/tmp/bigdl` where
disableSparkLog = "true"
and fitParam.0.featureSize="[3,28,28]"
and fitParam.0.classNum="10"
and fitParam.0.maxEpoch="10"
-- print evaluate message
and fitParam.0.evaluate.trigger.everyEpoch="true"
and fitParam.0.evaluate.batchSize="1000"
and fitParam.0.evaluate.table="testData"
and fitParam.0.evaluate.methods="Loss,Top1Accuracy"
-- for unbalanced class
-- and fitParam.0.criterion.classWeight="[......]"
and fitParam.0.summary.trainDir="/tmp/bigdl-summary"
and fitParam.0.summary.validateDir="/tmp/bigdl-summary"
and fitParam.0.code='''
def apply(params:Map[String,String])={
val model = Sequential()
model.add(Reshape(Array(3, 28, 28), inputShape = Shape(28, 28, 3)))
model.add(Convolution2D(6, 5, 5, activation = "tanh").setName("conv1_5x5"))
model.add(MaxPooling2D())
model.add(Convolution2D(12, 5, 5, activation = "tanh").setName("conv2_5x5"))
model.add(MaxPooling2D())
model.add(Flatten())
model.add(Dense(100, activation = "tanh").setName("fc1"))
model.add(Dense(params("classNum").toInt, activation = "softmax").setName("fc2"))
}
'''
;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment