Skip to content

Instantly share code, notes, and snippets.

@savan77
Created October 20, 2020 21:50
Show Gist options
  • Select an option

  • Save savan77/c47d72ffab8382b619d8cdae263e8e37 to your computer and use it in GitHub Desktop.

Select an option

Save savan77/c47d72ffab8382b619d8cdae263e8e37 to your computer and use it in GitHub Desktop.
OCR Inference
def run(checkpoint, batch_size, dataset_name, image_path_pattern, annotations):
images_placeholder, endpoints = create_model(batch_size,
dataset_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint)
count = 0
width, height = get_dataset_image_size(dataset_name)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
for path,boxes in annotations.items():
print("Processing: ", path)
img = cv2.imread(os.path.join('/mnt/data/datasets/images', os.path.basename(path)))
for box in boxes:
img_cropped = img[box['xmin']:box['xmax']+1, box['ymin']:box['ymax']+1]
pil_img = PIL.Image.fromarray(img_cropped)
img = pil_img.resize((width, height), PIL.Image.ANTIALIAS)
count += 1
predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: np.asarray(img)[np.newaxis, ...]})
file_writer = open('/mnt/output/'+os.path.basename(path).split('.')[0]+'.txt', 'w')
file_writer.write([pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()][0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment