Created
May 3, 2018 21:52
-
-
Save conormm/d03a04ebb940a80e163583948c1b3f46 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import make_moons | |
import pandas as pd | |
import numpy as np | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
import torch.nn.functional as F | |
def to_categorical(y, num_classes): | |
"""1-hot encodes a tensor""" | |
return np.eye(num_classes, dtype='uint8')[y] | |
X, y = make_moons(n_samples=1000, noise=.1) | |
y = to_categorical(y_, 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment