Skip to content

Instantly share code, notes, and snippets.

@shreydesai
Created March 20, 2020 01:17
Show Gist options
  • Save shreydesai/72f02096695371d9dcaacd91a582917b to your computer and use it in GitHub Desktop.
Save shreydesai/72f02096695371d9dcaacd91a582917b to your computer and use it in GitHub Desktop.
Visualization for span-based outputs
"""
Model output visualization script for span-based data. Requires
matplotlib for color maps.
Usage:
python3 span_viz.py > index.html
open index.html
"""
import argparse
import json
import random
import matplotlib.cm as cm
import matplotlib.colors as colors
def prepare_colors(map_name):
"""Creates matplotlib-based color map and style string."""
normer = colors.Normalize(vmin=0., vmax=1.)
cmap = cm.get_cmap(map_name) # BuGn or YlOrBr work best
color_nums = [i / 10. for i in range(11)]
colors_hex = [colors.to_hex(cmap(normer(x))) for x in color_nums]
style_elems_str = ' '.join([
f'.color{i} {{ background-color: {color_hex}; }}'
for i, color_hex in enumerate(colors_hex)
])
color_dict = {x: i for i, x in enumerate(color_nums)}
return (style_elems_str, color_dict)
def render_span(tokens, labels, color_dict, gold=False):
"""
Renders span using tokens and labels. If `gold` is enabled
for the true tokens, they are displayed in a gold color. Otherwise,
probabilities are displayed according to their relative intensities.
"""
html_tokens = []
for i, token in enumerate(tokens):
if gold:
token_class = '-gold' if labels[i] == 1 else 'NA'
else:
token_class = color_dict[round(labels[i], 1)]
html_tokens.append(f'<span class="color{token_class}">{token} </span>')
return ' '.join(html_tokens)
def render_block(true_span_html, pred_span_html):
"""Renders block with true and predicted spans."""
return f"""
<div class="row">
<div class="col-sm"><h4>True</h3></div>
<div class="col-sm"><h4>Pred</h3></div>
</div>
<div class="row">
<div class="col-sm summ">{true_span_html}</div>
<div class="col-sm summ">{pred_span_html}</div>
</div>
<hr>
"""
def render_page(style_elems_str, block_html):
"""Renders website page with color style and block."""
return f"""
<!DOCTYPE html>
<html>
<head>
<link href="https://fonts.googleapis.com/css?family=Roboto+Mono&display=swap" rel="stylesheet">
<link href="bootstrap.min.css" rel="stylesheet">
<style>
{style_elems_str}
.color-gold {{ background-color: #D4AF37; }}
.summ {{ font-family: "Roboto Mono", monospace; font-size: 12px; }}
.highlight {{ background-color: #F9F900; }}
</style>
</head>
<body>
<div class="container">
<h2>Span Visualization</h2>
<div>{block_html}</div>
</div>
</body>
</html>
"""
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--color_map', type=str, default='YlOrBr')
args = parser.parse_args()
inputs = [
(
"Possible Scenario While there are several scenarios where conflict between the United States and China is possible , some analysts believe that a conflict over Taiwan remains the most likely place where the PRC and the U.S. would come to blows .",
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0],
[0.18, 0.89, 0.88, 0.56, 0.93, 0.36, 0.26, 0.59, 0.62, 0.18, 0.31, 0.7, 0.66, 0.67, 0.13, 0.94, 0.69, 0.12, 0.26, 0.39, 0.57, 0.14, 0.3, 0.92, 0.03, 0.79, 0.26, 0.27, 0.04, 0.23, 0.48, 0.47, 0.24, 0.53, 0.65, 0.75, 0.91, 0.08, 0.59, 0.24, 0.23, 0.29]
),
(
"If thwarted in its initial efforts to stop Chinese aggression against Taiwan , the United States may be tempted to resort to stronger measures and attack mainland China .",
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0],
[0.52, 0.55, 0.96, 0.14, 0.06, 0.39, 0.65, 0.89, 0.66, 0.44, 0.09, 0.6, 0.77, 0.26, 0.4, 0.82, 0.39, 0.89, 0.1, 0.33, 0.6, 0.56, 0.99, 0.34, 0.15, 0.13, 0.77, 0.74, 0.92]
),
(
"It is also important to remember that nuclear weapons are an asymmetric response to American conventional superiority .",
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[0.76, 0.19, 0.54, 0.32, 0.33, 0.05, 0.34, 0.14, 0.43, 0.88, 0.54, 0.53, 0.46, 0.77, 0.35, 0.55, 0.73, 0.28]
),
]
block_html = []
style_elems_str, color_dict = prepare_colors(args.color_map)
for (string, true_labels, pred_labels) in inputs:
tokens = string.split()
true_span_html = render_span(
tokens, true_labels, color_dict, gold=True
)
pred_span_html = render_span(
tokens, pred_labels, color_dict
)
block_html.append(render_block(true_span_html, pred_span_html))
block_html = ''.join(block_html)
print(render_page(style_elems_str, block_html))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment