Last active
December 4, 2016 15:12
-
-
Save shriphani/d1c2a1ea6e35aeb2fa51 to your computer and use it in GitHub Desktop.
Caffe LFW Training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// This script converts the lfw dataset to the leveldb format used | |
// by caffe to train siamese network. | |
// Usage: | |
// convert_lfw_data input_image_file input_label_file output_db_file | |
#include <fstream> // NOLINT(readability/streams) | |
#include <string> | |
#include "glog/logging.h" | |
#include "google/protobuf/text_format.h" | |
#include "leveldb/db.h" | |
#include "stdint.h" | |
#include "caffe/proto/caffe.pb.h" | |
#include "caffe/util/math_functions.hpp" | |
using namespace std; | |
void read_image(std::ifstream* image_file, std::ifstream* label_file, | |
uint32_t index, uint32_t rows, uint32_t cols, | |
char* pixels, char* label) { | |
image_file->seekg(index * rows * cols + 12); | |
image_file->read(pixels, rows * cols); | |
image_file->read(pixels + (rows * cols), rows * cols); | |
label_file->seekg(index + 8); | |
label_file->read(label, 1); | |
} | |
void convert_dataset(const char* image_filename, const char* label_filename, | |
const char* db_filename) { | |
// Open files | |
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); | |
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary); | |
CHECK(image_file) << "Unable to open file " << image_filename; | |
CHECK(label_file) << "Unable to open file " << label_filename; | |
// Read the magic and the meta data | |
int32_t num_items; | |
int32_t num_labels; | |
uint32_t rows; | |
uint32_t cols; | |
image_file.read(reinterpret_cast<char*>(&num_items), 4); | |
label_file.read(reinterpret_cast<char*>(&num_labels), 4); | |
CHECK_EQ(num_items, num_labels); | |
image_file.read(reinterpret_cast<char*>(&rows), 4); | |
image_file.read(reinterpret_cast<char*>(&cols), 4); | |
cout << num_items << ", " << num_labels << endl; | |
cout << rows << ", " << cols << endl; | |
// Open leveldb | |
leveldb::DB* db; | |
leveldb::Options options; | |
options.create_if_missing = true; | |
options.error_if_exists = true; | |
leveldb::Status status = leveldb::DB::Open( | |
options, db_filename, &db); | |
CHECK(status.ok()) << "Failed to open leveldb " << db_filename | |
<< ". Is it already existing?"; | |
char label_i; | |
char* pixels = new char[2 * rows * cols]; | |
const int kMaxKeyLength = 10; | |
char key[kMaxKeyLength]; | |
std::string value; | |
caffe::Datum datum; | |
datum.set_channels(2); // one channel for each image in the pair | |
datum.set_height(rows); | |
datum.set_width(cols); | |
LOG(INFO) << "A total of " << num_items << " items."; | |
LOG(INFO) << "Rows: " << rows << " Cols: " << cols; | |
for (int itemid = 0; itemid < num_items; ++itemid) { | |
read_image( &image_file, | |
&label_file, | |
itemid, | |
rows, | |
cols, | |
pixels, | |
&label_i ); | |
datum.set_data(pixels, 2*rows*cols); | |
datum.set_label(label_i); | |
datum.SerializeToString(&value); | |
snprintf(key, kMaxKeyLength, "%08d", itemid); | |
db->Put(leveldb::WriteOptions(), std::string(key), value); | |
} | |
delete db; | |
delete pixels; | |
} | |
int main(int argc, char** argv) { | |
if (argc != 4) { | |
printf("This script converts the MNIST dataset to the leveldb format used\n" | |
"by caffe to train a siamese network.\n" | |
"Usage:\n" | |
" convert_mnist_data input_image_file input_label_file " | |
"output_db_file\n" | |
"The MNIST dataset could be downloaded at\n" | |
" http://yann.lecun.com/exdb/mnist/\n" | |
"You should gunzip them after downloading.\n"); | |
} else { | |
google::InitGoogleLogging(argv[0]); | |
convert_dataset(argv[1], argv[2], argv[3]); | |
} | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The train/test net protocol buffer definition | |
net: "examples/lfw_siamese/lfw_siamese_train_test.prototxt" | |
# test_iter specifies how many forward passes the test should carry out. | |
# In the case of lfw_01, we have test batch size 100 and 6 test iterations, | |
# covering the full 600 testing images. | |
test_iter: 6 | |
# Carry out testing every 500 training iterations. | |
test_interval: 500 | |
# The base learning rate, momentum and the weight decay of the network. | |
base_lr: 0.001 | |
momentum: 0.9 | |
weight_decay: 0.0000 | |
# The learning rate policy | |
lr_policy: "inv" | |
gamma: 0.0001 | |
power: 0.75 | |
# Display every 100 iterations | |
display: 100 | |
# The maximum number of iterations | |
max_iter: 50000 | |
# snapshot intermediate results | |
snapshot: 5000 | |
snapshot_prefix: "examples/siamese/mnist_siamese" | |
# solver mode: CPU or GPU | |
solver_mode: GPU |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: "lfw_siamese_train_test" | |
layer { | |
name: "pair_data" | |
type: "Data" | |
top: "pair_data" | |
top: "sim" | |
include { | |
phase: TRAIN | |
} | |
transform_param { | |
scale: 0.00390625 | |
} | |
data_param { | |
source: "examples/lfw_siamese/01_lfw_train_leveldb" | |
batch_size: 64 | |
} | |
} | |
layer { | |
name: "pair_data" | |
type: "Data" | |
top: "pair_data" | |
top: "sim" | |
include { | |
phase: TEST | |
} | |
transform_param { | |
scale: 0.00390625 | |
} | |
data_param { | |
source: "examples/lfw_siamese/01_lfw_test_leveldb" | |
batch_size: 100 | |
} | |
} | |
layer { | |
name: "slice_pair" | |
type: "Slice" | |
bottom: "pair_data" | |
top: "data" | |
top: "data_p" | |
slice_param { | |
slice_dim: 1 | |
slice_point: 1 | |
} | |
} | |
layer { | |
name: "conv1" | |
type: "Convolution" | |
bottom: "data" | |
top: "conv1" | |
param { | |
name: "conv1_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "conv1_b" | |
lr_mult: 2 | |
} | |
convolution_param { | |
num_output: 20 | |
kernel_size: 5 | |
stride: 1 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "pool1" | |
type: "Pooling" | |
bottom: "conv1" | |
top: "pool1" | |
pooling_param { | |
pool: MAX | |
kernel_size: 2 | |
stride: 2 | |
} | |
} | |
layer { | |
name: "conv2" | |
type: "Convolution" | |
bottom: "pool1" | |
top: "conv2" | |
param { | |
name: "conv2_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "conv2_b" | |
lr_mult: 2 | |
} | |
convolution_param { | |
num_output: 50 | |
kernel_size: 5 | |
stride: 1 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "pool2" | |
type: "Pooling" | |
bottom: "conv2" | |
top: "pool2" | |
pooling_param { | |
pool: MAX | |
kernel_size: 2 | |
stride: 2 | |
} | |
} | |
layer { | |
name: "ip1" | |
type: "InnerProduct" | |
bottom: "pool2" | |
top: "ip1" | |
param { | |
name: "ip1_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "ip1_b" | |
lr_mult: 2 | |
} | |
inner_product_param { | |
num_output: 500 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "relu1" | |
type: "ReLU" | |
bottom: "ip1" | |
top: "ip1" | |
} | |
layer { | |
name: "ip2" | |
type: "InnerProduct" | |
bottom: "ip1" | |
top: "ip2" | |
param { | |
name: "ip2_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "ip2_b" | |
lr_mult: 2 | |
} | |
inner_product_param { | |
num_output: 10 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "feat" | |
type: "InnerProduct" | |
bottom: "ip2" | |
top: "feat" | |
param { | |
name: "feat_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "feat_b" | |
lr_mult: 2 | |
} | |
inner_product_param { | |
num_output: 2 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "conv1_p" | |
type: "Convolution" | |
bottom: "data_p" | |
top: "conv1_p" | |
param { | |
name: "conv1_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "conv1_b" | |
lr_mult: 2 | |
} | |
convolution_param { | |
num_output: 20 | |
kernel_size: 5 | |
stride: 1 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "pool1_p" | |
type: "Pooling" | |
bottom: "conv1_p" | |
top: "pool1_p" | |
pooling_param { | |
pool: MAX | |
kernel_size: 2 | |
stride: 2 | |
} | |
} | |
layer { | |
name: "conv2_p" | |
type: "Convolution" | |
bottom: "pool1_p" | |
top: "conv2_p" | |
param { | |
name: "conv2_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "conv2_b" | |
lr_mult: 2 | |
} | |
convolution_param { | |
num_output: 50 | |
kernel_size: 5 | |
stride: 1 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "pool2_p" | |
type: "Pooling" | |
bottom: "conv2_p" | |
top: "pool2_p" | |
pooling_param { | |
pool: MAX | |
kernel_size: 2 | |
stride: 2 | |
} | |
} | |
layer { | |
name: "ip1_p" | |
type: "InnerProduct" | |
bottom: "pool2_p" | |
top: "ip1_p" | |
param { | |
name: "ip1_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "ip1_b" | |
lr_mult: 2 | |
} | |
inner_product_param { | |
num_output: 500 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "relu1_p" | |
type: "ReLU" | |
bottom: "ip1_p" | |
top: "ip1_p" | |
} | |
layer { | |
name: "ip2_p" | |
type: "InnerProduct" | |
bottom: "ip1_p" | |
top: "ip2_p" | |
param { | |
name: "ip2_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "ip2_b" | |
lr_mult: 2 | |
} | |
inner_product_param { | |
num_output: 10 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "feat_p" | |
type: "InnerProduct" | |
bottom: "ip2_p" | |
top: "feat_p" | |
param { | |
name: "feat_w" | |
lr_mult: 1 | |
} | |
param { | |
name: "feat_b" | |
lr_mult: 2 | |
} | |
inner_product_param { | |
num_output: 2 | |
weight_filler { | |
type: "xavier" | |
} | |
bias_filler { | |
type: "constant" | |
} | |
} | |
} | |
layer { | |
name: "loss" | |
type: "ContrastiveLoss" | |
bottom: "feat" | |
bottom: "feat_p" | |
bottom: "sim" | |
top: "loss" | |
contrastive_loss_param { | |
margin: 1 | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env sh | |
TOOLS=./build/tools | |
$TOOLS/caffe train --solver=examples/lfw_siamese/lfw_siamese_solver.prototxt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I download your convert_lfw_siamese_data.cpp. But i don't know how to create training database. How sign label for image?