Created
November 28, 2018 06:49
-
-
Save marty1885/b94225e60a6c067b6831be57ad1c6fd6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
xt::xarray<float> compute(xt::xarray<float> input) | |
{ | |
assert(input.size() == 3); | |
//save data for traning | |
if(last_input_.size() != 0) { | |
for(auto v : last_input_) | |
input_.push_back(v); | |
for(auto v : input) | |
output_.push_back(v); | |
} | |
last_input_ = vec_t(input.begin(), input.end()); | |
//Train once all needed data collected | |
if(input_.size() == RNN_DATA_PER_EPOCH) { | |
assert(input_.size() == output_.size()); | |
//Set the netwotk into a "traning more" | |
nn.at<recurrent_layer>(0).seq_len(RNN_DATA_PER_EPOCH); | |
nn.set_netphase(net_phase::train); | |
nn.fit<cross_entropy_multiclass>(optimizer_, std::vector<vec_t>({input_}),std::vector<vec_t>({output_}), 1, 1, [](){},[](){}); | |
//Leave the "leaning" mode. Keep on predicting | |
nn.set_netphase(net_phase::test); | |
nn.at<recurrent_layer>(0).seq_len(1); | |
input_.clear(); | |
output_.clear(); | |
} | |
//Predict the opponent's next mvoe | |
vec_t out = nn_.predict(vec_t(input.begin(), input.end())); | |
assert(out.size() == 3); | |
//Convert the prediction to xarray | |
xt::xarray<float> r = xt::zeros<float>({3}); | |
for(size_t i=0;i<out.size();i++) | |
r[i] = out[i]; | |
return r; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment