Skip to content

Instantly share code, notes, and snippets.

@artemklevtsov
Last active November 25, 2018 08:29
Show Gist options
  • Save artemklevtsov/2abb0ba3d822d9c149e9cf7cc1f8f446 to your computer and use it in GitHub Desktop.
Save artemklevtsov/2abb0ba3d822d9c149e9cf7cc1f8f446 to your computer and use it in GitHub Desktop.
Quick, Draw! Doodle Recognition Challenge raw data processing
// [[Rcpp::plugins(cpp11)]]
// [[Rcpp::plugins(opencv)]]
// [[Rcpp::depends(rapidjsonr)]]
// [[Rcpp::depends(RcppThread)]]
#include <rapidjson/document.h>
#include <opencv2/opencv.hpp>
#include <Rcpp.h>
#include <RcppThread.h>
// Сиинонимы для типов
using PointsVec = std::vector<cv::Point>;
using StrokesVec = std::vector<PointsVec>;
using StringVec = std::vector<std::string>;
using RcppThread::parallelFor;
// Статические константы
// Разамер изображения в пикселях
const static int SIZE = 256;
// Тип линии
// См. https://docs.opencv.org/3.4.4/d0/de1/group__core.html#gaf076ef45de481ac96e0ab3dc2c29a777
const static int LINE_TYPE = cv::LINE_4;
// Толщина линии в пикселях
const static int LINE_WIDTH = 3;
// Алгоритм ресайза
// https://docs.opencv.org/3.4.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
const static int RESIZE_TYPE = cv::INTER_LINEAR;
// Параметры сжатия PNG
// https://docs.opencv.org/3.4.4/d4/da8/group__imgcodecs.html#gga292d81be8d76901bff7988d18d2b42acad2548321c69ab9c0582fd51e75ace1d0
const static std::vector<int> compression_params = {CV_IMWRITE_PNG_COMPRESSION, 5};
// Проверка структуры JSON
void check_json_array(const rapidjson::Value& arr) {
if (arr.Size() != 2) {
throw std::runtime_error("Element must be 2-dimensional array.");
}
if (!arr[0].IsArray() || !arr[1].IsArray()) {
throw std::runtime_error("One of the element is not array.");
}
if (arr[0].Size() != arr[1].Size()) {
throw std::runtime_error("Size of the arrays not equal.");
}
}
// Парсинг JSON
StrokesVec extract_points(const std::string& json) {
rapidjson::Document doc;
doc.Parse(json.c_str());
if (doc.HasParseError()) {
throw std::runtime_error("JSON string parsing error.");
}
if (!doc.IsArray()) {
throw std::runtime_error("JSON string is not array.");
}
// Количество штрихов
std::size_t n_strokes = doc.Size();
// Счётчик для количества точек
std::size_t total_points = 0;
// Результирующий объект
// Создаётся в «куче» для передачи в Rcpp::Xptr
StrokesVec strokes;
strokes.reserve(n_strokes);
for (std::size_t i = 0; i < n_strokes; ++i) {
const rapidjson::Value& stroke = doc[i];
check_json_array(stroke);
std::size_t n_points = stroke[0].Size();
total_points += n_points;
PointsVec tmp;
tmp.reserve(n_points);
for (std::size_t p = 0; p < n_points; ++p) {
// Первый массив - y; Второй массив - x;
tmp.emplace_back(stroke[0][p].GetDouble(), stroke[1][p].GetDouble());
}
strokes.emplace_back(tmp);
}
return strokes;
}
// Отрисовка линий
// Цвета HSV
cv::Mat ocv_draw_lines(const StrokesVec& strokes, bool color = false) {
auto stype = color ? CV_8UC3 : CV_8UC1;
auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255);
auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0);
cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg);
size_t n = strokes.size();
for (const PointsVec& stroke: strokes) {
// Количество точек
size_t n_points = stroke.size();
for (size_t i = 0; i < n_points - 1; ++i) {
cv::line(img, stroke[i], stroke[i + 1], col, LINE_WIDTH, LINE_TYPE);
}
if (color) {
col[0] += 180 / n;
}
}
if (color) {
// Менеяем цветовое представление на BGR
cv::cvtColor(img, img, cv::COLOR_HSV2BGR);
}
return img;
}
// Обработка JSON и получение тензора с данными изоражения
cv::Mat process_json(const std::string& x, double scale = 1.0, bool color = false) {
StrokesVec p = extract_points(x);
cv::Mat img = ocv_draw_lines(p, color);
if (scale != 1.0) {
cv::Mat out;
cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE);
cv::swap(img, out);
out.release();
}
return img;
}
// [[Rcpp::export]]
StringVec json_vec_save(const StringVec& x, const StringVec& files, double scale = 1.0, bool color = false) {
int n = x.size();
parallelFor(0, n, [&x, &files, scale, color](int i) {
cv::Mat img = process_json(x[i], scale, color);
cv::imwrite(files[i], img, compression_params);
});
return files;
}
#!/usr/bin/env Rscript
## ---- Парсинг аргументов командной строки ----
doc <- '
Usage:
train_nn.R --help
train_nn.R [options]
Options:
-h --help Show this message.
-i --input-file=<file> Input file [default: data/train_simplified.zip].
-o --output-dir=<dir> Output directory [default: data/images].
-m --mapping-file=<file> Mapping file [default: data/maping_images.csv]
-s --scale-factor=<ratio> Scale factor [default: 1.0].
-c --color Use color lines [default: FALSE].
--log-level=<level> Log level [default: debug].
'
args <- docopt::docopt(doc)
if (args[["help"]]) {
cat(doc, file = stdout())
quit(save = "no", status = 0L)
}
## ---- Константы ----
# коэффициент ресайза изображений
scale_factor <- as.double(args[["scale-factor"]])
# Использование цвета
color <- isTRUE(args[["color"]])
# Zip-архив
zipfile <- args[["input-file"]]
outdir <- args[["output-dir"]]
mapping_file <- args[["mapping-file"]]
image_ext <- ".png"
## ---- Инициализация логгера ----
suppressMessages(library(futile.logger))
# Формат лога
logger_format <- layout.format("~t [~l]: ~m", "%Y-%m-%d %H:%M:%OS")
# Применение формата лога
invisible(flog.layout(logger_format))
# Перенаправление лога в stdout
invisible(flog.appender(appender.console()))
# Устанавливаем уровень логирования
invisible(flog.threshold(toupper(args[["log-level"]])))
## ---- Установка пакетов -----
pkgs <- c(
"checkmate",
"data.table",
"Rcpp",
"RcppThread",
"rapidjsonr"
)
to_inst <- setdiff(pkgs, rownames(installed.packages(fields = "Package")))
if (length(to_inst) > 0L) {
flog.info("Install required packages")
install.packages(pkgs, quiet = TRUE)
}
## ---- Загрузка пакетов ----
flog.info("Load required packages")
for (pkg in pkgs) {
suppressMessages(library(pkg, character.only = TRUE))
}
## ---- Провекра аргументов ----
assert_file_exists(zipfile, access = "r", extension = "zip", .var.name = "input-file")
assert_path_for_output(mapping_file, overwrite = TRUE, .var.name = "mapping-file")
assert_number(scale_factor, lower = 0.01, upper = 5, na.ok = FALSE, finite = TRUE, .var.name = "scale-factor")
assert_flag(color, na.ok = FALSE, .var.name = "color")
## ---- Компиляция C++ функций ----
flog.info("Compile C++ code")
# OpenCV функции
registerPlugin("opencv", function() {
pkg_config_name <- "opencv"
pkg_config_bin <- Sys.which("pkg-config")
assert_file_exists(pkg_config_bin, access = "x")
list(env = list(
PKG_CXXFLAG = system(paste(pkg_config_bin, "--cflags", pkg_config_name), intern = TRUE),
PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), intern = TRUE)
))
})
sourceCpp(file = "cv.cpp", env = .GlobalEnv)
## ---- Выгрузка данных ----
process <- function(filename) {
path <- file.path(tempdir(), filename)
on.exit(unlink(file.path(path)))
flog.debug("Unzip file")
unzip(zipfile, files = filename, exdir = tempdir(), junkpaths = TRUE, unzip = getOption("unzip"))
data <- fread(file = path, sep = ",", header = TRUE, select = c("key_id", "drawing", "word"))
flog.debug("Read data")
data[, filename := file.path(outdir, paste0(key_id, image_ext))]
flog.debug("Write images")
data[, json_vec_save(drawing, filename, scale_factor, color)]
data[, drawing := NULL]
flog.debug("Write mapping")
fwrite(x = data, file = mapping_file, append = TRUE, sep = ",", eol = "\n")
}
if (!dir.exists(outdir)) {
flog.debug("Create '%s' directory", outdir)
dir.create(outdir)
}
flog.info("Load files from '%s'", zipfile)
files <- unzip(zipfile, list = TRUE)$Name
for (f in files) {
flog.info("Process '%s'", f)
process(f)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment