Created
December 19, 2018 20:30
-
-
Save JoshVarty/143aa35c0efc25d29d18ac523fbb597c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstddef> | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
#include <torch/torch.h> | |
torch::nn::Conv2d conv3x3(int64_t inputChannels, int64_t outputChannels, int64_t stride) { | |
auto options = torch::nn::Conv2dOptions(inputChannels, outputChannels, /*kernel_size=*/3); | |
options = options.stride(stride).padding(1).with_bias(false); | |
return std::make_shared<torch::nn::Conv2dImpl>(options); | |
} | |
torch::nn::Conv2d conv1x1(int64_t inputChannels, int64_t outputChannels, int64_t stride) { | |
auto options = torch::nn::Conv2dOptions(inputChannels, outputChannels, /*kernel_size=*/1); | |
options = options.stride(stride).with_bias(false); | |
return std::make_shared<torch::nn::Conv2dImpl>(options); | |
} | |
struct BasicBlock : torch::nn::Module { | |
static const int64_t EXPANSION = 1; | |
torch::nn::Conv2d conv1; | |
torch::nn::BatchNorm bn1; | |
torch::nn::Conv2d conv2; | |
torch::nn::BatchNorm bn2; | |
torch::nn::Sequential downsample; | |
BasicBlock(int64_t inplanes, int64_t planes, int64_t stride, torch::nn::Sequential downsample) | |
: conv1(conv3x3(inplanes, planes, stride)), | |
bn1(torch::nn::BatchNorm(torch::nn::BatchNormOptions(planes))), | |
conv2(conv3x3(planes, planes, /*stride*/1)), | |
bn2(torch::nn::BatchNorm(torch::nn::BatchNormOptions(planes))), | |
downsample(downsample) | |
{ | |
register_module("conv1", conv1); | |
register_module("bn1", bn1); | |
register_module("conv2", conv2); | |
register_module("bn2", bn2); | |
register_module("downsample", downsample); | |
} | |
torch::Tensor forward(torch::Tensor x) { | |
auto identity = x; | |
auto out = conv1->forward(x); | |
out = bn1->forward(out); | |
out = torch::relu(out); | |
out = conv2->forward(out); | |
out = bn2->forward(out); | |
if (this->downsample.get()->size() > 0) { | |
identity = this->downsample->forward(x); | |
} | |
out = out + identity; | |
out = torch::relu(out); | |
return out; | |
} | |
}; | |
struct ResNet : torch::nn::Module { | |
int64_t inplanes = 64; | |
torch::nn::Conv2dOptions conv1Options; | |
torch::nn::Conv2d conv1; | |
torch::nn::BatchNorm bn1; | |
torch::nn::Sequential layer1; | |
torch::nn::Sequential layer2; | |
torch::nn::Sequential layer3; | |
torch::nn::Sequential layer4; | |
torch::nn::Linear fc1; | |
ResNet(int64_t inputDepth, int layers[]) | |
: | |
conv1Options(torch::nn::Conv2dOptions(inputDepth, 64, /*kernel_size=*/7).stride(2).padding(3).with_bias(false)), | |
conv1(std::make_shared<torch::nn::Conv2dImpl>(conv1Options)), | |
bn1(std::make_shared<torch::nn::BatchNormImpl>(64)), | |
layer1(make_layer_basic(64, layers[0], /*stride=*/1)), | |
layer2(make_layer_basic(128, layers[1], /*stride=*/2)), | |
layer3(make_layer_basic(256, layers[2], /*stride=*/2)), | |
layer4(make_layer_basic(512, layers[3], /*stride=*/2)), | |
fc1(std::make_shared<torch::nn::LinearImpl>(512, 9)) | |
{ | |
register_module("conv1", conv1); | |
register_module("bn1", bn1); | |
register_module("layer1", layer1); | |
register_module("layer2", layer2); | |
register_module("layer3", layer3); | |
register_module("layer4", layer4); | |
register_module("fc1", fc1); | |
} | |
torch::nn::Sequential make_layer_basic(int64_t planes, int64_t blocks, int64_t stride) { | |
torch::nn::Sequential downsample = std::make_shared<torch::nn::SequentialImpl>(); | |
if(stride != 1 || this->inplanes != planes * BasicBlock::EXPANSION) { | |
downsample = torch::nn::Sequential( | |
std::make_shared<torch::nn::SequentialImpl>( | |
conv1x1(this->inplanes, planes * BasicBlock::EXPANSION, stride), | |
torch::nn::BatchNorm(planes * BasicBlock::EXPANSION)) | |
); | |
} | |
torch::nn::Sequential layers = std::make_shared<torch::nn::SequentialImpl>(); | |
auto newBlock = std::make_shared<BasicBlock>(this->inplanes, planes, stride, downsample); | |
layers->push_back(newBlock); | |
this->inplanes = planes * BasicBlock::EXPANSION; | |
for(int64_t i = 0; i < blocks; i++) { | |
torch::nn::Sequential empty_downsample = std::make_shared<torch::nn::SequentialImpl>(); | |
newBlock = std::make_shared<BasicBlock>(this->inplanes, planes, /*stride=*/1, empty_downsample); | |
layers->push_back(newBlock); | |
} | |
return layers; | |
} | |
torch::Tensor forward(torch::Tensor x) { | |
x = this->conv1->forward(x); | |
x = this->bn1->forward(x); | |
x = torch::relu(x); | |
x = torch::max_pool2d(x, /*kernel_size*/{3}, /*stride*/{2}, /*padding*/{1}); | |
x = this->layer1->forward(x); | |
x = this->layer2->forward(x); | |
x = this->layer3->forward(x); | |
x = this->layer4->forward(x); | |
x = torch::adaptive_avg_pool2d(x, {1,1}); | |
x = x.view({-1, 512}); | |
auto logits = this->fc1->forward(x); | |
x = torch::softmax(logits, /*dim=*/1); | |
return x; | |
} | |
}; | |
struct Options { | |
std::string data_root{"data"}; | |
int32_t batch_size{64}; | |
int32_t epochs{10}; | |
double lr{0.01}; | |
double momentum{0.5}; | |
bool no_cuda{false}; | |
int32_t seed{1}; | |
int32_t test_batch_size{1000}; | |
int32_t log_interval{10}; | |
}; | |
template <typename DataLoader> | |
void train( | |
int32_t epoch, | |
const Options& options, | |
ResNet& model, | |
torch::Device device, | |
DataLoader& data_loader, | |
torch::optim::SGD& optimizer, | |
size_t dataset_size) { | |
model.train(); | |
size_t batch_idx = 0; | |
for (auto& batch : data_loader) { | |
auto data = batch.data.to(device), targets = batch.target.to(device); | |
optimizer.zero_grad(); | |
auto output = model.forward(data); | |
auto loss = torch::nll_loss(output, targets); | |
loss.backward(); | |
optimizer.step(); | |
if (batch_idx++ % options.log_interval == 0) { | |
std::cout << "Train Epoch: " << epoch << " [" | |
<< batch_idx * batch.data.size(0) << "/" << dataset_size | |
<< "]\tLoss: " << loss.template item<float>() << std::endl; | |
} | |
} | |
} | |
template <typename DataLoader> | |
void test( | |
ResNet& model, | |
torch::Device device, | |
DataLoader& data_loader, | |
size_t dataset_size) { | |
torch::NoGradGuard no_grad; | |
model.eval(); | |
double test_loss = 0; | |
int32_t correct = 0; | |
for (const auto& batch : data_loader) { | |
auto data = batch.data.to(device), targets = batch.target.to(device); | |
auto output = model.forward(data); | |
test_loss += torch::nll_loss( | |
output, | |
targets, | |
/*weight=*/{}, | |
Reduction::Sum) | |
.template item<float>(); | |
auto pred = output.argmax(1); | |
correct += pred.eq(targets).sum().template item<int64_t>(); | |
} | |
test_loss /= dataset_size; | |
std::cout << "Test set: Average loss: " << test_loss | |
<< ", Accuracy: " << correct << "/" << dataset_size << std::endl; | |
} | |
struct Normalize : public torch::data::transforms::TensorTransform<> { | |
Normalize(float mean, float stddev) | |
: mean_(torch::tensor(mean)), stddev_(torch::tensor(stddev)) {} | |
torch::Tensor operator()(torch::Tensor input) { | |
return input.sub_(mean_).div_(stddev_); | |
} | |
torch::Tensor mean_, stddev_; | |
}; | |
auto main(int argc, const char* argv[]) -> int { | |
torch::manual_seed(0); | |
Options options; | |
torch::DeviceType device_type; | |
if (torch::cuda::is_available() && !options.no_cuda) { | |
std::cout << "CUDA available! Training on GPU" << std::endl; | |
device_type = torch::kCUDA; | |
} else { | |
std::cout << "Training on CPU" << std::endl; | |
device_type = torch::kCPU; | |
} | |
torch::Device device(device_type); | |
int layers[4] = {2,2,2,2}; | |
auto model_shared = std::make_shared<ResNet>(1, layers); | |
auto model = model_shared.get(); | |
model->to(device); | |
auto train_dataset = | |
torch::data::datasets::MNIST( | |
options.data_root, torch::data::datasets::MNIST::Mode::kTrain) | |
.map(Normalize(0.1307, 0.3081)) | |
.map(torch::data::transforms::Stack<>()); | |
//const auto dataset_size = train_dataset.size(); | |
auto x = train_dataset.size(); | |
const auto dataset_size = 50000; | |
auto train_loader = torch::data::make_data_loader(std::move(train_dataset), options.batch_size); | |
auto test_loader = torch::data::make_data_loader( | |
torch::data::datasets::MNIST( | |
options.data_root, torch::data::datasets::MNIST::Mode::kTest) | |
.map(Normalize(0.1307, 0.3081)) | |
.map(torch::data::transforms::Stack<>()), | |
options.batch_size); | |
torch::optim::SGD optimizer( | |
model->parameters(), | |
torch::optim::SGDOptions(options.lr).momentum(options.momentum)); | |
for (size_t epoch = 1; epoch <= options.epochs; ++epoch) { | |
train(epoch, options, *model, device, *train_loader, optimizer, dataset_size); | |
test(*model, device, *test_loader, dataset_size); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment