Skip to content

Instantly share code, notes, and snippets.

@mitjat
Created April 29, 2015 20:59
Show Gist options
  • Select an option

  • Save mitjat/2e878b87c31f96317866 to your computer and use it in GitHub Desktop.

Select an option

Save mitjat/2e878b87c31f96317866 to your computer and use it in GitHub Desktop.
(Stratified) sampling without replacement of lines on stdin.
#!/usr/bin/env python
"""
(Stratified) sampling without replacement of lines on stdin.
Preserves random N lines of input (without replacement), discards others.
Optionally, lines can be grouped into classes based on a value extracted from the line;
in this case, N lines per group are preserved.
Order is not preserved.
Examples:
sample.py -n100 # 100 samples
sample.py -n10 -k2 # 10 sample lines for each distinct value fo second column in a tab-separated input
sample.py -n10 -k 'len(line)' # 10 sample lines for each line length
"""
import os
import sys
from collections import defaultdict
from argparse import ArgumentParser
import random
class ReservoirSample:
"""
Reservoir sampling. Creates a sample of size `size`; access it via the `result` attribute.
"""
def __init__(self, size):
self.size = size # output sample size
self.seen = 0 # number of elements considered so far
self.result = []
def add(self, el):
self.seen += 1
if len(self.result) < self.size:
self.result.append(el)
else:
s = random.randint(0, self.seen-1)
if s < self.size:
self.result[s] = el
def __repr__(self):
return 'ReservoirSample(size=%r, seen=%r, result=%r)' % (self.size, self.seen, self.result)
if __name__ == '__main__':
# Parse command-line params
arg_parser = ArgumentParser(__doc__)
arg_parser.add_argument('-n', metavar="NUM", type=int, required=True,
help="Number of samples per class.")
arg_parser.add_argument('-k', metavar="COL|EXPR", default=None,
help="How to extract the class. If this is a comma-separated list of ints, assume input is tab-separated and "
"use those columns as class. Otherwise, the argument is treated as a python expression that is passed through "
"eval(); the variable `line` will hold the current line during evaluation, and `cols` will hold its tab-separated "
"columns. Default: don't extract the class.")
args = arg_parser.parse_args()
# Parse command-line args
if args.k == None:
class_expr = 'None'
elif all(part.isdigit() for part in args.k.split(',')):
indices = map(int, args.k.split(','))
class_expr = 'tuple(cols[i-1] for i in indices)'
else:
class_expr = args.k
class_expr = compile(class_expr, filename='class_expression', mode='eval')
def class_func(line):
cols = line.split('\t')
return eval(class_expr, locals(), globals())
# Sample the input
random.seed(19071985)
samples = defaultdict(lambda: ReservoirSample(args.n)) # class -> sample
for line in sys.stdin:
cls = class_func(line)
samples[cls].add(line)
# Write sample to output
for sample in samples.values():
for line in sample.result:
sys.stdout.write(line)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment