Created
May 8, 2018 21:01
-
-
Save goldsborough/b840957bcddf7d4e91c8e9a20750e543 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
AUTOGRAD_CONTAINER_CLASS(DefoggerModel) { | |
// Multi level LSTM model from starcraft_defogger. | |
public: | |
struct Parameters { | |
int map_embsize = 64; | |
int race_embsize = 8; | |
int dec_convsize = 3; | |
int dec_depth = 3; | |
int dec_embsize = 128; | |
int hid_dim = 256; | |
float lstm_dropout = 0; | |
std::string top_pooling = "mean"; | |
// simple kwargs | |
bool bypass_encoder = false; | |
int enc_convsize = 3; | |
int enc_embsize = 256; | |
int enc_depth = 3; | |
int inp_embsize = 256; | |
bool predict_delta = false; | |
agutils::UpsampleMode upsample = agutils::UpsampleMode::Bilinear; | |
// multilvl_lstm kwargs | |
int midconv_kw = 3; | |
int midconv_stride = 2; | |
int midconv_depth = 2; | |
int n_lvls = 2; | |
std::string model_name = "multilvl_lstm"; | |
}; | |
struct Builder { | |
BUILDER(DefoggerModel); | |
BUILDER_ARG(int, map_embsize); | |
BUILDER_ARG(int, race_embsize); | |
BUILDER_ARG(int, dec_convsize); | |
BUILDER_ARG(int, dec_depth); | |
BUILDER_ARG(int, dec_embsize); | |
BUILDER_ARG(int, hid_dim); | |
BUILDER_ARG(float, lstm_dropout); | |
BUILDER_ARG(std::string, top_pooling); | |
BUILDER_ARG(bool, bypass_encoder); | |
BUILDER_ARG(int, enc_convsize); | |
BUILDER_ARG(int, enc_embsize); | |
BUILDER_ARG(int, enc_depth); | |
BUILDER_ARG(int, inp_embsize); | |
BUILDER_ARG(bool, predict_delta); | |
BUILDER_ARG(agutils::UpsampleMode, upsample); | |
BUILDER_ARG(int, midconv_kw); | |
BUILDER_ARG(int, midconv_stride); | |
BUILDER_ARG(int, midconv_depth); | |
BUILDER_ARG(int, n_lvls); | |
BUILDER_ARG(std::string, model_name); | |
}; | |
DefoggerModel( | |
conv_builder conv, | |
nonlin_type nonlin, | |
int kernel_size, | |
int n_inp_feats, | |
int stride) | |
: conv_(conv), | |
nonlin_(nonlin), | |
kernel_size_(kernel_size), | |
n_inp_feats_(n_inp_feats), | |
stride_(stride){}; | |
void initialize_containers() override; | |
// Reset the hidden state (to call before each game) | |
void zero_hidden(); | |
torch::variable_list forward(torch::variable_list input) override; | |
// Load all parameters from the python ones. | |
void load_parameters(std::string const& path_to_npz); | |
void cpu() override; | |
void cuda() override; | |
protected: | |
void repackage_hidden(); | |
torch::Variable encode(torch::Variable x); | |
torch::Variable do_rnn_middle(torch::Variable x, at::IntList sz, int i); | |
torch::Variable pooling(torch::Variable x, std::string method = ""); | |
torch::variable_list trunk_encode_pool(torch::variable_list input); | |
torch::Variable do_rnn( | |
torch::Variable x, at::IntList size, torch::Variable & hidden); | |
torch::variable_list do_heads(torch::Variable x); | |
torch::variable_list forward_rest(torch::variable_list input); | |
conv_builder conv_; | |
nonlin_type nonlin_; | |
torch::Container trunk_; // Map/race featurizer | |
torch::Container sum_pool_embed_; | |
torch::Container conv1x1_; | |
std::vector<torch::Container> midnets_; | |
std::vector<torch::Container> midrnns_; | |
torch::Container rnn_; | |
torch::Container decoder_; | |
torch::Container regression_head_; | |
torch::Container unit_class_head_; | |
torch::Container bldg_class_head_; | |
torch::Container opbt_class_head_; | |
torch::variable_list append_to_decoder_input_; | |
std::vector<torch::Variable> hidden_; | |
at::IntList input_sz_; | |
int lstm_nlayers_; | |
int kernel_size_; | |
int n_inp_feats_; | |
int stride_; | |
public: | |
// Global variable holding the activations of the python model for easy | |
// comparisons everywhere in the code. Ugly but only temporary (and avoids | |
// modifying all function signatures). | |
static std::unique_ptr<cnpy::npz_t> layers; | |
// Global variable used by external containers when comparing activations. | |
static std::string prefix; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment