Created
November 7, 2015 00:37
-
-
Save mickypaganini/70007d5b52553fb4d5a7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python2.7/site-packages/IPython/kernel/__init__.py:13: ShimWarning: The `IPython.kernel` package has been deprecated. You should import from ipykernel or jupyter_client instead.\n", | |
" \"You should import from ipykernel or jupyter_client instead.\", ShimWarning)\n" | |
] | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"import pandautils as pu\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"from sklearn import linear_model\n", | |
"import cPickle as pickle\n", | |
"import glob\n", | |
"from numpy.lib.recfunctions import stack_arrays\n", | |
"from root_numpy import root2rec\n", | |
"%matplotlib notebook" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamPhotons' of branch 'HGamPhotons' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamPhotonsAux.' of branch 'HGamPhotonsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJets' of branch 'HGamAntiKt4EMTopoJets' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJetsAux.' of branch 'HGamAntiKt4EMTopoJetsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJets_AllSelAnyTag' of branch 'HGamAntiKt4EMTopoJets_AllSelAnyTag' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJets_AllSelAnyTagAux.' of branch 'HGamAntiKt4EMTopoJets_AllSelAnyTagAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamElectrons' of branch 'HGamElectrons' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamElectronsAux.' of branch 'HGamElectronsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamMuons' of branch 'HGamMuons' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamMuonsAux.' of branch 'HGamMuonsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamEventInfo' of branch 'HGamEventInfo' with type 'xAOD::EventInfo_v1' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamEventInfoAux.' of branch 'HGamEventInfoAux.' with type 'xAOD::AuxInfoBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'EventInfo' of branch 'EventInfo' with type 'xAOD::EventInfo_v1' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamPhotons' of branch 'HGamPhotons' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamPhotonsAux.' of branch 'HGamPhotonsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJets' of branch 'HGamAntiKt4EMTopoJets' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJetsAux.' of branch 'HGamAntiKt4EMTopoJetsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJets_AllSelAnyTag' of branch 'HGamAntiKt4EMTopoJets_AllSelAnyTag' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamAntiKt4EMTopoJets_AllSelAnyTagAux.' of branch 'HGamAntiKt4EMTopoJets_AllSelAnyTagAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamElectrons' of branch 'HGamElectrons' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamElectronsAux.' of branch 'HGamElectronsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamMuons' of branch 'HGamMuons' with type '' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamMuonsAux.' of branch 'HGamMuonsAux.' with type 'xAOD::AuxContainerBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamEventInfo' of branch 'HGamEventInfo' with type 'xAOD::EventInfo_v1' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'HGamEventInfoAux.' of branch 'HGamEventInfoAux.' with type 'xAOD::AuxInfoBase' (skipping)\n", | |
" weight_name)\n", | |
"/usr/local/lib/python2.7/site-packages/root_numpy/_tree.py:205: RootNumpyUnconvertibleWarning: cannot convert leaf 'EventInfo' of branch 'EventInfo' with type 'xAOD::EventInfo_v1' (skipping)\n", | |
" weight_name)\n" | |
] | |
} | |
], | |
"source": [ | |
"H300_df = pu.root2panda('../MxAOD/framework-00-02-30_rel2.3.32/*H300*.root/*', 'CollectionTree')\n", | |
"yybb_df = pu.root2panda('../MxAOD/framework-00-02-30_rel2.3.32/*A14NNPDF23LO_yybb*.root/*', 'CollectionTree')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"H300_df['class'] = 1\n", | |
"yybb_df['class'] = 0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 102, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"df = pd.concat([H300_df, yybb_df], ignore_index=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 103, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"features = [key for key in df.keys() if (\"HGamAntiKt4EMTopoJetsAuxDyn\" in key)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 104, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# -- separate training features from target \n", | |
"X = df[features].values\n", | |
"y = df['class'].values" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 105, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# -- get number of available examples\n", | |
"ix = range(X.shape[0])\n", | |
"# -- shuffle the indices to shuffle X and y\n", | |
"np.random.shuffle(ix)\n", | |
"X, y = X[ix], y[ix]\n", | |
"# -- divide the sample in half for training and testing\n", | |
"n = X.shape[0] / 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"#from sklearn.preprocessing import StandardScaler\n", | |
"#scaler = StandardScaler()\n", | |
"#Z = scaler.fit_transform(np.array(X))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"MAX_N_JETS = max([k.shape[0] for k in df['HGamAntiKt4EMTopoJetsAuxDyn.pt']])\n", | |
"N_VARIABLES = X.shape[1]\n", | |
"data = np.zeros((df.shape[0], MAX_N_JETS, N_VARIABLES))\n", | |
"SORT_COL = 'HGamAntiKt4EMTopoJetsAuxDyn.pt'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 108, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"df_sliced = df[features]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 149, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Processing 0 of 61244.\n", | |
"Processing 1000 of 61244.\n", | |
"Processing 2000 of 61244.\n", | |
"Processing 3000 of 61244.\n", | |
"Processing 4000 of 61244.\n", | |
"Processing 5000 of 61244.\n", | |
"Processing 6000 of 61244.\n", | |
"Processing 7000 of 61244.\n", | |
"Processing 8000 of 61244.\n", | |
"Processing 9000 of 61244.\n", | |
"Processing 10000 of 61244.\n", | |
"Processing 11000 of 61244.\n", | |
"Processing 12000 of 61244.\n", | |
"Processing 13000 of 61244.\n", | |
"Processing 14000 of 61244.\n", | |
"Processing 15000 of 61244.\n", | |
"Processing 16000 of 61244.\n", | |
"Processing 17000 of 61244.\n", | |
"Processing 18000 of 61244.\n", | |
"Processing 19000 of 61244.\n", | |
"Processing 20000 of 61244.\n", | |
"Processing 21000 of 61244.\n", | |
"Processing 22000 of 61244.\n", | |
"Processing 23000 of 61244.\n", | |
"Processing 24000 of 61244.\n", | |
"Processing 25000 of 61244.\n", | |
"Processing 26000 of 61244.\n", | |
"Processing 27000 of 61244.\n", | |
"Processing 28000 of 61244.\n", | |
"Processing 29000 of 61244.\n", | |
"Processing 30000 of 61244.\n", | |
"Processing 31000 of 61244.\n", | |
"Processing 32000 of 61244.\n", | |
"Processing 33000 of 61244.\n", | |
"Processing 34000 of 61244.\n", | |
"Processing 35000 of 61244.\n", | |
"Processing 36000 of 61244.\n", | |
"Processing 37000 of 61244.\n", | |
"Processing 38000 of 61244.\n", | |
"Processing 39000 of 61244.\n", | |
"Processing 40000 of 61244.\n", | |
"Processing 41000 of 61244.\n", | |
"Processing 42000 of 61244.\n", | |
"Processing 43000 of 61244.\n", | |
"Processing 44000 of 61244.\n", | |
"Processing 45000 of 61244.\n", | |
"Processing 46000 of 61244.\n", | |
"Processing 47000 of 61244.\n", | |
"Processing 48000 of 61244.\n", | |
"Processing 49000 of 61244.\n", | |
"Processing 50000 of 61244.\n", | |
"Processing 51000 of 61244.\n", | |
"Processing 52000 of 61244.\n", | |
"Processing 53000 of 61244.\n", | |
"Processing 54000 of 61244.\n", | |
"Processing 55000 of 61244.\n", | |
"Processing 56000 of 61244.\n", | |
"Processing 57000 of 61244.\n", | |
"Processing 58000 of 61244.\n", | |
"Processing 59000 of 61244.\n", | |
"Processing 60000 of 61244.\n", | |
"Processing 61000 of 61244.\n" | |
] | |
} | |
], | |
"source": [ | |
"for i, jet in df_sliced.iterrows():\n", | |
" if i % 1000 == 0:\n", | |
" print 'Processing %s of %s.' % (i, df.shape[0])\n", | |
" #if len(df_sliced[SORT_COL][i]) > 0: # -- just to make sure there arent zero jets\n", | |
" sorted_data = np.array([v.tolist() for v in jet.get_values()])[:, (np.argsort(jet[SORT_COL]))[::-1]]\n", | |
" n_jets = sorted_data.shape[1]\n", | |
" data[i, :n_jets, :] = sorted_data[:, :(min(n_jets, MAX_N_JETS))].T # you may need a transpose" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 150, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Processing 0 of 18.\n", | |
"Processing 1 of 18.\n", | |
"Processing 2 of 18.\n", | |
"Processing 3 of 18.\n", | |
"Processing 4 of 18.\n", | |
"Processing 5 of 18.\n", | |
"Processing 6 of 18.\n", | |
"Processing 7 of 18.\n", | |
"Processing 8 of 18.\n", | |
"Processing 9 of 18.\n", | |
"Processing 10 of 18.\n", | |
"Processing 11 of 18.\n", | |
"Processing 12 of 18.\n", | |
"Processing 13 of 18.\n", | |
"Processing 14 of 18.\n", | |
"Processing 15 of 18.\n", | |
"Processing 16 of 18.\n", | |
"Processing 17 of 18.\n" | |
] | |
} | |
], | |
"source": [ | |
"scale = {}\n", | |
"for v in xrange(N_VARIABLES):\n", | |
" print 'Processing %s of %s.' % (v, N_VARIABLES)\n", | |
" f = data[:, :, v]\n", | |
" slc = f[f > 0]\n", | |
" m, s = slc.mean(), slc.std()\n", | |
" slc -= m\n", | |
" slc /= s\n", | |
" data[:, :, v][f > 0] = slc\n", | |
" scale[v] = {'mean' : m, 'sd' : s}\n", | |
"\n", | |
"data[np.isnan(data)] = 0.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 173, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras.models import Sequential\n", | |
"from keras.layers.core import TimeDistributedDense, Dropout, Activation, Dense, RepeatVector\n", | |
"from keras.layers.recurrent import LSTM, JZS1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 192, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"model = Sequential()\n", | |
"model.add(JZS1(32, input_shape=(None, MAX_N_JETS)))\n", | |
"model.add(RepeatVector(N_VARIABLES + 1))\n", | |
"#for _ in range(3):\n", | |
"# model.add(JZS1(32, return_sequences=True))\n", | |
"model.add(JZS1(1, return_sequences=True))\n", | |
"# For each of step of the output sequence, decide which character should be chosen\n", | |
"model.add(TimeDistributedDense(MAX_N_JETS))\n", | |
"model.add(Activation('sigmoid'))\n", | |
"\n", | |
"model.compile(loss='categorical_crossentropy', optimizer='adam')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 208, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/20\n" | |
] | |
}, | |
{ | |
"ename": "TypeError", | |
"evalue": "('Bad input argument to theano function with name \"build/bdist.macosx-10.10-intel/egg/keras/models.py:404\" at index 1(0-based)', 'Wrong number of dimensions: expected 3, got 2 with shape (4, 1).')", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-208-03a9a0dcba99>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnb_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python2.7/site-packages/Keras-0.2.0-py2.7.egg/keras/models.pyc\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, batch_size, nb_epoch, verbose, callbacks, validation_split, validation_data, shuffle, show_accuracy, class_weight, sample_weight)\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mval_f\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval_f\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_ins\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval_ins\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 494\u001b[0;31m shuffle=shuffle, metrics=metrics)\n\u001b[0m\u001b[1;32m 495\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 496\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python2.7/site-packages/Keras-0.2.0-py2.7.egg/keras/models.pyc\u001b[0m in \u001b[0;36m_fit\u001b[0;34m(self, f, ins, out_labels, batch_size, nb_epoch, verbose, callbacks, val_f, val_ins, shuffle, metrics)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'size'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_logs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python2.7/site-packages/Theano-0.7.0-py2.7.egg/theano/compile/function_module.pyc\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 765\u001b[0m s.storage[0] = s.type.filter(\n\u001b[1;32m 766\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m allow_downcast=s.allow_downcast)\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 769\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python2.7/site-packages/Theano-0.7.0-py2.7.egg/theano/tensor/type.pyc\u001b[0m in \u001b[0;36mfilter\u001b[0;34m(self, data, strict, allow_downcast)\u001b[0m\n\u001b[1;32m 175\u001b[0m raise TypeError(\"Wrong number of dimensions: expected %s,\"\n\u001b[1;32m 176\u001b[0m \" got %s with shape %s.\" % (self.ndim, data.ndim,\n\u001b[0;32m--> 177\u001b[0;31m data.shape))\n\u001b[0m\u001b[1;32m 178\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maligned\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: ('Bad input argument to theano function with name \"build/bdist.macosx-10.10-intel/egg/keras/models.py:404\" at index 1(0-based)', 'Wrong number of dimensions: expected 3, got 2 with shape (4, 1).')" | |
] | |
} | |
], | |
"source": [ | |
"model.fit(data[:n], y[:n], batch_size=4, nb_epoch=20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 206, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(18,)" | |
] | |
}, | |
"execution_count": 206, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data[0][0].shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment