-
-
Save luistung/4f23b7d0026b26560fdd82a3b39ca460 to your computer and use it in GitHub Desktop.
/* c++ version of tokenization for bert | |
Copyright (C) 2019 luistung | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <http://www.gnu.org/licenses/>.*/ | |
#include <iostream> | |
#include <fstream> | |
#include <string> | |
#include <vector> | |
#include <unordered_map> | |
#include <boost/algorithm/string.hpp> | |
#include <utf8proc.h> | |
//https://unicode.org/reports/tr15/#Norm_Forms | |
//https://ssl.icu-project.org/apiref/icu4c/uchar_8h.html | |
const std::wstring stripChar = L" \t\n\r\v\f"; | |
using Vocab = std::unordered_map<std::wstring, size_t>; | |
using InvVocab = std::unordered_map<size_t, std::wstring>; | |
class BasicTokenizer { | |
public: | |
BasicTokenizer(bool doLowerCase=true); | |
std::vector<std::wstring> tokenize(const std::string& text) const; | |
private: | |
std::wstring cleanText(const std::wstring& text) const; | |
bool isControol(const wchar_t& ch) const; | |
bool isWhitespace(const wchar_t& ch) const; | |
bool isPunctuation(const wchar_t& ch) const; | |
bool isChineseChar(const wchar_t& ch) const; | |
std::wstring tokenizeChineseChars(const std::wstring& text) const; | |
bool isStripChar(const wchar_t& ch) const; | |
std::wstring strip(const std::wstring& text) const; | |
std::vector<std::wstring> split(const std::wstring& text) const; | |
std::wstring runStripAccents(const std::wstring& text) const; | |
std::vector<std::wstring> runSplitOnPunc(const std::wstring& text) const; | |
bool mDoLowerCase; | |
}; | |
class WordpieceTokenizer { | |
public: | |
WordpieceTokenizer(std::shared_ptr<Vocab> vocab, const std::wstring& unkToken = L"[UNK]", size_t maxInputCharsPerWord=200); | |
std::vector<std::wstring> tokenize(const std::wstring& text) const; | |
private: | |
std::shared_ptr<Vocab> mVocab; | |
std::wstring mUnkToken; | |
size_t mMaxInputCharsPerWord; | |
}; | |
class FullTokenizer { | |
public: | |
FullTokenizer(const std::string& vocabFile, bool doLowerCase = true); | |
std::vector<std::wstring> tokenize(const std::string& text) const; | |
std::vector<size_t> convertTokensToIds(const std::vector<std::wstring>& text) const; | |
private: | |
std::shared_ptr<Vocab> mVocab; | |
InvVocab mInvVocab; | |
std::string mVocabFile; | |
bool mDoLowerCase; | |
BasicTokenizer mBasicTokenizer; | |
WordpieceTokenizer mWordpieceTokenizer; | |
}; | |
static std::string normalize_nfd(const std::string& s) { | |
std::string ret; | |
char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str()); | |
if (result) { | |
ret = std::string(result); | |
free(result); | |
result = NULL; | |
} | |
return ret; | |
} | |
static bool isStripChar(const wchar_t& ch) { | |
return stripChar.find(ch) != std::wstring::npos; | |
} | |
static std::wstring strip(const std::wstring& text) { | |
std::wstring ret = text; | |
if (ret.empty()) return ret; | |
size_t pos = 0; | |
while (pos < ret.size() && isStripChar(ret[pos])) pos++; | |
if (pos != 0) ret = ret.substr(pos, ret.size() - pos); | |
pos = ret.size() - 1; | |
while (pos != (size_t)-1 && isStripChar(ret[pos])) pos--; | |
return ret.substr(0, pos + 1); | |
} | |
static std::vector<std::wstring> split(const std::wstring& text) { | |
std::vector<std::wstring> result; | |
boost::split(result, text, boost::is_any_of(stripChar)); | |
return result; | |
} | |
static std::vector<std::wstring> whitespaceTokenize(const std::wstring& text) { | |
std::wstring rtext = strip(text); | |
if (rtext.empty()) return std::vector<std::wstring>(); | |
return split(text); | |
} | |
static std::wstring convertToUnicode(const std::string& text) { | |
size_t i = 0; | |
std::wstring ret; | |
while (i < text.size()) { | |
wchar_t codepoint; | |
utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint); | |
if (forward < 0) return L""; | |
ret += codepoint; | |
i += forward; | |
} | |
return ret; | |
} | |
static std::string convertFromUnicode(const std::wstring& wText) { | |
char dst[64]; | |
std::string ret; | |
for (auto ch : wText) { | |
utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst); | |
if (num <= 0) return ""; | |
ret += std::string(dst, dst+num); | |
} | |
return ret; | |
} | |
static std::wstring tolower(const std::wstring& s) { | |
std::wstring ret(s.size(), L' '); | |
for (size_t i = 0; i < s.size(); i++) { | |
ret[i] = utf8proc_tolower(s[i]); | |
} | |
return ret; | |
} | |
static std::shared_ptr<Vocab> loadVocab(const std::string& vocabFile) { | |
std::shared_ptr<Vocab> vocab(new Vocab); | |
size_t index = 0; | |
std::ifstream ifs(vocabFile, std::ifstream::in); | |
if (!ifs) { | |
throw std::runtime_error("open file failed"); | |
} | |
std::string line; | |
while (getline(ifs, line)) { | |
std::wstring token = convertToUnicode(line); | |
if (token.empty()) break; | |
token = strip(token); | |
(*vocab)[token] = index; | |
index++; | |
} | |
return vocab; | |
} | |
BasicTokenizer::BasicTokenizer(bool doLowerCase) | |
: mDoLowerCase(doLowerCase) { | |
} | |
std::wstring BasicTokenizer::cleanText(const std::wstring& text) const { | |
std::wstring output; | |
for (const wchar_t& cp : text) { | |
if (cp == 0 || cp == 0xfffd || isControol(cp)) continue; | |
if (isWhitespace(cp)) output += L" "; | |
else output += cp; | |
} | |
return output; | |
} | |
bool BasicTokenizer::isControol(const wchar_t& ch) const { | |
if (ch== L'\t' || ch== L'\n' || ch== L'\r') return false; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; | |
return false; | |
} | |
bool BasicTokenizer::isWhitespace(const wchar_t& ch) const { | |
if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') return true; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_ZS) return true; | |
return false; | |
} | |
bool BasicTokenizer::isPunctuation(const wchar_t& ch) const { | |
if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || | |
(ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) return true; | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS | |
|| cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC | |
|| cat == UTF8PROC_CATEGORY_PO //sometimes ¶ belong SO | |
|| cat == UTF8PROC_CATEGORY_PI | |
|| cat == UTF8PROC_CATEGORY_PF) return true; | |
return false; | |
} | |
bool BasicTokenizer::isChineseChar(const wchar_t& ch) const { | |
if ((ch >= 0x4E00 && ch <= 0x9FFF) || | |
(ch >= 0x3400 && ch <= 0x4DBF) || | |
(ch >= 0x20000 && ch <= 0x2A6DF) || | |
(ch >= 0x2A700 && ch <= 0x2B73F) || | |
(ch >= 0x2B740 && ch <= 0x2B81F) || | |
(ch >= 0x2B820 && ch <= 0x2CEAF) || | |
(ch >= 0xF900 && ch <= 0xFAFF) || | |
(ch >= 0x2F800 && ch <= 0x2FA1F)) | |
return true; | |
return false; | |
} | |
std::wstring BasicTokenizer::tokenizeChineseChars(const std::wstring& text) const { | |
std::wstring output; | |
for (auto& ch : text) { | |
if (isChineseChar(ch)) { | |
output += L' '; | |
output += ch; | |
output += L' '; | |
} | |
else | |
output += ch; | |
} | |
return output; | |
} | |
std::wstring BasicTokenizer::runStripAccents(const std::wstring& text) const { | |
//Strips accents from a piece of text. | |
std::wstring nText; | |
try { | |
nText = convertToUnicode(normalize_nfd(convertFromUnicode(text))); | |
} catch (std::bad_cast& e) { | |
std::cerr << "bad_cast" << std::endl; | |
return L""; | |
} | |
std::wstring output; | |
for (auto& ch : nText) { | |
auto cat = utf8proc_category(ch); | |
if (cat == UTF8PROC_CATEGORY_MN) continue; | |
output += ch; | |
} | |
return output; | |
} | |
std::vector<std::wstring> BasicTokenizer::runSplitOnPunc(const std::wstring& text) const { | |
size_t i = 0; | |
bool startNewWord = true; | |
std::vector<std::wstring> output; | |
while (i < text.size()) { | |
wchar_t ch = text[i]; | |
if (isPunctuation(ch)) { | |
output.push_back(std::wstring(&ch, 1)); | |
startNewWord = true; | |
} | |
else { | |
if (startNewWord) output.push_back(std::wstring()); | |
startNewWord = false; | |
output[output.size() - 1] += ch; | |
} | |
i++; | |
} | |
return output; | |
} | |
std::vector<std::wstring> BasicTokenizer::tokenize(const std::string& text) const { | |
std::wstring nText = convertToUnicode(text); | |
nText = cleanText(nText); | |
nText = tokenizeChineseChars(nText); | |
const std::vector<std::wstring>& origTokens = whitespaceTokenize(nText); | |
std::vector<std::wstring> splitTokens; | |
for (std::wstring token : origTokens) { | |
if (mDoLowerCase) { | |
token = tolower(token); | |
token = runStripAccents(token); | |
} | |
const auto& tokens = runSplitOnPunc(token); | |
splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); | |
} | |
return whitespaceTokenize(boost::join(splitTokens, L" ")); | |
} | |
WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr<Vocab> vocab, const std::wstring& unkToken, size_t maxInputCharsPerWord) | |
: mVocab(vocab), | |
mUnkToken(unkToken), | |
mMaxInputCharsPerWord(maxInputCharsPerWord) { | |
} | |
std::vector<std::wstring> WordpieceTokenizer::tokenize(const std::wstring& text) const { | |
std::vector<std::wstring> outputTokens; | |
for (auto& token : whitespaceTokenize(text)) { | |
if (token.size() > mMaxInputCharsPerWord) { | |
outputTokens.push_back(mUnkToken); | |
} | |
bool isBad = false; | |
size_t start = 0; | |
std::vector<std::wstring> subTokens; | |
while (start < token.size()) { | |
size_t end = token.size(); | |
std::wstring curSubstr; | |
bool hasCurSubstr = false; | |
while (start < end) { | |
std::wstring substr = token.substr(start, end - start); | |
if (start > 0) substr = L"##" + substr; | |
if (mVocab->find(substr) != mVocab->end()) { | |
curSubstr = substr; | |
hasCurSubstr = true; | |
break; | |
} | |
end--; | |
} | |
if (!hasCurSubstr) { | |
isBad = true; | |
break; | |
} | |
subTokens.push_back(curSubstr); | |
start = end; | |
} | |
if (isBad) outputTokens.push_back(mUnkToken); | |
else outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end()); | |
} | |
return outputTokens; | |
} | |
FullTokenizer::FullTokenizer(const std::string& vocabFile, bool doLowerCase) : | |
mVocab(loadVocab(vocabFile)), | |
mBasicTokenizer(BasicTokenizer(doLowerCase)), | |
mWordpieceTokenizer(WordpieceTokenizer(mVocab)) { | |
for (auto& v : *mVocab) mInvVocab[v.second] = v.first; | |
} | |
std::vector<std::wstring> FullTokenizer::tokenize(const std::string& text) const { | |
std::vector<std::wstring> splitTokens; | |
for (auto& token : mBasicTokenizer.tokenize(text)) | |
for (auto& subToken : mWordpieceTokenizer.tokenize(token)) | |
splitTokens.push_back(subToken); | |
return splitTokens; | |
} | |
std::vector<size_t> FullTokenizer::convertTokensToIds(const std::vector<std::wstring>& text) const { | |
std::vector<size_t> ret(text.size()); | |
for (size_t i = 0; i < text.size(); i++) { | |
ret[i] = (*mVocab)[text[i]]; | |
} | |
return ret; | |
} | |
int main() { | |
FullTokenizer* pTokenizer = nullptr; | |
try { | |
pTokenizer = new FullTokenizer("data/chinese_L-12_H-768_A-12/vocab.txt"); | |
} | |
catch (std::exception& e) { | |
std::cerr << "construct FullTokenizer failed" << std::endl; | |
return -1; | |
} | |
std::string line; | |
while (std::getline(std::cin, line)) { | |
auto tokens = pTokenizer->tokenize(line); | |
auto ids = pTokenizer->convertTokensToIds(tokens); | |
std::cout << "#" << convertFromUnicode(boost::join(tokens, L" ")) << "#" << "\t"; | |
for (size_t i = 0; i < ids.size(); i++) { | |
if (i!=0) std::cout << " "; | |
std::cout << ids[i]; | |
} | |
std::cout << std::endl; | |
} | |
return 0; | |
} |
original bert version d66a146741588fb208450bde15aa7db143baaa69
OMG, dude, you have to indicate which license you attach to this piece of code.
OMG, dude, you have to indicate which license you attach to this piece of code.
All right. I've done.
hi你好 ,编译运行显示
i am chinese a in beijing
i am chinese a in beijing
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
wo 我是中国人
w我是中国人
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
hi你好 ,编译运行显示
i am chinese a in beijing
i am chinese a in beijing
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
wo 我是中国人
w我是中国人
#[UNK] [UNK] [UNK] [UNK] [UNK] [UNK]# 0 0 0 0 0 0
Probably, you gave a wrong vocab.txt path.
Just now, I've added exception when failing to open file.
hello,utf8proc is C library,how I get C++ library?
why not C library?
have not compile on Windows. maybe you can try cmake.
see official document https://github.com/JuliaStrings/utf8proc?tab=readme-ov-file
g++ tokenization.cpp --std=c++11 -o tokenization -Iutf8proc-2.4.0/ -Lutf8proc-2.4.0/ -lutf8proc