Skip to content

Instantly share code, notes, and snippets.

@honnibal
Created April 22, 2019 10:33
Show Gist options
  • Save honnibal/5c54ad3ddbb0504c0b25319b53d4c681 to your computer and use it in GitHub Desktop.
Save honnibal/5c54ad3ddbb0504c0b25319b53d4c681 to your computer and use it in GitHub Desktop.
Script to jury-rig a little spaced-repeition system out of the Prodigy annotation tool
"""See https://twitter.com/honnibal/status/1120020992636661767 """
import time
import srsly
from prodigy import recipe
from prodigy.components.db import connect
from prodigy.util import INPUT_HASH_ATTR, set_hashes
from prodigy.components.filters import filter_duplicates
def get_rank_priority(data):
return 1. / float(data["rank"])
def get_recency_priority(data):
seconds_since_seen = int(time.time()) - data["timestamp"]
if seconds_since_seen < 5:
return 0.0
elif seconds_since_seen < 20:
return 0.1
elif seconds_since_seen < 60:
return 0.2
elif seconds_since_seen < 600:
return 0.3
elif seconds_since_seen < 6000:
return 0.4
elif seconds_since_seen < 60000:
return 0.5
elif seconds_since_seen < 600000:
return 0.6
elif seconds_since_seen < 6000000:
return 0.7
else:
return 0.8
def get_difficulty_priority(data):
return 1-data["accuracy"]
class Card:
def __init__(self, data):
data.setdefault("timestamp", int(time.time()))
data.setdefault("history", [])
data.setdefault("accuracy", 0.5)
self.data = data
self.id = data[INPUT_HASH_ATTR]
self.update_priority()
@property
def accuracy(self):
return self.data["accuracy"]
def update_priority(self, recency_weight=0.6, difficulty_weight=0.4):
self.priority = sum((
recency_weight * get_recency_priority(self.data),
difficulty_weight * get_difficulty_priority(self.data)
))
def update(self, response, acceptance, ema_decay=0.7):
self.data["history"].append(response)
self.data["accuracy"] *= ema_decay
self.data["accuracy"] += (1-ema_decay) * (acceptance == "accept")
self.update_priority()
def to_json(self):
self.data["priority"] = self.priority
self.data["meta"] = {
"priority": "%.2f" % self.priority,
"rank": int(self.data["rank"]),
"recency": "%.1f" % get_recency_priority(self.data),
"accuracy": "%.2f" % self.data["accuracy"]
}
return self.data
def __eq__(self, other):
return self.priority == other.priority
def __ne__(self, other):
return self.priority != other.priority
def __lt__(self, other):
return self.priority < other.priority
def __le__(self, other):
return self.priority <= other.priority
def __gt__(self, other):
return self.priority > other.priority
def __ge__(self, other):
return self.priority >= other.priority
class CardQueue:
def __init__(self, cards):
self.queue = [Card(card) for card in cards]
self.queue = self.queue
self.index = {card.id: card for card in self.queue}
self.i = 0
self.progress = 0.
@property
def avg_accuracy(self):
total_accuracy = sum(card.accuracy for card in self.queue)
return total_accuracy / len(self.queue)
def __iter__(self):
self.i = 0
self.sort_queue()
while True:
if self.i >= len(self.queue):
print("Re-do queue", self.i, len(self.queue))
self.sort_queue()
self.i = 0
card = self.queue[self.i]
self.i += 1
data = card.to_json()
card.data["timestamp"] = int(time.time())
data["html"] = " " # to make html_template kick in
yield data
def sort_queue(self):
for card in self.queue:
card.update_priority()
self.queue.sort(reverse=True)
def update(self, batch):
for card_data in batch:
if "accept" not in card_data:
card = self.index[card_data[INPUT_HASH_ATTR]]
card.priority = 0.
else:
card = self.index[card_data[INPUT_HASH_ATTR]]
card.update(card_data["accept"], card_data["answer"])
self.sort_queue()
self.i = 0
def add_gender_options(stream):
for eg in stream:
eg["options"] = [
{"id": 1, "text": "der"},
{"id": 2, "text": "die"},
{"id": 3, "text": "das"},
]
yield eg
@recipe("import_genders")
def import_genders(deck, jsonl_loc):
DB = connect()
cards = add_gender_options(srsly.read_jsonl(jsonl_loc))
cards = list(filter_duplicates([set_hashes(card) for card in cards], by_input=True))
DB.add_dataset(deck)
DB.add_examples(list(cards), [deck])
DB.save()
@recipe("set-gender")
def set_gender(dataset_in, dataset_out):
"""After marking the genders, set the answers into the 'gender' key."""
DB = connect()
examples = list(DB.get_dataset(dataset_in))
examples = [eg for eg in examples if eg["answer"] == "accept" and eg.get("accept")]
for eg in examples:
eg.pop("answer")
accept = eg.pop("accept")[0]
eg["gender"] = ["der", "die", "das"][accept-1]
DB.add_dataset(dataset_out)
DB.add_examples(examples, [dataset_out])
def update_state(DB, name, cards):
"""Prodigy doesn't let you update a dataset in-place. Instead, drop and re-add."""
DB.drop_dataset(name)
DB.add_dataset(name)
DB.add_examples([card.to_json() for card in cards], [name])
# Template and Javascript to handle the 'show answer on response' logic.
HTML_TEMPLATE = """<span style="font-size: {{theme.largeText}}px"><span class="srs-article" style="width: 2em; text-align: right; display: inline-block; font-weight: bold; margin-right: 5px"></span> {{text}}</span>"""
JAVASCRIPT = """
document.addEventListener('prodigyupdate', ({ detail }) => {
const accept = detail.task.accept || []
if (accept.length) {
const article = document.querySelector('.srs-article')
const opt = detail.task.options.find(({ id }) => id === accept[0])
const isCorrect = detail.task.gender.toLowerCase() === opt.text
article.textContent = detail.task.gender.toLowerCase()
article.style.color = isCorrect ? window.prodigy.theme.accept : window.prodigy.theme.reject
}
})
"""
@recipe("srs")
def srs(deck):
"""Perform spaced-repetition memory training."""
DB = connect()
stream = CardQueue(list(DB.get_dataset(deck)))
return {
"view_id": "choice",
"dataset": f"raw_answers_{deck}",
"stream": stream,
"update": stream.update,
"on_exit": lambda self: update_state(DB, deck, stream.queue),
"progress": lambda *args, **kwargs: stream.avg_accuracy,
"config": {
"auto_exclude_current": False,
"choice_auto_accept": False,
"javascript": JAVASCRIPT,
"html_template": HTML_TEMPLATE,
},
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment