Created
March 20, 2020 01:17
-
-
Save shreydesai/72f02096695371d9dcaacd91a582917b to your computer and use it in GitHub Desktop.
Visualization for span-based outputs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Model output visualization script for span-based data. Requires | |
matplotlib for color maps. | |
Usage: | |
python3 span_viz.py > index.html | |
open index.html | |
""" | |
import argparse | |
import json | |
import random | |
import matplotlib.cm as cm | |
import matplotlib.colors as colors | |
def prepare_colors(map_name): | |
"""Creates matplotlib-based color map and style string.""" | |
normer = colors.Normalize(vmin=0., vmax=1.) | |
cmap = cm.get_cmap(map_name) # BuGn or YlOrBr work best | |
color_nums = [i / 10. for i in range(11)] | |
colors_hex = [colors.to_hex(cmap(normer(x))) for x in color_nums] | |
style_elems_str = ' '.join([ | |
f'.color{i} {{ background-color: {color_hex}; }}' | |
for i, color_hex in enumerate(colors_hex) | |
]) | |
color_dict = {x: i for i, x in enumerate(color_nums)} | |
return (style_elems_str, color_dict) | |
def render_span(tokens, labels, color_dict, gold=False): | |
""" | |
Renders span using tokens and labels. If `gold` is enabled | |
for the true tokens, they are displayed in a gold color. Otherwise, | |
probabilities are displayed according to their relative intensities. | |
""" | |
html_tokens = [] | |
for i, token in enumerate(tokens): | |
if gold: | |
token_class = '-gold' if labels[i] == 1 else 'NA' | |
else: | |
token_class = color_dict[round(labels[i], 1)] | |
html_tokens.append(f'<span class="color{token_class}">{token} </span>') | |
return ' '.join(html_tokens) | |
def render_block(true_span_html, pred_span_html): | |
"""Renders block with true and predicted spans.""" | |
return f""" | |
<div class="row"> | |
<div class="col-sm"><h4>True</h3></div> | |
<div class="col-sm"><h4>Pred</h3></div> | |
</div> | |
<div class="row"> | |
<div class="col-sm summ">{true_span_html}</div> | |
<div class="col-sm summ">{pred_span_html}</div> | |
</div> | |
<hr> | |
""" | |
def render_page(style_elems_str, block_html): | |
"""Renders website page with color style and block.""" | |
return f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<link href="https://fonts.googleapis.com/css?family=Roboto+Mono&display=swap" rel="stylesheet"> | |
<link href="bootstrap.min.css" rel="stylesheet"> | |
<style> | |
{style_elems_str} | |
.color-gold {{ background-color: #D4AF37; }} | |
.summ {{ font-family: "Roboto Mono", monospace; font-size: 12px; }} | |
.highlight {{ background-color: #F9F900; }} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h2>Span Visualization</h2> | |
<div>{block_html}</div> | |
</div> | |
</body> | |
</html> | |
""" | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--color_map', type=str, default='YlOrBr') | |
args = parser.parse_args() | |
inputs = [ | |
( | |
"Possible Scenario While there are several scenarios where conflict between the United States and China is possible , some analysts believe that a conflict over Taiwan remains the most likely place where the PRC and the U.S. would come to blows .", | |
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0], | |
[0.18, 0.89, 0.88, 0.56, 0.93, 0.36, 0.26, 0.59, 0.62, 0.18, 0.31, 0.7, 0.66, 0.67, 0.13, 0.94, 0.69, 0.12, 0.26, 0.39, 0.57, 0.14, 0.3, 0.92, 0.03, 0.79, 0.26, 0.27, 0.04, 0.23, 0.48, 0.47, 0.24, 0.53, 0.65, 0.75, 0.91, 0.08, 0.59, 0.24, 0.23, 0.29] | |
), | |
( | |
"If thwarted in its initial efforts to stop Chinese aggression against Taiwan , the United States may be tempted to resort to stronger measures and attack mainland China .", | |
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0], | |
[0.52, 0.55, 0.96, 0.14, 0.06, 0.39, 0.65, 0.89, 0.66, 0.44, 0.09, 0.6, 0.77, 0.26, 0.4, 0.82, 0.39, 0.89, 0.1, 0.33, 0.6, 0.56, 0.99, 0.34, 0.15, 0.13, 0.77, 0.74, 0.92] | |
), | |
( | |
"It is also important to remember that nuclear weapons are an asymmetric response to American conventional superiority .", | |
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], | |
[0.76, 0.19, 0.54, 0.32, 0.33, 0.05, 0.34, 0.14, 0.43, 0.88, 0.54, 0.53, 0.46, 0.77, 0.35, 0.55, 0.73, 0.28] | |
), | |
] | |
block_html = [] | |
style_elems_str, color_dict = prepare_colors(args.color_map) | |
for (string, true_labels, pred_labels) in inputs: | |
tokens = string.split() | |
true_span_html = render_span( | |
tokens, true_labels, color_dict, gold=True | |
) | |
pred_span_html = render_span( | |
tokens, pred_labels, color_dict | |
) | |
block_html.append(render_block(true_span_html, pred_span_html)) | |
block_html = ''.join(block_html) | |
print(render_page(style_elems_str, block_html)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment