-
-
Save bryant1410/4d81cd1e325302bd9f08 to your computer and use it in GitHub Desktop.
This function helps to do a one hot encoding of a pandas' dataframe instead of a features numpy matrix. This has some advantages, for instance the fact of knowing which new columns have been created (identifying them easily).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" Small script that shows hot to do one hot encoding | |
of categorical columns in a pandas DataFrame. | |
See: | |
http://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.OneHotEncoder.html#sklearn.preprocessing.OneHotEncoder | |
http://scikit-learn.org/dev/modules/generated/sklearn.feature_extraction.DictVectorizer.html | |
""" | |
import pandas | |
import random | |
import numpy | |
from sklearn.feature_extraction import DictVectorizer | |
def one_hot_dataframe(data, cols, replace=False): | |
""" Takes a dataframe and a list of columns that need to be encoded. | |
Returns a 3-tuple comprising the data, the vectorized data, | |
and the fitted vectorizor. | |
""" | |
vec = DictVectorizer() | |
vecData = pd.DataFrame(vec.fit_transform(data[cols].to_dict(outtype='records')).toarray()) | |
vecData.columns = vec.get_feature_names() | |
vecData.index = data.index | |
if replace is True: | |
data = data.drop(cols, axis=1) | |
data = data.join(vecData) | |
return (data, vecData, vec) | |
def main(): | |
# Get a random DataFrame | |
df = pandas.DataFrame(numpy.random.randn(25, 3), columns=['a', 'b', 'c']) | |
# Make some random categorical columns | |
df['e'] = [random.choice(('Chicago', 'Boston', 'New York')) for i in range(df.shape[0])] | |
df['f'] = [random.choice(('Chrome', 'Firefox', 'Opera', "Safari")) for i in range(df.shape[0])] | |
print df | |
# Vectorize the categorical columns: e & f | |
df, _, _ = one_hot_dataframe(df, ['e', 'f'], replace=True) | |
print df | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Example output | |
Original DataFrame | |
------------------ | |
a b c e f | |
0 -0.219222 -0.368154 0.388479 New York Opera | |
1 1.879536 -0.033210 -0.099437 New York Firefox | |
2 0.909419 -0.498084 0.084163 New York Safari | |
3 -0.002199 -0.692806 -0.844436 New York Opera | |
4 -0.109549 -0.367305 -0.520999 Chicago Firefox | |
5 -0.400515 -1.202466 -1.664337 New York Chrome | |
6 -2.241892 -0.888160 -0.332380 New York Chrome | |
7 -0.432767 -1.794931 0.975878 Chicago Chrome | |
8 -1.401193 -0.478224 0.112729 Chicago Safari | |
9 -1.493518 0.584824 0.652820 New York Opera | |
10 0.525359 -0.885912 0.474492 Boston Firefox | |
11 0.671226 -0.733788 0.272915 Boston Chrome | |
12 0.775901 -0.163745 0.628414 Boston Opera | |
13 -1.158007 -0.495240 1.183522 New York Chrome | |
14 -1.200085 1.083380 -0.692171 Boston Safari | |
15 0.872763 -2.119172 -0.169185 Boston Chrome | |
16 1.423514 -1.802891 -2.947628 Boston Safari | |
17 -0.547940 -0.788654 -1.065005 Boston Safari | |
18 -0.380440 2.050783 1.548453 New York Firefox | |
19 -0.095913 1.260104 0.196552 Boston Opera | |
20 -1.558961 1.240931 -0.165927 Boston Safari | |
21 1.111618 -0.309371 -0.803404 Chicago Chrome | |
22 0.348182 -1.200900 0.307754 New York Firefox | |
23 -0.834901 0.188590 -1.115227 New York Chrome | |
24 1.463240 -1.559017 0.954684 New York Chrome | |
Encoded DataFrame | |
----------------- | |
a b c e=Boston e=Chicago e=New York f=Chrome f=Firefox f=Opera f=Safari | |
0 -0.219222 -0.368154 0.388479 0 0 1 0 0 1 0 | |
1 1.879536 -0.033210 -0.099437 0 0 1 0 1 0 0 | |
2 0.909419 -0.498084 0.084163 0 0 1 0 0 0 1 | |
3 -0.002199 -0.692806 -0.844436 0 0 1 0 0 1 0 | |
4 -0.109549 -0.367305 -0.520999 0 1 0 0 1 0 0 | |
5 -0.400515 -1.202466 -1.664337 0 0 1 1 0 0 0 | |
6 -2.241892 -0.888160 -0.332380 0 0 1 1 0 0 0 | |
7 -0.432767 -1.794931 0.975878 0 1 0 1 0 0 0 | |
8 -1.401193 -0.478224 0.112729 0 1 0 0 0 0 1 | |
9 -1.493518 0.584824 0.652820 0 0 1 0 0 1 0 | |
10 0.525359 -0.885912 0.474492 1 0 0 0 1 0 0 | |
11 0.671226 -0.733788 0.272915 1 0 0 1 0 0 0 | |
12 0.775901 -0.163745 0.628414 1 0 0 0 0 1 0 | |
13 -1.158007 -0.495240 1.183522 0 0 1 1 0 0 0 | |
14 -1.200085 1.083380 -0.692171 1 0 0 0 0 0 1 | |
15 0.872763 -2.119172 -0.169185 1 0 0 1 0 0 0 | |
16 1.423514 -1.802891 -2.947628 1 0 0 0 0 0 1 | |
17 -0.547940 -0.788654 -1.065005 1 0 0 0 0 0 1 | |
18 -0.380440 2.050783 1.548453 0 0 1 0 1 0 0 | |
19 -0.095913 1.260104 0.196552 1 0 0 0 0 1 0 | |
20 -1.558961 1.240931 -0.165927 1 0 0 0 0 0 1 | |
21 1.111618 -0.309371 -0.803404 0 1 0 1 0 0 0 | |
22 0.348182 -1.200900 0.307754 0 0 1 0 1 0 0 | |
23 -0.834901 0.188590 -1.115227 0 0 1 1 0 0 0 | |
24 1.463240 -1.559017 0.954684 0 0 1 1 0 0 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment