Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save muminoff/370731dbd31e378cdb57994cbbb3bb8e to your computer and use it in GitHub Desktop.
Save muminoff/370731dbd31e378cdb57994cbbb3bb8e to your computer and use it in GitHub Desktop.
Supermemo 2 Algorithm, Unobscured (Python 3)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Polyglot spaced repetition algorithm"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A spaced repetition algorithm is a function that takes a user's subjective measurement of the difficulty of an item the user was tested on, and delays seeing that item by a certain amount of time. Over many items, this yields an ordering that changes. Each item you review next is pulled from this dynamically calculated ordering.\n",
"\n",
"\n",
"$$\\text{number of items in this lesson} = n$$\n",
"$$\\text{list of items to review} = E = \\{i_0, \\ldots , i_n \\}$$\n",
"$$\\text{order} = f(x) = \\ldots$$\n",
"\n",
"Based on \"SM2+\":\n",
"\n",
"$$\\text{performance rating from the user} = x$$\n",
"$$\\text{easiness previously observed} = e_{t-1}$$\n",
"$$\\text{easiness} = e = -0.8 + 0.28x +0.02x^2 + e_{t-1}$$\n",
"$$\\text{consecutive correct answers} = r$$\n",
"\n",
"For correct answers:\n",
"\n",
"$$f(x) = 6e^{r-1} = 6*(-0.8 + 0.28x +0.02x^2 + e_{t-1})^r$$\n",
"\n",
"For incorrect answers, units of days:\n",
"\n",
"$$f(x) = 1$$\n",
"\n",
"(1 day for incorrect items because you want to test as soon as possible, since clearly the user did not remember well)\n",
"\n",
"SM2's delay boils down to:\n",
"\n",
"$$ \\text{rating of how easy the question was} = x$$\n",
"$$0 \\leq x$$\n",
"$$f(x) = \\left( x+x^2 \\right)^r$$\n",
"\n",
"\n",
"$$x = {x_0, x_1, \\ldots, x_t}$$\n",
"$$x_i \\in \\{0,1,2,3,4,5\\} \\> \\forall \\, i \\leq t$$\n",
"$$r = \\text{consecutive correct answers, or 0 if} \\> x_t \\text{was incorrect}$$\n",
"$$f_t(x_t) = a(\\sum_{i=0}^{t}{b+cx_i+dx_i^2})^{\\theta r-1}$$"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"def sm2(x: [int], a=6.0, b=-0.8, c=0.28, d=0.02, theta=0.2) -> float:\n",
" \"\"\"\n",
" Returns the number of days to delay the next review of an item by, fractionally, based on the history of answers x to\n",
" a given question, where\n",
" x == 0: Incorrect, Hardest\n",
" x == 1: Incorrect, Hard\n",
" x == 2: Incorrect, Medium\n",
" x == 3: Correct, Medium\n",
" x == 4: Correct, Easy\n",
" x == 5: Correct, Easiest\n",
" @param x The history of answers in the above scoring.\n",
" @param theta When larger, the delays for correct answers will increase.\n",
" \"\"\"\n",
" assert all(0 <= x_i <= 5 for x_i in x)\n",
" correct_x = [x_i >= 3 for x_i in x]\n",
" # If you got the last question incorrect, just return 1\n",
" if not correct_x[-1]:\n",
" return 1.0\n",
" \n",
" # Calculate the latest consecutive answer streak\n",
" num_consecutively_correct = 0\n",
" for correct in reversed(correct_x):\n",
" if correct:\n",
" num_consecutively_correct += 1\n",
" else:\n",
" break\n",
" \n",
" return a*(max(1.3, 2.5 + sum(b+c*x_i+d*x_i*x_i for x_i in x)))**(theta*num_consecutively_correct)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9.458300066760838"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[2,1,3,3,4,1,2,3,4])"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7.329342743905306"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[3])"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15.068222995109046"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[3, 1, 5, 3, 5])"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0, 0, 1])"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0])"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0, 0, 1, 2])"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6.323243712370701"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0, 1, 2, 3])"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7.018687988451514"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0, 1, 2, 3, 3])"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0, 1, 2, 3, 2])"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10.595539477949636"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[0, 1, 2, 3, 3, 5])"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"154.94996798742838"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sm2(x=[5, 5, 5, 5, 5, 5, 5])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
def supermemo_2(x: [int], a=6.0, b=-0.8, c=0.28, d=0.02, assumed_score=2.5, min_score=1.3, theta=1.0) -> float:
"""
Returns the number of days until seeing a problem again based on the
history of answers x to the problem, where the meaning of x is:
x == 0: Incorrect, Hardest
x == 1: Incorrect, Hard
x == 2: Incorrect, Medium
x == 3: Correct, Medium
x == 4: Correct, Easy
x == 5: Correct, Easiest
@param x The history of answers in the above scoring.
@param theta When larger, the delays for correct answers will increase.
"""
assert all(0 <= x_i <= 5 for x_i in x)
correct = [x_i >= 3 for x_i in x]
# If you got the last question incorrect, just return 1
if not correct[-1]:
return 1.0
# Calculate the latest consecutive answer streak
r = 0
for c_i in reversed(correct):
if c_i:
r+=1
else:
break
return a*(max(min_score, assumed_score + sum(b+c*x_i+d*x_i*x_i for x_i in x)))**(theta*r)
function daysTillNextTestAlgorithm(recent, x, a = 6.0, b = -0.8, c = 0.28, d = 0.02, theta = 0.2) {
  if (recent < 4) {
    return 1
  }
  const history = [recent, ...x]
  // Calculate latest correctness streak
  let streak = 0
  for (let i = 0; i < history.length; i++) {
    if (history[i] > 3) {
      streak++
    } else {
      break
    }
  }
  // Sum up the history
  const historySum = history.reduce(
    (prev, val) => prev + (b + (c * val) + (d * val * val)),
    0
  )
  return a * Math.pow(Math.max(1.3, 2.5 + historySum), theta * streak)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment