Last active
November 25, 2018 08:29
-
-
Save artemklevtsov/2abb0ba3d822d9c149e9cf7cc1f8f446 to your computer and use it in GitHub Desktop.
Quick, Draw! Doodle Recognition Challenge raw data processing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// [[Rcpp::plugins(cpp11)]] | |
// [[Rcpp::plugins(opencv)]] | |
// [[Rcpp::depends(rapidjsonr)]] | |
// [[Rcpp::depends(RcppThread)]] | |
#include <rapidjson/document.h> | |
#include <opencv2/opencv.hpp> | |
#include <Rcpp.h> | |
#include <RcppThread.h> | |
// Сиинонимы для типов | |
using PointsVec = std::vector<cv::Point>; | |
using StrokesVec = std::vector<PointsVec>; | |
using StringVec = std::vector<std::string>; | |
using RcppThread::parallelFor; | |
// Статические константы | |
// Разамер изображения в пикселях | |
const static int SIZE = 256; | |
// Тип линии | |
// См. https://docs.opencv.org/3.4.4/d0/de1/group__core.html#gaf076ef45de481ac96e0ab3dc2c29a777 | |
const static int LINE_TYPE = cv::LINE_4; | |
// Толщина линии в пикселях | |
const static int LINE_WIDTH = 3; | |
// Алгоритм ресайза | |
// https://docs.opencv.org/3.4.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 | |
const static int RESIZE_TYPE = cv::INTER_LINEAR; | |
// Параметры сжатия PNG | |
// https://docs.opencv.org/3.4.4/d4/da8/group__imgcodecs.html#gga292d81be8d76901bff7988d18d2b42acad2548321c69ab9c0582fd51e75ace1d0 | |
const static std::vector<int> compression_params = {CV_IMWRITE_PNG_COMPRESSION, 5}; | |
// Проверка структуры JSON | |
void check_json_array(const rapidjson::Value& arr) { | |
if (arr.Size() != 2) { | |
throw std::runtime_error("Element must be 2-dimensional array."); | |
} | |
if (!arr[0].IsArray() || !arr[1].IsArray()) { | |
throw std::runtime_error("One of the element is not array."); | |
} | |
if (arr[0].Size() != arr[1].Size()) { | |
throw std::runtime_error("Size of the arrays not equal."); | |
} | |
} | |
// Парсинг JSON | |
StrokesVec extract_points(const std::string& json) { | |
rapidjson::Document doc; | |
doc.Parse(json.c_str()); | |
if (doc.HasParseError()) { | |
throw std::runtime_error("JSON string parsing error."); | |
} | |
if (!doc.IsArray()) { | |
throw std::runtime_error("JSON string is not array."); | |
} | |
// Количество штрихов | |
std::size_t n_strokes = doc.Size(); | |
// Счётчик для количества точек | |
std::size_t total_points = 0; | |
// Результирующий объект | |
// Создаётся в «куче» для передачи в Rcpp::Xptr | |
StrokesVec strokes; | |
strokes.reserve(n_strokes); | |
for (std::size_t i = 0; i < n_strokes; ++i) { | |
const rapidjson::Value& stroke = doc[i]; | |
check_json_array(stroke); | |
std::size_t n_points = stroke[0].Size(); | |
total_points += n_points; | |
PointsVec tmp; | |
tmp.reserve(n_points); | |
for (std::size_t p = 0; p < n_points; ++p) { | |
// Первый массив - y; Второй массив - x; | |
tmp.emplace_back(stroke[0][p].GetDouble(), stroke[1][p].GetDouble()); | |
} | |
strokes.emplace_back(tmp); | |
} | |
return strokes; | |
} | |
// Отрисовка линий | |
// Цвета HSV | |
cv::Mat ocv_draw_lines(const StrokesVec& strokes, bool color = false) { | |
auto stype = color ? CV_8UC3 : CV_8UC1; | |
auto bg = color ? cv::Scalar(0, 0, 255) : cv::Scalar(255); | |
auto col = color ? cv::Scalar(0, 255, 220) : cv::Scalar(0); | |
cv::Mat img = cv::Mat(SIZE, SIZE, stype, bg); | |
size_t n = strokes.size(); | |
for (const PointsVec& stroke: strokes) { | |
// Количество точек | |
size_t n_points = stroke.size(); | |
for (size_t i = 0; i < n_points - 1; ++i) { | |
cv::line(img, stroke[i], stroke[i + 1], col, LINE_WIDTH, LINE_TYPE); | |
} | |
if (color) { | |
col[0] += 180 / n; | |
} | |
} | |
if (color) { | |
// Менеяем цветовое представление на BGR | |
cv::cvtColor(img, img, cv::COLOR_HSV2BGR); | |
} | |
return img; | |
} | |
// Обработка JSON и получение тензора с данными изоражения | |
cv::Mat process_json(const std::string& x, double scale = 1.0, bool color = false) { | |
StrokesVec p = extract_points(x); | |
cv::Mat img = ocv_draw_lines(p, color); | |
if (scale != 1.0) { | |
cv::Mat out; | |
cv::resize(img, out, cv::Size(), scale, scale, RESIZE_TYPE); | |
cv::swap(img, out); | |
out.release(); | |
} | |
return img; | |
} | |
// [[Rcpp::export]] | |
StringVec json_vec_save(const StringVec& x, const StringVec& files, double scale = 1.0, bool color = false) { | |
int n = x.size(); | |
parallelFor(0, n, [&x, &files, scale, color](int i) { | |
cv::Mat img = process_json(x[i], scale, color); | |
cv::imwrite(files[i], img, compression_params); | |
}); | |
return files; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env Rscript | |
## ---- Парсинг аргументов командной строки ---- | |
doc <- ' | |
Usage: | |
train_nn.R --help | |
train_nn.R [options] | |
Options: | |
-h --help Show this message. | |
-i --input-file=<file> Input file [default: data/train_simplified.zip]. | |
-o --output-dir=<dir> Output directory [default: data/images]. | |
-m --mapping-file=<file> Mapping file [default: data/maping_images.csv] | |
-s --scale-factor=<ratio> Scale factor [default: 1.0]. | |
-c --color Use color lines [default: FALSE]. | |
--log-level=<level> Log level [default: debug]. | |
' | |
args <- docopt::docopt(doc) | |
if (args[["help"]]) { | |
cat(doc, file = stdout()) | |
quit(save = "no", status = 0L) | |
} | |
## ---- Константы ---- | |
# коэффициент ресайза изображений | |
scale_factor <- as.double(args[["scale-factor"]]) | |
# Использование цвета | |
color <- isTRUE(args[["color"]]) | |
# Zip-архив | |
zipfile <- args[["input-file"]] | |
outdir <- args[["output-dir"]] | |
mapping_file <- args[["mapping-file"]] | |
image_ext <- ".png" | |
## ---- Инициализация логгера ---- | |
suppressMessages(library(futile.logger)) | |
# Формат лога | |
logger_format <- layout.format("~t [~l]: ~m", "%Y-%m-%d %H:%M:%OS") | |
# Применение формата лога | |
invisible(flog.layout(logger_format)) | |
# Перенаправление лога в stdout | |
invisible(flog.appender(appender.console())) | |
# Устанавливаем уровень логирования | |
invisible(flog.threshold(toupper(args[["log-level"]]))) | |
## ---- Установка пакетов ----- | |
pkgs <- c( | |
"checkmate", | |
"data.table", | |
"Rcpp", | |
"RcppThread", | |
"rapidjsonr" | |
) | |
to_inst <- setdiff(pkgs, rownames(installed.packages(fields = "Package"))) | |
if (length(to_inst) > 0L) { | |
flog.info("Install required packages") | |
install.packages(pkgs, quiet = TRUE) | |
} | |
## ---- Загрузка пакетов ---- | |
flog.info("Load required packages") | |
for (pkg in pkgs) { | |
suppressMessages(library(pkg, character.only = TRUE)) | |
} | |
## ---- Провекра аргументов ---- | |
assert_file_exists(zipfile, access = "r", extension = "zip", .var.name = "input-file") | |
assert_path_for_output(mapping_file, overwrite = TRUE, .var.name = "mapping-file") | |
assert_number(scale_factor, lower = 0.01, upper = 5, na.ok = FALSE, finite = TRUE, .var.name = "scale-factor") | |
assert_flag(color, na.ok = FALSE, .var.name = "color") | |
## ---- Компиляция C++ функций ---- | |
flog.info("Compile C++ code") | |
# OpenCV функции | |
registerPlugin("opencv", function() { | |
pkg_config_name <- "opencv" | |
pkg_config_bin <- Sys.which("pkg-config") | |
assert_file_exists(pkg_config_bin, access = "x") | |
list(env = list( | |
PKG_CXXFLAG = system(paste(pkg_config_bin, "--cflags", pkg_config_name), intern = TRUE), | |
PKG_LIBS = system(paste(pkg_config_bin, "--libs", pkg_config_name), intern = TRUE) | |
)) | |
}) | |
sourceCpp(file = "cv.cpp", env = .GlobalEnv) | |
## ---- Выгрузка данных ---- | |
process <- function(filename) { | |
path <- file.path(tempdir(), filename) | |
on.exit(unlink(file.path(path))) | |
flog.debug("Unzip file") | |
unzip(zipfile, files = filename, exdir = tempdir(), junkpaths = TRUE, unzip = getOption("unzip")) | |
data <- fread(file = path, sep = ",", header = TRUE, select = c("key_id", "drawing", "word")) | |
flog.debug("Read data") | |
data[, filename := file.path(outdir, paste0(key_id, image_ext))] | |
flog.debug("Write images") | |
data[, json_vec_save(drawing, filename, scale_factor, color)] | |
data[, drawing := NULL] | |
flog.debug("Write mapping") | |
fwrite(x = data, file = mapping_file, append = TRUE, sep = ",", eol = "\n") | |
} | |
if (!dir.exists(outdir)) { | |
flog.debug("Create '%s' directory", outdir) | |
dir.create(outdir) | |
} | |
flog.info("Load files from '%s'", zipfile) | |
files <- unzip(zipfile, list = TRUE)$Name | |
for (f in files) { | |
flog.info("Process '%s'", f) | |
process(f) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment