Skip to content

Instantly share code, notes, and snippets.

@tsukanov-as
Forked from 45deg/gen.py
Created June 8, 2023 11:57
Show Gist options
  • Save tsukanov-as/503c59d7c37f45bba8c77455b129dbbb to your computer and use it in GitHub Desktop.
Save tsukanov-as/503c59d7c37f45bba8c77455b129dbbb to your computer and use it in GitHub Desktop.
Generating Spiral Dataset for Classifying in Python
import numpy as np
from numpy import pi
# import matplotlib.pyplot as plt
N = 400
theta = np.sqrt(np.random.rand(N))*2*pi # np.linspace(0,2*pi,100)
r_a = 2*theta + pi
data_a = np.array([np.cos(theta)*r_a, np.sin(theta)*r_a]).T
x_a = data_a + np.random.randn(N,2)
r_b = -2*theta - pi
data_b = np.array([np.cos(theta)*r_b, np.sin(theta)*r_b]).T
x_b = data_b + np.random.randn(N,2)
res_a = np.append(x_a, np.zeros((N,1)), axis=1)
res_b = np.append(x_b, np.ones((N,1)), axis=1)
res = np.append(res_a, res_b, axis=0)
np.random.shuffle(res)
np.savetxt("result.csv", res, delimiter=",", header="x,y,label", comments="", fmt='%.5f')
# plt.scatter(x_a[:,0],x_a[:,1])
# plt.scatter(x_b[:,0],x_b[:,1])
# plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment