Created
November 18, 2012 21:38
-
-
Save ndawe/4107663 to your computer and use it in GitHub Desktop.
numpy_root histogram functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ROOT | |
import numpy as np | |
import ctypes | |
from rootpy.plotting import Hist, Hist2D, Hist3D | |
NULL_DOUBLE_P = ctypes.POINTER(ctypes.c_double)() | |
def array_fill(hist, data, weights=None): | |
data = np.asarray(data) | |
# TODO: check that data.shape[1] == dimensionality of hist | |
if weights is not None: | |
weights = np.asarray(weights) | |
if data.shape[0] != weights.shape[0]: | |
raise ValueError("blah") # TODO more informative message | |
hist.FillN(data.shape[0], data, weights) | |
else: | |
hist.FillN(data.shape[0], data, np.ones(data.shape[0])) | |
def array2hist(data, has_overflow=False): | |
memmove() | |
def hist2array(hist, include_overflow=False): | |
if isinstance(hist, ROOT.TH3): | |
shape = (hist.GetNbinsZ() + 2, | |
hist.GetNbinsY() + 2, | |
hist.GetNbinsX() + 2) | |
elif isinstance(hist, ROOT.TH2): | |
shape = (hist.GetNbinsY() + 2, hist.GetNbinsX() + 2) | |
elif isinstance(hist, ROOT.TH1): | |
shape = (hist.GetNbinsX() + 2,) | |
else: | |
raise TypeError("blah") # TODO more informative message | |
arr = np.ndarray(shape=shape, buffer=hist.GetArray()) | |
if not include_overflow: | |
# remove overflow and underflow bins | |
#arr = arr[.....] | |
pass | |
return arr | |
h1 = Hist(10, 0, 1, type='d') | |
h2 = Hist2D(10, 0, 1, 5, 0, 1, type='d') | |
h3 = Hist3D(5, 0, 1, 10, 0, 1, 3, 0, 1, type='d') | |
array_fill(h1, np.random.randn(1E6)) | |
print hist2array(h1) | |
#print hist2array(h2) | |
#print hist2array(h3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment