Created
September 20, 2022 14:14
-
-
Save dermesser/f16ee63c365d0876fc44cbb67ce6e181 to your computer and use it in GitHub Desktop.
An experiment classifying (separating) words that look like names from non-names. RNN Classifier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MKL | |
using DataFrames | |
using Flux | |
import ChainRulesCore: ignore_derivatives | |
import Distributions: Bernoulli | |
import CSV | |
import Random: Sampler | |
import BSON | |
import Flux.MLUtils: DataLoader | |
const ENC_RANGE = vcat(collect('a':'z'), ' ') | |
function generate_fake_data(out_csv="fake.csv"; n=2500, maxlen=5) | |
sr = ENC_RANGE[begin:end-1] | |
names = fill("", n) | |
for i in 1:n | |
l = abs(rand(Int)) % maxlen + 5 | |
names[i] = Random.randstring('a':'z', l) | |
end | |
CSV.write(out_csv, DataFrame(:name => names)) | |
end | |
function encode_string_onehot(s::AbstractString; padto=0)::Matrix{Float32} | |
Flux.onehotbatch(s, ENC_RANGE) | |
end | |
function decode_string_onehot(m::AbstractMatrix)::String | |
String(Flux.onecold(m, ENC_RANGE)) | |
end | |
function prepare_training_data(namesfile="names.csv", fakefile="fake.csv"; train_frac=0.7)::Tuple{Tuple{DataFrame, DataFrame}, Tuple{DataFrame, DataFrame}} | |
names = DataFrame(CSV.File(namesfile)) | |
fakes = DataFrame(CSV.File(fakefile)) | |
select!(names, :name) | |
all = vcat(names, fakes) | |
select!(all, :name => (x -> encode_string_onehot.(x)) => :name) | |
train_selector = rand(Bernoulli(train_frac), nrow(all)) | |
test_selector = .~ train_selector | |
((all[train_selector, :], DataFrame(:label => vcat(ones(Bool, nrow(names)), zeros(Bool, nrow(fakes)))[train_selector])), | |
(all[test_selector, :], DataFrame(:label => vcat(ones(Bool, nrow(names)), zeros(Bool, nrow(fakes)))[test_selector]))) | |
end | |
function classifier_model() | |
N_in = length(ENC_RANGE) | |
model = Chain( | |
RNN(N_in => 13, Flux.σ), | |
Dense(13 => 1, Flux.σ) | |
) | |
model | |
end | |
function save_model(model, filename="model.bson") | |
Flux.reset!(model) | |
BSON.@save filename model | |
end | |
function load_model(filename="model.bson") | |
BSON.@load filename model | |
model | |
end | |
function apply_model(model, word)::Float64 | |
a = 0. | |
ignore_derivatives() do | |
Flux.reset!(model) | |
end | |
for col in eachcol(word) | |
a = model(col)[1] | |
end | |
a | |
end | |
function evaluate(model, data)::Vector | |
[apply_model(model, word) for word in data] | |
end | |
function eloss(model, data, labels) | |
pred = evaluate(model, data) | |
Flux.Losses.mse(labels, pred) | |
end | |
function accuracy(model, data, labels) | |
predictions = evaluate(model, data) | |
predictions = round.(Int, predictions) | |
sum(predictions .== labels)/length(predictions) | |
end | |
""" | |
Expects a model as returned by classifier_model(), | |
and data frames `data` and `labels` with columns respectively `name` and `label`. | |
data.name should be Float32 data, labels.label should be Bool or Int or Float. | |
""" | |
function train_classifier(model, data, labels; savemodel="trained.bson", epochs=10, batchsize=1) | |
dl = DataLoader((data.name, labels.label); batchsize=batchsize, shuffle=true) | |
p = Flux.params(model) | |
loss(d, l) = begin | |
eloss(model, d, l) | |
end | |
opt = Flux.ADAM() | |
for i in 1:epochs | |
count = 0 | |
cb() = begin | |
if div(count, 1000) < div((count+batchsize), 1000) | |
print(" $(count+batchsize)") | |
end | |
count += batchsize | |
end | |
Flux.train!(loss, p, dl, opt; cb=cb) | |
#for t in dl | |
# d, l = t | |
# grads = Flux.gradient(() -> loss(d, l), p) | |
# Flux.Optimise.update!(opt, p, grads) | |
# cb() | |
#end | |
if !isnothing(savemodel) | |
save_model(model, savemodel) | |
end | |
println("\nEpoch $i: Accuracy $(accuracy(model, data.name, labels.label)), Loss $(loss(data.name, labels.label))") | |
end | |
model | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment