Created
December 19, 2019 21:16
-
-
Save MaxHalford/d2303aad59d787443fefa36ec56e94b4 to your computer and use it in GitHub Desktop.
Over/under/hybrid sampling a stream
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Resampling with a target distribution" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Three strategies:\n", | |
" \n", | |
"1. Any class can be under-sampled (rejection sampling)\n", | |
"2. Any class can be over-sampled\n", | |
"3. Classes can be both under and over sampled\n", | |
"\n", | |
"In the first two cases the number of samples is deterministic. However in the third case there is an extra degree of liberty and the number of samples can be chosen. " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Generating data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A: 9832 (9.83%)\n", | |
"B: 30175 (30.18%)\n", | |
"C: 59993 (59.99%)\n" | |
] | |
} | |
], | |
"source": [ | |
"import collections\n", | |
"from numpy import random\n", | |
"\n", | |
"actual = {\n", | |
" 'A': .1,\n", | |
" 'B': .3,\n", | |
" 'C': .6\n", | |
"}\n", | |
"\n", | |
"desired = {\n", | |
" 'A': .3,\n", | |
" 'B': .5,\n", | |
" 'C': .2\n", | |
"}\n", | |
"\n", | |
"actual_cumsum = {\n", | |
" 'A': actual['A'],\n", | |
" 'B': actual['A'] + actual['B'],\n", | |
" 'C': 1.\n", | |
"}\n", | |
"\n", | |
"def stream_values(n):\n", | |
" \n", | |
" labels = actual.keys()\n", | |
" \n", | |
" for _ in range(n):\n", | |
" p = random.random()\n", | |
" \n", | |
" for l in labels:\n", | |
" if p < actual_cumsum[l]:\n", | |
" yield l\n", | |
" break\n", | |
" \n", | |
"def count(values):\n", | |
" counts = collections.Counter(values)\n", | |
" for label, count in sorted(counts.items()):\n", | |
" print(f'{label}: {count} ({count / len(values):.2%})')\n", | |
" \n", | |
"\n", | |
"n = 100_000\n", | |
"values = list(stream_values(n))\n", | |
"count(values)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Under-sampling\n", | |
"\n", | |
"This is via rejection sampling." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A: 9832 (29.48%)\n", | |
"B: 16826 (50.45%)\n", | |
"C: 6696 (20.08%)\n", | |
"Expected 33333 samples, got 33354\n" | |
] | |
} | |
], | |
"source": [ | |
"class UnderSampler:\n", | |
"\n", | |
" def __init__(self, values, actual, desired, seed=None):\n", | |
" self.values = values\n", | |
" self.actual = actual\n", | |
" self.desired = desired\n", | |
" self.rng = random.RandomState(seed)\n", | |
" \n", | |
" def __iter__(self):\n", | |
" \n", | |
" f = self.desired\n", | |
" g = self.actual\n", | |
"\n", | |
" pivot = max(g.keys(), key=lambda x: f[x] / g[x])\n", | |
" M = f[pivot] / g[pivot]\n", | |
"\n", | |
" for x in self.values:\n", | |
"\n", | |
" ratio = f[x] / (M * g[x])\n", | |
"\n", | |
" if self.rng.random() < ratio:\n", | |
" yield x\n", | |
" \n", | |
" self.pivot_ = pivot\n", | |
" self.M_ = M\n", | |
" \n", | |
"\n", | |
"\n", | |
"sampler = UnderSampler(values, actual, desired, seed=42)\n", | |
"sample = list(iter(sampler))\n", | |
"count(sample)\n", | |
"\n", | |
"print(f'Expected {int(len(values) / sampler.M_)} samples, got {len(sample)}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Over-sampling" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A: 88828 (29.65%)\n", | |
"B: 151019 (50.41%)\n", | |
"C: 59729 (19.94%)\n", | |
"Expected 299999 samples, got 299576\n" | |
] | |
} | |
], | |
"source": [ | |
"class OverSampler:\n", | |
"\n", | |
" def __init__(self, values, actual, desired, seed=None):\n", | |
" self.values = values\n", | |
" self.actual = actual\n", | |
" self.desired = desired\n", | |
" self.rng = random.RandomState(seed)\n", | |
" \n", | |
" def __iter__(self):\n", | |
" \n", | |
" f = self.desired\n", | |
" g = self.actual\n", | |
"\n", | |
" pivot = max(g.keys(), key=lambda x: g[x] / f[x])\n", | |
" M = g[pivot] / f[pivot]\n", | |
" \n", | |
" for x in self.values:\n", | |
" \n", | |
" rate = M * f[x] / g[x]\n", | |
"\n", | |
" for _ in range(self.rng.poisson(rate)):\n", | |
" yield x\n", | |
" \n", | |
" self.pivot_ = pivot\n", | |
" self.M_ = M\n", | |
" \n", | |
" \n", | |
"sampler = OverSampler(values, actual, desired, seed=42)\n", | |
"sample = list(iter(sampler))\n", | |
"count(sample)\n", | |
"\n", | |
"print(f'Expected {int(len(values) * sampler.M_)} samples, got {len(sample)}')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Hybrid method " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A: 14792 (29.70%)\n", | |
"B: 25114 (50.43%)\n", | |
"C: 9893 (19.87%)\n", | |
"Expected 50000 samples, got 49799\n" | |
] | |
} | |
], | |
"source": [ | |
"class OverUnderSampler:\n", | |
"\n", | |
" def __init__(self, values, sampling_rate, actual, desired, seed=None):\n", | |
" self.values = values\n", | |
" self.actual = actual\n", | |
" self.desired = desired\n", | |
" self.rng = random.RandomState(seed)\n", | |
" \n", | |
" def __iter__(self):\n", | |
" \n", | |
" f = self.desired\n", | |
" g = self.actual\n", | |
" \n", | |
" for x in self.values:\n", | |
" \n", | |
" rate = self.sampling_rate * f[x] / g[x]\n", | |
"\n", | |
" for _ in range(self.rng.poisson(rate)):\n", | |
" yield x\n", | |
" \n", | |
"sampling_rate = .5\n", | |
"sampler = OverUnderSampler(values, sampling_rate, actual, desired, seed=42)\n", | |
"sample = list(iter(sampler))\n", | |
"count(sample)\n", | |
"\n", | |
"print(f'Expected {int(len(values) * sampling_rate)} samples, got {len(sample)}')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment