Skip to content

Instantly share code, notes, and snippets.

@Laurian
Created November 16, 2017 17:31
Show Gist options
  • Save Laurian/7869355a000c803f26bb434935a367cb to your computer and use it in GitHub Desktop.
Save Laurian/7869355a000c803f26bb434935a367cb to your computer and use it in GitHub Desktop.
Experimental: reading dpwe/audfprint signatures into mysql using worldveil/dejavu code
{
"database": {
"host": "127.0.0.1",
"user": "root",
"passwd": "",
"db": "dejavu"
}
}
from __future__ import absolute_import
import abc
class Database(object):
__metaclass__ = abc.ABCMeta
FIELD_FILE_SHA1 = 'file_sha1'
FIELD_SONG_ID = 'song_id'
FIELD_SONGNAME = 'song_name'
FIELD_OFFSET = 'offset'
FIELD_HASH = 'hash'
# Name of your Database subclass, this is used in configuration
# to refer to your class
type = None
def __init__(self):
super(Database, self).__init__()
def before_fork(self):
"""
Called before the database instance is given to the new process
"""
pass
def after_fork(self):
"""
Called after the database instance has been given to the new process
This will be called in the new process.
"""
pass
def setup(self):
"""
Called on creation or shortly afterwards.
"""
pass
@abc.abstractmethod
def empty(self):
"""
Called when the database should be cleared of all data.
"""
pass
@abc.abstractmethod
def delete_unfingerprinted_songs(self):
"""
Called to remove any song entries that do not have any fingerprints
associated with them.
"""
pass
@abc.abstractmethod
def get_num_songs(self):
"""
Returns the amount of songs in the database.
"""
pass
@abc.abstractmethod
def get_num_fingerprints(self):
"""
Returns the number of fingerprints in the database.
"""
pass
@abc.abstractmethod
def set_song_fingerprinted(self, sid):
"""
Sets a specific song as having all fingerprints in the database.
sid: Song identifier
"""
pass
@abc.abstractmethod
def get_songs(self):
"""
Returns all fully fingerprinted songs in the database.
"""
pass
@abc.abstractmethod
def get_song_by_id(self, sid):
"""
Return a song by its identifier
sid: Song identifier
"""
pass
@abc.abstractmethod
def insert(self, hash, sid, offset):
"""
Inserts a single fingerprint into the database.
hash: Part of a sha1 hash, in hexadecimal format
sid: Song identifier this fingerprint is off
offset: The offset this hash is from
"""
pass
@abc.abstractmethod
def insert_song(self, song_name):
"""
Inserts a song name into the database, returns the new
identifier of the song.
song_name: The name of the song.
"""
pass
@abc.abstractmethod
def query(self, hash):
"""
Returns all matching fingerprint entries associated with
the given hash as parameter.
hash: Part of a sha1 hash, in hexadecimal format
"""
pass
@abc.abstractmethod
def get_iterable_kv_pairs(self):
"""
Returns all fingerprints in the database.
"""
pass
@abc.abstractmethod
def insert_hashes(self, sid, hashes):
"""
Insert a multitude of fingerprints.
sid: Song identifier the fingerprints belong to
hashes: A sequence of tuples in the format (hash, offset)
- hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at.
"""
pass
@abc.abstractmethod
def return_matches(self, hashes):
"""
Searches the database for pairs of (hash, offset) values.
hashes: A sequence of tuples in the format (hash, offset)
- hash: Part of a sha1 hash, in hexadecimal format
- offset: Offset this hash was created from/at.
Returns a sequence of (sid, offset_difference) tuples.
sid: Song identifier
offset_difference: (offset - database_offset)
"""
pass
def get_database(database_type=None):
# Default to using the mysql database
database_type = database_type or "mysql"
# Lower all the input.
database_type = database_type.lower()
for db_cls in Database.__subclasses__():
if db_cls.type == database_type:
return db_cls
raise TypeError("Unsupported database type supplied.")
# Import our default database handler
import database_sql
from __future__ import absolute_import
from itertools import izip_longest
import Queue
import MySQLdb as mysql
from MySQLdb.cursors import DictCursor
from database import Database
class SQLDatabase(Database):
"""
Queries:
1) Find duplicates (shouldn't be any, though):
select `hash`, `song_id`, `offset`, count(*) cnt
from fingerprints
group by `hash`, `song_id`, `offset`
having cnt > 1
order by cnt asc;
2) Get number of hashes by song:
select song_id, song_name, count(song_id) as num
from fingerprints
natural join songs
group by song_id
order by count(song_id) desc;
3) get hashes with highest number of collisions
select
hash,
count(distinct song_id) as n
from fingerprints
group by `hash`
order by n DESC;
=> 26 different songs with same fingerprint (392 times):
select songs.song_name, fingerprints.offset
from fingerprints natural join songs
where fingerprints.hash = "08d3c833b71c60a7b620322ac0c0aba7bf5a3e73";
"""
type = "mysql"
# tables
FINGERPRINTS_TABLENAME = "fingerprints"
SONGS_TABLENAME = "songs"
# fields
FIELD_FINGERPRINTED = "fingerprinted"
# creates
CREATE_FINGERPRINTS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
`%s` binary(10) not null,
`%s` mediumint unsigned not null,
`%s` int unsigned not null,
INDEX (%s),
UNIQUE KEY `unique_constraint` (%s, %s, %s),
FOREIGN KEY (%s) REFERENCES %s(%s) ON DELETE CASCADE
) ENGINE=INNODB;""" % (
FINGERPRINTS_TABLENAME, Database.FIELD_HASH,
Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH,
Database.FIELD_SONG_ID, Database.FIELD_OFFSET, Database.FIELD_HASH,
Database.FIELD_SONG_ID, SONGS_TABLENAME, Database.FIELD_SONG_ID
)
CREATE_SONGS_TABLE = """
CREATE TABLE IF NOT EXISTS `%s` (
`%s` mediumint unsigned not null auto_increment,
`%s` varchar(250) not null,
`%s` tinyint default 0,
`%s` binary(20) not null,
PRIMARY KEY (`%s`),
UNIQUE KEY `%s` (`%s`)
) ENGINE=INNODB;""" % (
SONGS_TABLENAME, Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, FIELD_FINGERPRINTED,
Database.FIELD_FILE_SHA1,
Database.FIELD_SONG_ID, Database.FIELD_SONG_ID, Database.FIELD_SONG_ID,
)
# inserts (ignores duplicates)
INSERT_FINGERPRINT = """
INSERT IGNORE INTO %s (%s, %s, %s) values
(UNHEX(%%s), %%s, %%s);
""" % (FINGERPRINTS_TABLENAME, Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET)
INSERT_SONG = "INSERT INTO %s (%s, %s) values (%%s, UNHEX(%%s));" % (
SONGS_TABLENAME, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1)
# selects
SELECT1 = """
SELECT %s, %s FROM %s WHERE %s = UNHEX(%%s);
""" % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME, Database.FIELD_HASH)
SELECT = """
SELECT %s, %s FROM %s WHERE %s = UNHEX(LPAD(RPAD(%%s, 19, '0'), 20, '0'));
""" % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME, Database.FIELD_HASH)
SELECT_MULTIPLE = """
SELECT HEX(%s), %s, %s FROM %s WHERE %s IN (%%s);
""" % (Database.FIELD_HASH, Database.FIELD_SONG_ID, Database.FIELD_OFFSET,
FINGERPRINTS_TABLENAME, Database.FIELD_HASH)
SELECT_ALL = """
SELECT %s, %s FROM %s;
""" % (Database.FIELD_SONG_ID, Database.FIELD_OFFSET, FINGERPRINTS_TABLENAME)
SELECT_SONG = """
SELECT %s, HEX(%s) as %s FROM %s WHERE %s = %%s;
""" % (Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1, SONGS_TABLENAME, Database.FIELD_SONG_ID)
SELECT_NUM_FINGERPRINTS = """
SELECT COUNT(*) as n FROM %s
""" % (FINGERPRINTS_TABLENAME)
SELECT_UNIQUE_SONG_IDS = """
SELECT COUNT(DISTINCT %s) as n FROM %s WHERE %s = 1;
""" % (Database.FIELD_SONG_ID, SONGS_TABLENAME, FIELD_FINGERPRINTED)
SELECT_SONGS = """
SELECT %s, %s, HEX(%s) as %s FROM %s WHERE %s = 1;
""" % (Database.FIELD_SONG_ID, Database.FIELD_SONGNAME, Database.FIELD_FILE_SHA1, Database.FIELD_FILE_SHA1,
SONGS_TABLENAME, FIELD_FINGERPRINTED)
# drops
DROP_FINGERPRINTS = "DROP TABLE IF EXISTS %s;" % FINGERPRINTS_TABLENAME
DROP_SONGS = "DROP TABLE IF EXISTS %s;" % SONGS_TABLENAME
# update
UPDATE_SONG_FINGERPRINTED = """
UPDATE %s SET %s = 1 WHERE %s = %%s
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED, Database.FIELD_SONG_ID)
# delete
DELETE_UNFINGERPRINTED = """
DELETE FROM %s WHERE %s = 0;
""" % (SONGS_TABLENAME, FIELD_FINGERPRINTED)
def __init__(self, **options):
super(SQLDatabase, self).__init__()
self.cursor = cursor_factory(**options)
self._options = options
def after_fork(self):
# Clear the cursor cache, we don't want any stale connections from
# the previous process.
Cursor.clear_cache()
def setup(self):
"""
Creates any non-existing tables required for dejavu to function.
This also removes all songs that have been added but have no
fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(self.CREATE_SONGS_TABLE)
cur.execute(self.CREATE_FINGERPRINTS_TABLE)
cur.execute(self.DELETE_UNFINGERPRINTED)
def empty(self):
"""
Drops tables created by dejavu and then creates them again
by calling `SQLDatabase.setup`.
.. warning:
This will result in a loss of data
"""
with self.cursor() as cur:
cur.execute(self.DROP_FINGERPRINTS)
cur.execute(self.DROP_SONGS)
self.setup()
def delete_unfingerprinted_songs(self):
"""
Removes all songs that have no fingerprints associated with them.
"""
with self.cursor() as cur:
cur.execute(self.DELETE_UNFINGERPRINTED)
def get_num_songs(self):
"""
Returns number of songs the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(self.SELECT_UNIQUE_SONG_IDS)
for count, in cur:
return count
return 0
def get_num_fingerprints(self):
"""
Returns number of fingerprints the database has fingerprinted.
"""
with self.cursor() as cur:
cur.execute(self.SELECT_NUM_FINGERPRINTS)
for count, in cur:
return count
return 0
def set_song_fingerprinted(self, sid):
"""
Set the fingerprinted flag to TRUE (1) once a song has been completely
fingerprinted in the database.
"""
with self.cursor() as cur:
cur.execute(self.UPDATE_SONG_FINGERPRINTED, (sid,))
def get_songs(self):
"""
Return songs that have the fingerprinted flag set TRUE (1).
"""
with self.cursor(cursor_type=DictCursor) as cur:
cur.execute(self.SELECT_SONGS)
for row in cur:
yield row
def get_song_by_id(self, sid):
"""
Returns song by its ID.
"""
with self.cursor(cursor_type=DictCursor) as cur:
cur.execute(self.SELECT_SONG, (sid,))
return cur.fetchone()
def insert(self, hash, sid, offset):
"""
Insert a (sha1, song_id, offset) row into database.
"""
with self.cursor() as cur:
cur.execute(self.INSERT_FINGERPRINT, (hash, sid, offset))
def insert_song(self, songname, file_hash):
"""
Inserts song in the database and returns the ID of the inserted record.
"""
with self.cursor() as cur:
cur.execute(self.INSERT_SONG, (songname, file_hash))
return cur.lastrowid
def query(self, hash):
"""
Return all tuples associated with hash.
If hash is None, returns all entries in the
database (be careful with that one!).
"""
# select all if no key
query = self.SELECT_ALL if hash is None else self.SELECT
with self.cursor() as cur:
cur.execute(query, (hash,))
for sid, offset in cur:
yield (sid, offset)
def get_iterable_kv_pairs(self):
"""
Returns all tuples in database.
"""
return self.query(None)
def insert_hashes(self, sid, hashes):
"""
Insert series of hash => song_id, offset
values into the database.
"""
values = []
# for hash, offset in hashes:
for offset, hash in hashes:
# values.append((hash, sid, offset))
# print('{:x}'.format(hash))
values.append(('{:x}'.format(hash), sid, offset))
with self.cursor() as cur:
for split_values in grouper(values, 1000):
cur.executemany(self.INSERT_FINGERPRINT, split_values)
def return_matches(self, hashes):
"""
Return the (song_id, offset_diff) tuples associated with
a list of (sha1, sample_offset) values.
"""
# Create a dictionary of hash => offset pairs for later lookups
mapper = {}
for hash, offset in hashes:
mapper[hash.upper()] = offset
# Get an iteratable of all the hashes we need
values = mapper.keys()
with self.cursor() as cur:
for split_values in grouper(values, 1000):
# Create our IN part of the query
query = self.SELECT_MULTIPLE
query = query % ', '.join(['UNHEX(%s)'] * len(split_values))
cur.execute(query, split_values)
for hash, sid, offset in cur:
# (sid, db_offset - song_sampled_offset)
yield (sid, offset - mapper[hash])
def __getstate__(self):
return (self._options,)
def __setstate__(self, state):
self._options, = state
self.cursor = cursor_factory(**self._options)
def grouper(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return (filter(None, values) for values
in izip_longest(fillvalue=fillvalue, *args))
def cursor_factory(**factory_options):
def cursor(**options):
options.update(factory_options)
return Cursor(**options)
return cursor
class Cursor(object):
"""
Establishes a connection to the database and returns an open cursor.
```python
# Use as context manager
with Cursor() as cur:
cur.execute(query)
```
"""
_cache = Queue.Queue(maxsize=5)
def __init__(self, cursor_type=mysql.cursors.Cursor, **options):
super(Cursor, self).__init__()
try:
conn = self._cache.get_nowait()
except Queue.Empty:
conn = mysql.connect(**options)
else:
# Ping the connection before using it from the cache.
conn.ping(True)
self.conn = conn
self.conn.autocommit(False)
self.cursor_type = cursor_type
@classmethod
def clear_cache(cls):
cls._cache = Queue.Queue(maxsize=5)
def __enter__(self):
self.cursor = self.conn.cursor(self.cursor_type)
return self.cursor
def __exit__(self, extype, exvalue, traceback):
# if we had a MySQL related error we try to rollback the cursor.
if extype is mysql.MySQLError:
self.cursor.rollback()
self.cursor.close()
self.conn.commit()
# Put it back on the queue
try:
self._cache.put_nowait(self.conn)
except Queue.Full:
self.conn.close()
import json
import sys
from database import get_database, Database
DATABASE_CONFIG_FILE = "./database.cnf"
try:
with open(DATABASE_CONFIG_FILE) as f:
config = json.load(f)
except IOError as err:
print("Cannot open configuration: %s. Exiting" % (str(err)))
sys.exit(1)
db_cls = get_database(config.get("database_type", None))
db = db_cls(**config.get("database", {}))
db.setup()
import numpy as np
# For reading/writing hashes to file
import struct
# Format string for writing binary data to file
HASH_FMT = '<2i'
HASH_MAGIC = 'audfprinthashV00' # 16 chars, FWIW
PEAK_FMT = '<2i'
PEAK_MAGIC = 'audfprintpeakV00' # 16 chars, FWIW
def hashes_load(hashfilename):
""" Read back a set of hashes written by hashes_save """
hashes = []
fmtsize = struct.calcsize(HASH_FMT)
with open(hashfilename, 'rb') as f:
magic = f.read(len(HASH_MAGIC))
if magic != HASH_MAGIC:
raise IOError('%s is not a hash file (magic %s)'
% (hashfilename, magic))
data = f.read(fmtsize)
while data is not None and len(data) == fmtsize:
hashes.append(struct.unpack(HASH_FMT, data))
data = f.read(fmtsize)
return hashes
###############################################################
hashes = hashes_load('./Samples/program1.afpt')
print len(hashes)
id = db.insert_song('program1', '0001');
db.insert_hashes(id, hashes);
hashes = hashes_load('./Samples/program2.afpt')
print len(hashes)
id = db.insert_song('program2', '0001');
db.insert_hashes(id, hashes);
hashes = hashes_load('./Samples/program3.afpt')
print len(hashes)
id = db.insert_song('program3', '0001');
db.insert_hashes(id, hashes);
hashes = hashes_load('./Samples/program4.afpt')
print len(hashes)
id = db.insert_song('program4', '0001');
db.insert_hashes(id, hashes);
hashes = hashes_load('./Samples/program5.afpt')
print len(hashes)
id = db.insert_song('program5', '0001');
db.insert_hashes(id, hashes);
################################################################
hashes = hashes_load('./Samples/viral.afpt')
print len(hashes)
a = []
for hash in hashes:
print hash
hash_ = '{:x}'.format(hash[1])
print(hash_)
r = list(db.query(hash_))
print(r)
b = { 'hash': hash_, 'time': hash[0],'matches': r }
a.append(b)
print json.dumps(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment