Last active
August 1, 2024 08:45
-
-
Save iAnanich/79b89a282dee18530ed664fc1c81ebba to your computer and use it in GitHub Desktop.
(De)Normalizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
from typing import Self | |
import numpy as np | |
__all__ = ( | |
"MinMaxNormalizer", | |
) | |
@dataclasses.dataclass(frozen=True) | |
class MinMaxNormalizer: | |
""" | |
Dataclass for normalizing and denormalizing data using min-max normalization scaling to 0..1. | |
Attributes: | |
min_val (float): The minimum value of the data before normalization. | |
max_val (float): The maximum value of the data before normalization. | |
Usage Examples: | |
>>> import numpy as np | |
>>> from normalizer import MinMaxNormalizer # replace with the actual module name | |
>>> # Create a numpy array | |
>>> data = np.array([1, 2, 3, 4, 5]) | |
>>> # Create a normalizer and normalize the data | |
>>> normalizer, normalized_data = MinMaxNormalizer.from_regular(data) | |
>>> print(normalized_data) | |
[0. 0.25 0.5 0.75 1. ] | |
>>> # Denormalize the data | |
>>> denormalized_data = normalizer.denormalize(normalized_data) | |
>>> print(denormalized_data) | |
[1. 2. 3. 4. 5.] | |
Note: | |
The `from_regular` class method is used to create a `_MinMaxNormalizer` instance and normalize the data in one step. | |
The `normalize` and `denormalize` methods are used to normalize and denormalize the data, respectively. | |
""" | |
min_val: float | |
max_val: float | |
@classmethod | |
def from_regular(cls, regular: np.ndarray) -> (Self, np.ndarray): | |
obj = cls( | |
min_val=regular.min(), | |
max_val=regular.max(), | |
) | |
return obj, obj.normalize(regular=regular) | |
def normalize(self, regular: np.ndarray) -> np.ndarray: | |
return (regular - self.min_val) / (self.max_val - self.min_val) | |
def denormalize(self, normalized: np.ndarray) -> np.ndarray: | |
return normalized * (self.max_val - self.min_val) + self.min_val |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pytest | |
from normalizer import MinMaxNormalizer | |
class TestMinMaxNormalizer: | |
@pytest.fixture | |
def normalizer_and_data(self) -> (MinMaxNormalizer, np.ndarray, np.ndarray): | |
data = np.array([1, 2, 3, 4, 5]) | |
normalizer, normalized = MinMaxNormalizer.from_regular(data) | |
return normalizer, data, normalized | |
def test_acceptance(self, normalizer_and_data): | |
normalizer, data, normalized = normalizer_and_data | |
denormalized = normalizer.denormalize(normalized) | |
assert np.allclose(denormalized, data) | |
@pytest.mark.parametrize("regular, normalized", [ | |
(np.array([1, 2, 3, 4, 5]), np.array([0., 0.25, 0.5, 0.75, 1.])), | |
(np.array([10, 20, 30, 40, 50]), np.array([0., 0.25, 0.5, 0.75, 1.])), | |
(np.array([[1, 2], [3, 4]]), np.array([[0., 0.3333], [0.6666, 1.]])), | |
(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), np.array([[[0., 0.1428], [0.2857, 0.4285]], [[0.5714, 0.7142], [0.8571, 1.]]])), | |
]) | |
def test_normalize_and_back(self, regular, normalized): | |
normalizer, actual_normalized = MinMaxNormalizer.from_regular(regular) | |
assert np.allclose(actual_normalized, normalized, rtol=1e-3) | |
actual_denormalized = normalizer.denormalize(normalized) | |
assert np.allclose(actual_denormalized, regular, rtol=1e-3) | |
def test_from_regular(self, normalizer_and_data): | |
normalizer, data, normalized = normalizer_and_data | |
assert normalizer.min_val == 1 | |
assert normalizer.max_val == 5 | |
assert np.allclose(normalized, np.array([0., 0.25, 0.5, 0.75, 1.])) | |
def test_empty_array(self): | |
with pytest.raises(ValueError): | |
MinMaxNormalizer.from_regular(np.array([])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
... Or just use https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html