Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save goldsborough/b840957bcddf7d4e91c8e9a20750e543 to your computer and use it in GitHub Desktop.
Save goldsborough/b840957bcddf7d4e91c8e9a20750e543 to your computer and use it in GitHub Desktop.
AUTOGRAD_CONTAINER_CLASS(DefoggerModel) {
// Multi level LSTM model from starcraft_defogger.
public:
struct Parameters {
int map_embsize = 64;
int race_embsize = 8;
int dec_convsize = 3;
int dec_depth = 3;
int dec_embsize = 128;
int hid_dim = 256;
float lstm_dropout = 0;
std::string top_pooling = "mean";
// simple kwargs
bool bypass_encoder = false;
int enc_convsize = 3;
int enc_embsize = 256;
int enc_depth = 3;
int inp_embsize = 256;
bool predict_delta = false;
agutils::UpsampleMode upsample = agutils::UpsampleMode::Bilinear;
// multilvl_lstm kwargs
int midconv_kw = 3;
int midconv_stride = 2;
int midconv_depth = 2;
int n_lvls = 2;
std::string model_name = "multilvl_lstm";
};
struct Builder {
BUILDER(DefoggerModel);
BUILDER_ARG(int, map_embsize);
BUILDER_ARG(int, race_embsize);
BUILDER_ARG(int, dec_convsize);
BUILDER_ARG(int, dec_depth);
BUILDER_ARG(int, dec_embsize);
BUILDER_ARG(int, hid_dim);
BUILDER_ARG(float, lstm_dropout);
BUILDER_ARG(std::string, top_pooling);
BUILDER_ARG(bool, bypass_encoder);
BUILDER_ARG(int, enc_convsize);
BUILDER_ARG(int, enc_embsize);
BUILDER_ARG(int, enc_depth);
BUILDER_ARG(int, inp_embsize);
BUILDER_ARG(bool, predict_delta);
BUILDER_ARG(agutils::UpsampleMode, upsample);
BUILDER_ARG(int, midconv_kw);
BUILDER_ARG(int, midconv_stride);
BUILDER_ARG(int, midconv_depth);
BUILDER_ARG(int, n_lvls);
BUILDER_ARG(std::string, model_name);
};
DefoggerModel(
conv_builder conv,
nonlin_type nonlin,
int kernel_size,
int n_inp_feats,
int stride)
: conv_(conv),
nonlin_(nonlin),
kernel_size_(kernel_size),
n_inp_feats_(n_inp_feats),
stride_(stride){};
void initialize_containers() override;
// Reset the hidden state (to call before each game)
void zero_hidden();
torch::variable_list forward(torch::variable_list input) override;
// Load all parameters from the python ones.
void load_parameters(std::string const& path_to_npz);
void cpu() override;
void cuda() override;
protected:
void repackage_hidden();
torch::Variable encode(torch::Variable x);
torch::Variable do_rnn_middle(torch::Variable x, at::IntList sz, int i);
torch::Variable pooling(torch::Variable x, std::string method = "");
torch::variable_list trunk_encode_pool(torch::variable_list input);
torch::Variable do_rnn(
torch::Variable x, at::IntList size, torch::Variable & hidden);
torch::variable_list do_heads(torch::Variable x);
torch::variable_list forward_rest(torch::variable_list input);
conv_builder conv_;
nonlin_type nonlin_;
torch::Container trunk_; // Map/race featurizer
torch::Container sum_pool_embed_;
torch::Container conv1x1_;
std::vector<torch::Container> midnets_;
std::vector<torch::Container> midrnns_;
torch::Container rnn_;
torch::Container decoder_;
torch::Container regression_head_;
torch::Container unit_class_head_;
torch::Container bldg_class_head_;
torch::Container opbt_class_head_;
torch::variable_list append_to_decoder_input_;
std::vector<torch::Variable> hidden_;
at::IntList input_sz_;
int lstm_nlayers_;
int kernel_size_;
int n_inp_feats_;
int stride_;
public:
// Global variable holding the activations of the python model for easy
// comparisons everywhere in the code. Ugly but only temporary (and avoids
// modifying all function signatures).
static std::unique_ptr<cnpy::npz_t> layers;
// Global variable used by external containers when comparing activations.
static std::string prefix;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment