Created
December 4, 2019 04:47
-
-
Save RicherMans/5317aa2b553b8229c8ca5cf523262564 to your computer and use it in GitHub Desktop.
Train and eval simple GMM for Spoofing using HDF5
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.mixture import GaussianMixture | |
import h5py | |
import fire | |
import numpy as np | |
import pandas as pd | |
from pypeln import thread as pr | |
from joblib import dump, load | |
from tqdm import tqdm | |
def eval(model, features, labels, output='scores.txt'): | |
labels = pd.read_csv(labels) | |
model_dump = load(model) | |
real_gmm = model_dump['real'] | |
spoof_gmm = model_dump['spoof'] | |
with h5py.File(features, 'r') as store, open(output, 'w') as wp: | |
def calc_score(item): | |
fname, label = item | |
feature = store[fname] | |
score = real_gmm.score(feature) - spoof_gmm.score(feature) | |
return fname, score, label | |
for k, score, label in tqdm(pr.map(calc_score, | |
labels[['filename', | |
'label']].values, | |
workers=4, maxsize=2), | |
total=len(labels)): | |
# Get system-type and label from the provided labels | |
wp.write("{} {} {} {}\n".format(k, label, label, score)) | |
if __name__ == '__main__': | |
fire.Fire(eval) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.mixture import GaussianMixture | |
from sklearn.cluster import MiniBatchKMeans | |
from sklearn.decomposition import PCA | |
import h5py | |
import fire | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import kaldi_io | |
from joblib import dump, load | |
import logging | |
FORMAT = '%(asctime)s %(message)s' | |
logging.basicConfig(level=logging.INFO, format=FORMAT) | |
def main(features, labels, out='model.pkl', **kwargs): | |
train_labels = pd.read_csv(labels) | |
# A bit ditry to do that, but datasize is reasonable | |
real_features, spoof_features = [], [] | |
with h5py.File(features, 'r') as store: | |
for row in tqdm(train_labels.itertuples(), total=len(train_labels)): | |
sample = store[row.filename][()] | |
if row.label == 'spoof': | |
spoof_features.append(sample) | |
else: | |
real_features.append(sample) | |
real_features = np.concatenate(real_features) | |
spoof_features = np.concatenate(spoof_features) | |
logging.info("Loading features done") | |
# logging.info("Running PCA") | |
# real_features = PCA(40).fit_transform(real_features) | |
# spoof_features = PCA(40).fit_transform(spoof_features) | |
# Gaussian Components | |
n_components = kwargs.get('n_components', 512) | |
cov_type = kwargs.get('cov_type', 'diag') | |
# Max iter for GMM training | |
max_iter = kwargs.get('max_iter', 100) | |
# Max iter for Kmeans init | |
max_kmeans_iter = kwargs.get('max_kmeans_iter', 5) | |
# Preserve some memory | |
logging.info("Estimating component means with KMeans") | |
# Init Means for real GMM | |
real_kmeans = MiniBatchKMeans(n_clusters=n_components, | |
max_iter=max_kmeans_iter, | |
compute_labels=False, | |
random_state=0) | |
real_kmeans.fit(real_features) | |
spoof_kmeans = MiniBatchKMeans(n_clusters=n_components, | |
max_iter=max_kmeans_iter, | |
compute_labels=False, | |
random_state=0) | |
spoof_kmeans.fit(spoof_features) | |
real_gmm = GaussianMixture(n_components=n_components, | |
verbose=True, | |
max_iter=max_iter, | |
verbose_interval=1, | |
reg_covar=1e-4, | |
covariance_type=cov_type, | |
init_params='random', | |
means_init=real_kmeans.cluster_centers_) | |
spoof_gmm = GaussianMixture(n_components=n_components, | |
verbose=True, | |
max_iter=max_iter, | |
reg_covar=1e-4, | |
verbose_interval=1, | |
covariance_type=cov_type, | |
init_params='random', | |
means_init=spoof_kmeans.cluster_centers_) | |
logging.info("Bonafide GMM {}".format(real_gmm)) | |
logging.info("Spoof GMM {}".format(spoof_gmm)) | |
logging.info("Training Sample size: {} Real, {} Spoof".format( | |
real_features.shape, spoof_features.shape)) | |
logging.info("Training Real") | |
real_gmm.fit(real_features) | |
logging.info("Training Spoof") | |
spoof_gmm.fit(spoof_features) | |
logging.info("Dumping model to {}".format(out)) | |
dump({'spoof': spoof_gmm, 'real': real_gmm}, out) | |
if __name__ == '__main__': | |
fire.Fire(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment