Skip to content

Instantly share code, notes, and snippets.

@chryss
Created May 15, 2014 02:53
Show Gist options
  • Save chryss/3d708c66b5efdfe39741 to your computer and use it in GitHub Desktop.
Save chryss/3d708c66b5efdfe39741 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# encoding: utf-8
"""
evaluatehfdi_debug.py
Debugging scikit-learn Gaussian Mixture Model parameter choice.
Created by Chris Waigl on 2014-05-13.
Copyright (c) 2014 Christine F. Waigl. MIT License.
"""
from __future__ import print_function
import os.path
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from scipy import stats
from sklearn.mixture import GMM
import logging
logging.basicConfig(level=logging.ERROR)
LOGGER = logging.getLogger('pygaarst-scripts.evaluatehfdi_debug')
# files and directories
basedir = "."
outdir = "."
datafile = "hfdi_debug.npz"
firename = 'woodriver'
# save images?
SAVEIMG = True
# number of points for plotting and summing
N = 201
BINS = np.arange(-1, 1, 0.02)
THRESH = 0.001
# what to iterate over
comb = [(196, 216), (197, 216), (196, 215), (197, 215)]
def plot_results(hfdis_all, pdf_fire, pdf_bg, over, cutoff, b1, b2, firename):
mpl.rcParams.update({'font.size': 18, 'font.family': 'Calibri'})
f, ax = plt.subplots(1, 1, figsize=(15, 9))
freq, bins, patches = ax.hist(
hfdis_all,
bins=BINS,
normed=True,
label="%s total SWIR sample pixels" % hfdis_all.size,
color="lightgrey"
)
ax.plot(np.linspace(-1, 1, N), pdf_bg, linewidth=3.0, alpha=0.8, label="GMM fit: background")
ax.plot(np.linspace(-1, 1, N), pdf_fire, linewidth=3.0, alpha=0.8, label="GMM fit: fire")
ax.legend()
ax.set_title("Modified HFDI histogram fitting, %s fire, bands %s and %s" % (firename.capitalize(), b1, b2))
outfile = os.path.join(outdir, "%s_HFDI_GMM_debug_%s_%s.png" % (firename, b1, b2))
plt.savefig(outfile, dpi=150)
plt.close()
def rewrite_params(g):
para = [(
np.round(g.means_[0][0], decimals=4),
np.round(np.sqrt(g.covars_[0][0]), decimals=4),
np.round(g.weights_[0], decimals=4)
), (
np.round(g.means_[1][0], decimals=4),
np.round(np.sqrt(g.covars_[1][0]), decimals=4),
np.round(g.weights_[1], decimals=4)
)]
return para
def rewrite_pdf(para, idx):
return para[idx][2] * stats.norm.pdf(np.linspace(-1, 1, N), loc=para[idx][0], scale=para[idx][1])
def overlap(curve1, curve2):
mins = np.minimum(curve1, curve2)
cutoff = np.argmax(mins)
return sum(mins)/N, np.linspace(-1, 1, N)[cutoff]
def main():
arrs = ['arr_0', 'arr_1', 'arr_2', 'arr_3']
bandlt = dict(zip(arrs, comb))
npzfile = np.load(os.path.join(basedir, datafile))
output = []
for arr in arrs:
b1, b2 = bandlt[arr]
hfdis_all = npzfile[arr]
g = GMM(n_components=2, thresh=THRESH)
g.fit(hfdis_all)
if not g.converged_:
LOGGER.error("Gaussian mixture fit didn't converge")
gpara = rewrite_params(g)
gpara.sort()
pdf_fire = rewrite_pdf(gpara, 1)
pdf_bg = rewrite_pdf(gpara, 0)
over, cutoff = overlap(pdf_fire, pdf_bg)
outtup = (
over, cutoff, gpara[1][0], gpara[1][1], gpara[1][2], gpara[0][0], gpara[0][1], gpara[0][2], b1, b2
)
output.append(outtup)
if SAVEIMG:
plot_results(hfdis_all, pdf_fire, pdf_bg, over, cutoff, b1, b2, firename)
output.sort()
print(output)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment