Created
July 3, 2020 22:08
-
-
Save sjmielke/4aa2e40cde2d38cb2f931adc389f3929 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# each metric is a list---they must all be of equal length: | |
# len([bg, cs, da, de, el, es, et, fi, fr, hu, it, lt, lv, nl, pl, pt, ro, sk, sl, sv]) | |
raw_metrics = { | |
"XMI into English": [102.3396, 96.9558, 99.6932, 96.53, 105.2602, 103.8174, 92.8232, 92.1413, 96.9629, 92.54, 92.0796, 89.171, 94.1704, 86.4946, 91.8679, 102.4575, 106.0929, 99.7949, 100.0662, 96.8784], | |
"XMI from English": [106.2096, 102.8122, 103.3194, 104.0, 111.0332, 108.0881, 100.1656, 98.0227, 99.7149, 99.1018, 95.3096, 96.0001, 99.3214, 90.3871, 98.2996, 105.2407, 112.4211, 105.7748, 107.9064, 100.1222], | |
"BLEU into English": [47.4, 42.4, 46.3, 44.0, 50.0, 50.6, 39.3, 38.2, 44.9, 38.4, 40.8, 37.6, 40.3, 38.3, 39.8, 48.3, 50.5, 44.2, 45.3, 43.7], | |
"BLEU from English": [46.3, 34.7, 45.0, 36.3, 45.5, 50.2, 27.7, 30.5, 45.7, 30.3, 37.9, 31.0, 34.6, 34.9, 30.5, 46.7, 44.2, 39.8, 41.5, 41.3], | |
"q_MT(T | S) into English": [51.8457, 57.2295, 54.4921, 57.6553, 48.9251, 50.3679, 61.3621, 62.044, 57.2224, 61.6453, 62.1057, 65.0143, 60.0149, 67.6907, 62.3174, 51.7278, 48.0924, 54.3904, 54.1191, 57.3069], | |
"q_MT(T | S) from English": [50.2761, 61.2360, 49.3525, 63.6453, 52.6653, 51.2585, 62.3698, 60.5892, 55.1352, 67.5041, 63.2954, 63.1707, 57.0305, 69.3437, 65.0522, 54.104, 48.1175, 51.9047, 50.3312, 52.9726], | |
"q_LM(T) into English": [154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853, 154.1853], | |
"q_LM(T) from English": [156.4857, 164.0482, 152.6719, 167.6453, 163.6985, 159.3466, 162.5354, 158.6119, 154.8501, 166.6059, 158.605, 159.1708, 156.3519, 159.7308, 163.3518, 159.3447, 160.5386, 157.6795, 158.2376, 153.0948], | |
} | |
[npoints] = list(set([len(l) for l in raw_metrics.values()])) | |
colors = ["#9999ff"] * npoints + ["#cf0d0d"] * npoints | |
metric2vals = { | |
name: raw_metrics[f"{name} into English"] + raw_metrics[f"{name} from English"] | |
for name in ("XMI", "BLEU", "q_MT(T | S)", "q_LM(T)") | |
} | |
metric2min = {k: min(vs) for k, vs in metric2vals.items()} | |
metric2diff = {k: max(vs) - min(vs) for k, vs in metric2vals.items()} | |
# each row and col is assigned one metric---a metric can appear in multiple! | |
rows = ("BLEU", "q_MT(T | S)", "XMI") | |
cols = ("q_MT(T | S)", "XMI", "q_LM(T)") | |
row_sizes = (400, 500, 400) | |
col_sizes = (400, 600, 400) | |
# check that this covers all unordered pairs | |
def check_assignment(rows, cols): | |
covered = [(r, c) for r in rows for c in cols] | |
print("'Duplicates:'", len(covered) - len(set(tuple(sorted(p)) for p in covered if p[0] != p[1]))) | |
for k1 in metric2vals.keys(): | |
for k2 in metric2vals.keys(): | |
if k1 < k2: | |
if (k1, k2) not in covered and (k2, k1) not in covered: | |
return False | |
return True | |
necessary = [(k1, k2) for k1 in metric2vals.keys() for k2 in metric2vals.keys() if k1 < k2] | |
assert check_assignment(rows, cols) | |
# put it together | |
x_padding = 400 | |
y_padding = 200 | |
space = 100 | |
max_x = sum(col_sizes) + len(col_sizes) * space + space + 2 * x_padding | |
max_y = sum(row_sizes) + len(row_sizes) * space + space + 2 * y_padding | |
for selected_color in [None, 0, 1]: | |
with open(f"/tmp/rug_{selected_color}.svg", "wt") as f: | |
print(f"<svg viewBox='{-space-x_padding} {-space-y_padding} {max_x} {max_y}' ", file=f) | |
print("xmlns='http://www.w3.org/2000/svg' ", file=f) | |
print("xmlns:xlink='http://www.w3.org/1999/xlink'>", file=f) | |
print(f"<rect x='{-space-x_padding}' y='{-space-y_padding}' width='{max_x}' height='{max_y}' fill='white' />", file=f) | |
# Rectangles | |
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)): | |
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)): | |
if row == col: | |
continue | |
# Rectangles | |
x = sum(col_sizes[:i_col]) + i_col * space | |
y = sum(row_sizes[:i_row]) + i_row * space | |
print(f"<g transform='translate({x} {y})'>", file=f) | |
print(f"<rect width='{col_size}' height='{row_size}' fill='none' stroke='black' />", file=f) | |
# Points | |
for xval, yval, color in zip(metric2vals[col], metric2vals[row], colors): | |
x = col_size * (0.1 + 0.8 * (xval - metric2min[col]) / metric2diff[col]) | |
y = row_size * (0.1 + 0.8 * (yval - metric2min[row]) / metric2diff[row]) | |
if selected_color is None or color == ["#9999ff", "#cf0d0d"][selected_color]: | |
print(f"<circle cx='{x}' cy='{y}' r='7' fill='{color}' />", file=f) | |
print("</g>", file=f) | |
# Horizontal rug thingies | |
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)): | |
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)): | |
if row == col: | |
continue | |
go_from = sum(col_sizes[:i_col + 1]) + i_col * space | |
go_to = go_from + space | |
found_next = False | |
for i_next_col in range(i_col + 1, len(cols)): | |
if cols[i_next_col] != row: | |
found_next = True | |
break | |
go_to += col_sizes[i_next_col] + space | |
if found_next: | |
for xval, yval, color in zip(metric2vals[col], metric2vals[row], colors): | |
y = sum(row_sizes[:i_row]) + i_row * space | |
y += row_size * (0.1 + 0.8 * (yval - metric2min[row]) / metric2diff[row]) | |
if selected_color is None or color == ["#9999ff", "#cf0d0d"][selected_color]: | |
print(f"<line x1='{go_from}' y1='{y}' x2='{go_to}' y2='{y}' stroke='{color}' opacity='0.5' />", file=f) | |
# Vertical rug thingies | |
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)): | |
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)): | |
if row == col: | |
continue | |
go_from = sum(row_sizes[:i_row + 1]) + i_row * space | |
go_to = go_from + space | |
found_next = False | |
for i_next_row in range(i_row + 1, len(rows)): | |
if rows[i_next_row] != col: | |
found_next = True | |
break | |
go_to += row_sizes[i_next_row] + space | |
if found_next: | |
for xval, yval, color in zip(metric2vals[row], metric2vals[col], colors): | |
x = sum(col_sizes[:i_col]) + i_col * space | |
x += col_size * (0.1 + 0.8 * (yval - metric2min[col]) / metric2diff[col]) | |
if selected_color is None or color == ["#9999ff", "#cf0d0d"][selected_color]: | |
print(f"<line x1='{x}' y1='{go_from}' x2='{x}' y2='{go_to}' stroke='{color}' opacity='0.5' />", file=f) | |
# Labels | |
for i_col, (col, col_size) in enumerate(zip(cols, col_sizes)): | |
for s in ("LM", "MT"): | |
col = col.replace(f"q_{s}(", f"q<tspan dy='15' font-size='70%'>{s}</tspan><tspan dy='-15'>(</tspan>") | |
x = sum(col_sizes[:i_col]) + i_col * space + col_size / 2 | |
y = -space / 2 | |
print(f"<text x='{x}' y='{y}' text-anchor='middle' style='font-size: 400%; font-family: Times;'>{col}</text>", file=f) | |
for i_row, (row, row_size) in enumerate(zip(rows, row_sizes)): | |
for s in ("LM", "MT"): | |
row = row.replace(f"q_{s}(", f"q<tspan dy='15' font-size='70%'>{s}</tspan><tspan dy='-15'>(</tspan>") | |
x = -space / 2 | |
y = sum(row_sizes[:i_row]) + i_row * space + row_size / 2 | |
print(f"<text text-anchor='middle' style='font-size: 400%; font-family: Times;' transform='translate({x} {y}) rotate(270)'>{row}</text>", file=f) | |
print("</svg>", file=f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment