Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save pythonlessons/1a94b44407b4b4f2b509290c6b1a602f to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/1a94b44407b4b4f2b509290c6b1a602f to your computer and use it in GitHub Desktop.
handwriting_recognition_pytorch
# create callbacks
earlyStopping = EarlyStopping(monitor='val_CER', patience=20, mode="min", verbose=1)
modelCheckpoint = ModelCheckpoint(configs.model_path + '/model.pt', monitor='val_CER', mode="min", save_best_only=True, verbose=1)
tb_callback = TensorBoard(configs.model_path + '/logs')
reduce_lr = ReduceLROnPlateau(monitor='val_CER', factor=0.9, patience=10, verbose=1, mode='min', min_lr=1e-6)
model2onnx = Model2onnx(
saved_model_path=configs.model_path + '/model.pt',
input_shape=(1, configs.height, configs.width, 3),
verbose=1,
metadata={"vocab": configs.vocab}
)
# create model object that will handle training and testing of the network
model = Model(network, optimizer, loss, metrics=[CERMetric(configs.vocab), WERMetric(configs.vocab)])
model.fit(
train_dataProvider,
test_dataProvider,
epochs=1000,
callbacks=[earlyStopping, modelCheckpoint, tb_callback, reduce_lr, model2onnx]
)
# Save training and validation datasets as csv files
train_dataProvider.to_csv(os.path.join(configs.model_path, 'train.csv'))
test_dataProvider.to_csv(os.path.join(configs.model_path, 'val.csv'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment