Created
March 20, 2023 13:53
-
-
Save pythonlessons/1a94b44407b4b4f2b509290c6b1a602f to your computer and use it in GitHub Desktop.
handwriting_recognition_pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # create callbacks | |
| earlyStopping = EarlyStopping(monitor='val_CER', patience=20, mode="min", verbose=1) | |
| modelCheckpoint = ModelCheckpoint(configs.model_path + '/model.pt', monitor='val_CER', mode="min", save_best_only=True, verbose=1) | |
| tb_callback = TensorBoard(configs.model_path + '/logs') | |
| reduce_lr = ReduceLROnPlateau(monitor='val_CER', factor=0.9, patience=10, verbose=1, mode='min', min_lr=1e-6) | |
| model2onnx = Model2onnx( | |
| saved_model_path=configs.model_path + '/model.pt', | |
| input_shape=(1, configs.height, configs.width, 3), | |
| verbose=1, | |
| metadata={"vocab": configs.vocab} | |
| ) | |
| # create model object that will handle training and testing of the network | |
| model = Model(network, optimizer, loss, metrics=[CERMetric(configs.vocab), WERMetric(configs.vocab)]) | |
| model.fit( | |
| train_dataProvider, | |
| test_dataProvider, | |
| epochs=1000, | |
| callbacks=[earlyStopping, modelCheckpoint, tb_callback, reduce_lr, model2onnx] | |
| ) | |
| # Save training and validation datasets as csv files | |
| train_dataProvider.to_csv(os.path.join(configs.model_path, 'train.csv')) | |
| test_dataProvider.to_csv(os.path.join(configs.model_path, 'val.csv')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment