Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save pierrelouisbescond/2ee07df32ee784b994ebf360f887447b to your computer and use it in GitHub Desktop.
Save pierrelouisbescond/2ee07df32ee784b994ebf360f887447b to your computer and use it in GitHub Desktop.
import pandas as pd
import numpy as np
import plotly.graph_objects as go
# Let's start by creating our index
dataset_size = 1000
idx = np.linspace(0,20, dataset_size)
# x1, x2 have a cyclical behavior, quite close from each other
x1 = np.cos(idx) + 0.2 * np.random.random(dataset_size)
x2 = np.cos(idx) - 0.5 + np.random.random(dataset_size)
# We initiate the DataFrame
df = pd.DataFrame({"x1":x1, "x2":x2,}, index=idx)
# x3 - our third feature - is a linear interpolation from x1 low and high peaks
df["x3"] = df["x1"].where((df["x1"] > 0.9) | (df["x1"] < -0.9))
df["x3"] = df["x3"].interpolate()- 0.1 + 0.2 * np.random.random(dataset_size)
# the y target is a combination of all x features
df["y"] = 0.5*df["x3"]+0.3*df["x2"]+0.2*df["x1"]
# Let's plot this to visualize the obvious relationship between all features
fig = go.Figure()
fig.add_trace(go.Scatter(x=df.index, y=df.x1, name="x1"))
fig.add_trace(go.Scatter(x=df.index, y=df.x2, name="x2"))
fig.add_trace(go.Scatter(x=df.index, y=df.x3, name="x3"))
fig.add_trace(go.Scatter(x=df.index, y=df.y, name="y"))
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment