-
-
Save tendres/7b55326c90cbb6b9d9f8bc784ed05f6f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import os | |
import csv | |
import tempfile | |
import time | |
import logging | |
import optparse | |
import itertools | |
import dj_database_url | |
import psycopg2 | |
import psycopg2.extras | |
import dedupe | |
class Database(object): | |
def __init__(self, *args, **kwargs): | |
self.con = kwargs.pop('connection') | |
super().__init__(*args, **kwargs) | |
def match(self, data_1, data_2, threshold=0.5): | |
blocked_pairs = self._blockData(data_1, data_2) | |
return self.matchBlocks(blocked_pairs, threshold) | |
def _blockData(self, data_1, data_2): | |
cur = self.con.cursor() | |
if not self.loaded_indices: | |
self.blocker.indexAll(data_2) | |
cur.execute("DROP TABLE IF EXISTS block") | |
cur.execute("""CREATE TABLE block | |
(block_key VARCHAR(200), record_id INTEGER)""") | |
self.con.commit() | |
for block_key, record_id in self.blocker(data_2.items(), target=True): | |
cur.execute("INSERT INTO block VALUES (%s, %s)", | |
(block_key, record_id)) | |
cur.execute("CREATE INDEX ON block (block_key)") | |
self.con.commit() | |
for each in self._blockGenerator(data_1, data_2): | |
yield each | |
cur.close() | |
def _blockGenerator(self, data_1, data_2): | |
cur = self.con.cursor() | |
block_groups = itertools.groupby(self.blocker(data_1.items()), | |
lambda x: x[1]) | |
i = 0 | |
for record_id, blocks in block_groups: | |
A = [(record_id, data_1[record_id], set())] | |
block_keys, _ = list(zip(*blocks)) | |
cur.execute(""" | |
SELECT DISTINCT record_id | |
FROM block | |
WHERE block_key IN %s""", (block_keys,)) | |
match_ids = (row['record_id'] for row in cur.fetchall()) | |
B = [(record_id, data_2[record_id], set()) | |
for record_id in match_ids] | |
i += 1 | |
if i % 10000 == 0: | |
print(i, "records") | |
print(time.time() - start_time, "seconds") | |
if B: | |
yield (A, B) | |
cur.close() | |
class RecordLink(Database, dedupe.RecordLink): | |
pass | |
class StaticRecordLink(Database, dedupe.StaticRecordLink): | |
pass | |
## Logging | |
# Dedupe uses Python logging to show or suppress verbose output. Added | |
# for convenience. To enable verbose output, run `python | |
# pgsql_big_dedupe_example.py -v` | |
optp = optparse.OptionParser() | |
optp.add_option('-v', '--verbose', dest='verbose', action='count', | |
help='Increase verbosity (specify multiple times for more)' | |
) | |
(opts, args) = optp.parse_args() | |
log_level = logging.WARNING | |
if opts.verbose: | |
if opts.verbose == 1: | |
log_level = logging.INFO | |
elif opts.verbose >= 2: | |
log_level = logging.DEBUG | |
logging.getLogger().setLevel(log_level) | |
# ## Setup; change these names if you screw up earlier settings/trainings | |
settings_file = 'fig_reclink_settings' | |
training_file = 'fig_reclink_training.json' | |
start_time = time.time() | |
# Set the database connection from environment variable using | |
# [dj_database_url](https://github.com/kennethreitz/dj-database-url) | |
# For example: | |
# export DATABASE_URL=postgres://user:password@host/mydatabase | |
db_conf = dj_database_url.config() | |
if not db_conf: | |
raise Exception( | |
'set DATABASE_URL environment variable with your connection, e.g. ' | |
'export DATABASE_URL=postgres://user:password@host/mydatabase' | |
) | |
con = psycopg2.connect(database=db_conf['NAME'], | |
user=db_conf['USER'], | |
password=db_conf['PASSWORD'], | |
host=db_conf['HOST'], | |
cursor_factory=psycopg2.extras.RealDictCursor) | |
c = con.cursor() | |
con.commit() | |
CROSS_MATCH = """ | |
CREATE TABLE IF NOT EXISTS cross_match_exact AS | |
SELECT source.donor_id AS source, | |
target.donor_id AS target | |
FROM data_a AS source | |
INNER JOIN data_b AS target | |
USING({fields})""" | |
c.execute(CROSS_MATCH.format(fields=FIELDS)) | |
c.execute("CREATE INDEX IF NOT EXISTS source_idx ON cross_match_exact (source)") | |
c.execute("CREATE INDEX IF NOT EXISTS target_idx ON cross_match_exact (target)") | |
con.commit() | |
B_SELECT = """ | |
SELECT data_a.* | |
FROM data_a | |
LEFT JOIN cross_match_exact | |
ON donor_id = source | |
WHERE target IS NULL""" | |
M8_SELECT = """ | |
SELECT data_b.* | |
FROM data_b | |
LEFT JOIN cross_match_exact | |
ON donor_id = target | |
WHERE source IS NULL""" | |
if os.path.exists(settings_file): | |
print('reading from ', settings_file) | |
with open(settings_file, 'rb') as sf: | |
deduper = StaticRecordLink(sf, num_cores=4, connection=con) | |
else: | |
# Define the fields dedupe will pay attention to | |
# | |
# The address, city, and zip fields are often missing, so we'll | |
# tell dedupe that, and we'll learn a model that take that into | |
# account | |
fields = [ | |
{'field': 'bpcounty', 'type': 'Exact'}, | |
{'field': 'age', 'type': 'Price'}, | |
{'field': 'bpparish', 'type': 'String', 'crf': True} | |
] | |
# Create a new deduper object and pass our data model to it. | |
deduper = RecordLink(fields, num_cores=4, connection=con) | |
# Named cursor runs server side with psycopg2 | |
with con.cursor('donor_select') as cur: | |
cur.execute(B_SELECT) | |
temp_d = {i: row for i, row in enumerate(cur)} | |
with con.cursor('tower_select') as cur: | |
cur.execute(M8_SELECT) | |
temp_z = {i: row for i, row in enumerate(cur)} | |
deduper.sample(temp_d, temp_z) | |
# If we have training data saved from a previous run of dedupe, | |
# look for it an load it in. | |
# | |
# __Note:__ if you want to train from | |
# scratch, delete the training_file | |
if os.path.exists(training_file): | |
print('reading labeled examples from ', training_file) | |
with open(training_file) as tf: | |
deduper.readTraining(tf) | |
# ## Active learning | |
print('starting active labeling...') | |
# Starts the training loop. Dedupe will find the next pair of records | |
# it is least certain about and ask you to label them as duplicates | |
# or not. | |
# use 'y', 'n' and 'u' keys to flag duplicates | |
# press 'f' when you are finished | |
dedupe.convenience.consoleLabel(deduper) | |
# When finished, save our labeled, training pairs to disk | |
with open(training_file, 'w') as tf: | |
deduper.writeTraining(tf) | |
# Notice our argument here | |
# | |
# `recall` is the proportion of true dupes pairs that the learned | |
# rules must cover. You may want to reduce this if your are making | |
# too many blocks and too many comparisons. | |
deduper.train(recall=0.90) | |
with open(settings_file, 'wb') as sf: | |
deduper.writeSettings(sf) | |
# We can now remove some of the memory hogging objects we used | |
# for training | |
deduper.cleanupTraining() | |
with con.cursor('donor_select') as cur: | |
cur.execute(B_SELECT) | |
data_1 = {row['donor_id']: row for row in cur} | |
with con.cursor('tower_select') as cur: | |
cur.execute(M8_SELECT) | |
data_2 = {row['donor_id']: row for row in cur} | |
clustered_dupes = deduper.match(data_1, data_2, threshold=0.0) | |
# Writing out results | |
c.execute("DROP TABLE IF EXISTS cross_match_fuzzy") | |
print('creating cross_match_fuzzy database') | |
c.execute("CREATE TABLE cross_match_fuzzy " | |
"(source INTEGER, target INTEGER, " | |
" score FLOAT)") | |
csv_file = tempfile.NamedTemporaryFile(prefix='cross_match_fuzzy_', delete=False, | |
mode='w') | |
csv_writer = csv.writer(csv_file) | |
for (source, target), score in clustered_dupes: | |
csv_writer.writerow((source, target, score)) | |
csv_file.close() | |
with open(csv_file.name, 'r') as f: | |
c.copy_expert("COPY cross_match_fuzzy FROM STDIN CSV", f) | |
os.remove(csv_file.name) | |
con.commit() | |
c.close() | |
con.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment