Skip to content

Instantly share code, notes, and snippets.

@sdejean28
Created January 23, 2021 15:53
Show Gist options
  • Save sdejean28/267d42a50d1d48b5da813ae3d5b2a6c6 to your computer and use it in GitHub Desktop.
Save sdejean28/267d42a50d1d48b5da813ae3d5b2a6c6 to your computer and use it in GitHub Desktop.
using Plots
using Flux
using ColorSchemes
using NNlib
using Flux: @epochs
m = Dense(1,1)
loss(x, y) = sum((m(x).-y).^2)
dataset = [([0.8], [1.0]),
([2.0], [3.0]),
([2.4], [2.0]),
([0], [0.5]),
([1.5], [2]),
([3], [2.5]),
([4.0], [1.5])]
plot(dataset, seriestype = :scatter, legend = false)
N = 50
for j in 1:N
m = Chain(Dense(1,1))
for i in 1:10*j Flux.train!(loss, Flux.params(m), dataset, Descent(0.01)) end
x = 0:5;
y = zeros(6);
for i in 1:6
y[i] = m([i-1])[1]
end
if j < N
plot!(x, y, linecolor = get(ColorSchemes.Blues_8, j/N));
else
plot!(x, y, linecolor = get(ColorSchemes.Reds_6, 0.8), linewidth = 5);
end
end
display(plot!(dataset, seriestype = :scatter, legend = false))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment