Skip to content

Instantly share code, notes, and snippets.

Created July 17, 2019 17:39
Show Gist options
  • Save SohierDane/a90ef46d79808fe3afc70c80bae45972 to your computer and use it in GitHub Desktop.
Save SohierDane/a90ef46d79808fe3afc70c80bae45972 to your computer and use it in GitHub Desktop.
Python equivalent of the Kuzushiji competition metric (
Kaggle's backend uses a C# implementation of the same metric. This version is
provided for convenience only; in the event of any discrepancies the C# implementation
is the master version.
Tested on Python 3.6 with numpy 1.16.4 and pandas 0.24.2.
import argparse
import multiprocessing
import numpy as np
import pandas as pd
def define_console_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--sub_path', type=str)
parser.add_argument('--solution_path', type=str)
return parser
def score_page(preds, truth):
Scores a single page.
preds: prediction string of labels and center points.
truth: ground truth string of labels and bounding boxes.
True/false positive and false negative counts for the page
tp = 0
fp = 0
fn = 0
truth_indices = {
'label': 0,
'X': 1,
'Y': 2,
'Width': 3,
'Height': 4
preds_indices = {
'label': 0,
'X': 1,
'Y': 2
if pd.isna(truth) and pd.isna(preds):
return {'tp': tp, 'fp': fp, 'fn': fn}
if pd.isna(truth):
fp += len(preds.split(' ')) // len(preds_indices)
return {'tp': tp, 'fp': fp, 'fn': fn}
if pd.isna(preds):
fn += len(truth.split(' ')) // len(truth_indices)
return {'tp': tp, 'fp': fp, 'fn': fn}
truth = truth.split(' ')
if len(truth) % len(truth_indices) != 0:
raise ValueError('Malformed solution string')
truth_label = np.array(truth[truth_indices['label']::len(truth_indices)])
truth_xmin = np.array(truth[truth_indices['X']::len(truth_indices)]).astype(float)
truth_ymin = np.array(truth[truth_indices['Y']::len(truth_indices)]).astype(float)
truth_xmax = truth_xmin + np.array(truth[truth_indices['Width']::len(truth_indices)]).astype(float)
truth_ymax = truth_ymin + np.array(truth[truth_indices['Height']::len(truth_indices)]).astype(float)
preds = preds.split(' ')
if len(preds) % len(preds_indices) != 0:
raise ValueError('Malformed prediction string')
preds_label = np.array(preds[preds_indices['label']::len(preds_indices)])
preds_x = np.array(preds[preds_indices['X']::len(preds_indices)]).astype(float)
preds_y = np.array(preds[preds_indices['Y']::len(preds_indices)]).astype(float)
preds_unused = np.ones(len(preds_label)).astype(bool)
for xmin, xmax, ymin, ymax, label in zip(truth_xmin, truth_xmax, truth_ymin, truth_ymax, truth_label):
# Matching = point inside box & character same & prediction not already used
matching = (xmin < preds_x) & (xmax > preds_x) & (ymin < preds_y) & (ymax > preds_y) & (preds_label == label) & preds_unused
if matching.sum() == 0:
fn += 1
tp += 1
preds_unused[np.argmax(matching)] = False
fp += preds_unused.sum()
return {'tp': tp, 'fp': fp, 'fn': fn}
def kuzushiji_f1(sub, solution):
Calculates the competition metric.
sub: submissions, as a Pandas dataframe
solution: solution, as a Pandas dataframe
f1 score
if not all(sub['image_id'].values == solution['image_id'].values):
raise ValueError("Submission image id codes don't match solution")
pool = multiprocessing.Pool()
results = pool.starmap(score_page, zip(sub['labels'].values, solution['labels'].values))
tp = sum([x['tp'] for x in results])
fp = sum([x['fp'] for x in results])
fn = sum([x['fn'] for x in results])
if (tp + fp) == 0 or (tp + fn) == 0:
return 0
precision = tp / (tp + fp)
recall = tp / (tp + fn)
if precision > 0 and recall > 0:
f1 = (2 * precision * recall) / (precision + recall)
f1 = 0
return f1
if __name__ == '__main__':
parser = define_console_parser()
shell_args = parser.parse_args()
sub = pd.read_csv(shell_args.sub_path)
solution = pd.read_csv(shell_args.solution_path)
sub = sub.rename(columns={'rowId': 'image_id', 'PredictionString': 'labels'})
solution = solution.rename(columns={'rowId': 'image_id', 'PredictionString': 'labels'})
score = kuzushiji_f1(sub, solution)
print('F1 score of: {0}'.format(score))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment