Skip to content

Instantly share code, notes, and snippets.

Created July 3, 2020 22:08
Show Gist options
  • Save sjmielke/4aa2e40cde2d38cb2f931adc389f3929 to your computer and use it in GitHub Desktop.
Save sjmielke/4aa2e40cde2d38cb2f931adc389f3929 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# each metric is a list---they must all be of equal length:
# len([bg, cs, da, de, el, es, et, fi, fr, hu, it, lt, lv, nl, pl, pt, ro, sk, sl, sv])
raw_metrics = {
"XMI into English": [102.3396, 96.9558, 99.6932, 96.53, 105.2602, 103.8174, 92.8232, 92.1413, 96.9629, 92.54, 92.0796, 89.171, 94.1704, 86.4946, 91.8679, 102.4575, 106.0929, 99.7949, 100.0662, 96.8784],
"XMI from English": [106.2096, 102.8122, 103.3194, 104.0, 111.0332, 108.0881, 100.1656, 98.0227, 99.7149, 99.1018, 95.3096, 96.0001, 99.3214, 90.3871, 98.2996, 105.2407, 112.4211, 105.7748, 107.9064, 100.1222],
"BLEU into English": [47.4, 42.4, 46.3, 44.0, 50.0, 50.6, 39.3, 38.2, 44.9, 38.4, 40.8, 37.6, 40.3, 38.3, 39.8, 48.3, 50.5, 44.2, 45.3, 43.7],
"BLEU from English": [46.3, 34.7, 45.0, 36.3, 45.5, 50.2, 27.7, 30.5, 45.7, 30.3, 37.9, 31.0, 34.6, 34.9, 30.5, 46.7, 44.2, 39.8, 41.5, 41.3],
"q_MT(T | S) into English": [51.8457, 57.2295, 54.4921, 57.6553, 48.9251, 50.3679, 61.3621, 62.044, 57.2224, 61.6453, 62.1057, 65.0143, 60.0149, 67.6907, 62.3174, 51.7278, 48.0924, 54.3904, 54.1191, 57.3069],
"q_MT(T | S) from English": [50.2761, 61.2360, 49.3525, 63.6453, 52.6653, 51.2585, 62.3698, 60.5892, 55.1352, 67.5041, 63.2954, 63.1707, 57.0305, 69.3437, 65.0522, 54.104, 48.1175, 51.9047, 50.3312, 52.9726],
"q_LM(T) into English": [154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853],
"q_LM(T) from English": [156.4857, 164.0482, 152.6719, 167.6453, 163.6985, 159.3466, 162.5354, 158.6119, 154.8501, 166.6059, 158.605, 159.1708, 156.3519, 159.7308, 163.3518, 159.3447, 160.5386, 157.6795, 158.2376, 153.0948],
[npoints] = list(set([len(l) for l in raw_metrics.values()]))
colors = ["#9999ff"] * npoints + ["#cf0d0d"] * npoints
metric2vals = {
name: raw_metrics[f"{name} into English"] + raw_metrics[f"{name} from English"]
for name in ("XMI", "BLEU", "q_MT(T | S)", "q_LM(T)")
metric2min = {k: min(vs) for k, vs in metric2vals.items()}
metric2diff = {k: max(vs) - min(vs) for k, vs in metric2vals.items()}
# each row and col is assigned one metric---a metric can appear in multiple!
rows = ("BLEU", "q_MT(T | S)", "XMI")
cols = ("q_MT(T | S)", "XMI", "q_LM(T)")
row_sizes = (400, 500, 400)
col_sizes = (400, 600, 400)
# check that this covers all unordered pairs
def check_assignment(rows, cols):
covered = [(r, c) for r in rows for c in cols]
print("'Duplicates:'", len(covered) - len(set(tuple(sorted(p)) for p in covered if p[0] != p[1])))
for k1 in metric2vals.keys():
for k2 in metric2vals.keys():
if k1 < k2:
if (k1, k2) not in covered and (k2, k1) not in covered:
return False
return True
necessary = [(k1, k2) for k1 in metric2vals.keys() for k2 in metric2vals.keys() if k1 < k2]
assert check_assignment(rows, cols)
# put it together
x_padding = 400
y_padding = 200
space = 100
max_x = sum(col_sizes) + len(col_sizes) * space + space + 2 * x_padding
max_y = sum(row_sizes) + len(row_sizes) * space + space + 2 * y_padding
for selected_color in [None, 0, 1]:
with open(f"/tmp/rug_{selected_color}.svg", "wt") as f:
print(f"<svg viewBox='{-space-x_padding} {-space-y_padding} {max_x} {max_y}' ", file=f)
print("xmlns='' ", file=f)
print("xmlns:xlink=''>", file=f)
print(f"<rect x='{-space-x_padding}' y='{-space-y_padding}' width='{max_x}' height='{max_y}' fill='white' />", file=f)
# Rectangles
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)):
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)):
if row == col:
# Rectangles
x = sum(col_sizes[:i_col]) + i_col * space
y = sum(row_sizes[:i_row]) + i_row * space
print(f"<g transform='translate({x} {y})'>", file=f)
print(f"<rect width='{col_size}' height='{row_size}' fill='none' stroke='black' />", file=f)
# Points
for xval, yval, color in zip(metric2vals[col], metric2vals[row], colors):
x = col_size * (0.1 + 0.8 * (xval - metric2min[col]) / metric2diff[col])
y = row_size * (0.1 + 0.8 * (yval - metric2min[row]) / metric2diff[row])
if selected_color is None or color == ["#9999ff", "#cf0d0d"][selected_color]:
print(f"<circle cx='{x}' cy='{y}' r='7' fill='{color}' />", file=f)
print("</g>", file=f)
# Horizontal rug thingies
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)):
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)):
if row == col:
go_from = sum(col_sizes[:i_col + 1]) + i_col * space
go_to = go_from + space
found_next = False
for i_next_col in range(i_col + 1, len(cols)):
if cols[i_next_col] != row:
found_next = True
go_to += col_sizes[i_next_col] + space
if found_next:
for xval, yval, color in zip(metric2vals[col], metric2vals[row], colors):
y = sum(row_sizes[:i_row]) + i_row * space
y += row_size * (0.1 + 0.8 * (yval - metric2min[row]) / metric2diff[row])
if selected_color is None or color == ["#9999ff", "#cf0d0d"][selected_color]:
print(f"<line x1='{go_from}' y1='{y}' x2='{go_to}' y2='{y}' stroke='{color}' opacity='0.5' />", file=f)
# Vertical rug thingies
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)):
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)):
if row == col:
go_from = sum(row_sizes[:i_row + 1]) + i_row * space
go_to = go_from + space
found_next = False
for i_next_row in range(i_row + 1, len(rows)):
if rows[i_next_row] != col:
found_next = True
go_to += row_sizes[i_next_row] + space
if found_next:
for xval, yval, color in zip(metric2vals[row], metric2vals[col], colors):
x = sum(col_sizes[:i_col]) + i_col * space
x += col_size * (0.1 + 0.8 * (yval - metric2min[col]) / metric2diff[col])
if selected_color is None or color == ["#9999ff", "#cf0d0d"][selected_color]:
print(f"<line x1='{x}' y1='{go_from}' x2='{x}' y2='{go_to}' stroke='{color}' opacity='0.5' />", file=f)
# Labels
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)):
for s in ("LM", "MT"):
col = col.replace(f"q_{s}(", f"q<tspan dy='15' font-size='70%'>{s}</tspan><tspan dy='-15'>(</tspan>")
x = sum(col_sizes[:i_col]) + i_col * space + col_size / 2
y = -space / 2
print(f"<text x='{x}' y='{y}' text-anchor='middle' style='font-size: 400%; font-family: Times;'>{col}</text>", file=f)
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)):
for s in ("LM", "MT"):
row = row.replace(f"q_{s}(", f"q<tspan dy='15' font-size='70%'>{s}</tspan><tspan dy='-15'>(</tspan>")
x = -space / 2
y = sum(row_sizes[:i_row]) + i_row * space + row_size / 2
print(f"<text text-anchor='middle' style='font-size: 400%; font-family: Times;' transform='translate({x} {y}) rotate(270)'>{row}</text>", file=f)
print("</svg>", file=f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment