Last active
July 18, 2022 11:55
-
-
Save achinta/0a6dab8e3ed5c12c1e3e7298c282a824 to your computer and use it in GitHub Desktop.
Category Encoder - fit partial
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Iterable | |
class CategoryEncoder(object): | |
""" | |
Once fit method is called, sklearn.preprocessing.LabelEncoder cannot encode new categories. | |
In this category encoder, fit can be called any number times. It encodes categories which it has not seen before, | |
without changing the encoding of existing categories. | |
Usually the first category has encoded value of zero. We can override it with value 'start' | |
""" | |
mapping = {} | |
start = 0 | |
def __init__(self, start=0): | |
self.start = start | |
def fit(self, l): | |
if not isinstance(l, Iterable): | |
l = [l] | |
for o in l: | |
if o not in self.mapping.keys(): | |
self.mapping[o] = len(self.mapping.keys()) + self.start | |
return self | |
def transform(self, l): | |
if isinstance(l, Iterable): | |
return [self.mapping.get(o,-1) for o in l] | |
else: | |
return self.mapping.get(l,-1) | |
def fit_transform(self,l): | |
self.fit(l) | |
return self.transform(l) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment