Last active
November 13, 2020 11:03
-
-
Save ksopyla/4db0ed33fc2143ca800c9cc428ae0871 to your computer and use it in GitHub Desktop.
Flair loading polish embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
from flair.embeddings import BertEmbeddings | |
from flair.embeddings import WordEmbeddings | |
from flair.embeddings import FastTextEmbeddings | |
from flair.embeddings import ELMoEmbeddings | |
from flair.embeddings import BytePairEmbeddings | |
from flair.embeddings import FlairEmbeddings | |
from flair.embeddings import XLMEmbeddings | |
from flair.embeddings import RoBERTaEmbeddings | |
from flair.data import Sentence | |
import numpy as np | |
import os | |
text = "stół krzesło programista informatyk adwokat prawnik" | |
def similarity(token1, token2): | |
d = (token1.embedding - token2.embedding).cpu().numpy() | |
return np.power(d, 2).sum() | |
def print_similarity(sentence): | |
print( | |
f"similarity ({sentence[0]},{sentence[1]}):{similarity(sentence[0],sentence[1])}" | |
) | |
print( | |
f"similarity ({sentence[2]},{sentence[3]}):{similarity(sentence[2],sentence[3])}" | |
) | |
print( | |
f"similarity ({sentence[0]},{sentence[2]}):{similarity(sentence[0],sentence[2])}" | |
) | |
print( | |
f"similarity ({sentence[1]},{sentence[3]}):{similarity(sentence[1],sentence[3])}" | |
) | |
#%% BERT | |
# multilingual model, we should pretrain polish | |
# init embedding | |
pretrained_model = "bert-base-multilingual-cased" | |
pretrained_model = "./src/playground/data/model_bert/pytorch_v30k_bpe_tok_20_epoch/" | |
pretrained_model = os.path.abspath(pretrained_model) | |
# pooling_operation: pooling operation for subword embeddings (supported: first, last, first_last and mean) | |
bert = BertEmbeddings(pretrained_model, pooling_operation="mean", use_scalar_mix=True) | |
# create a sentence | |
sentence = Sentence(text) | |
# embed words in sentence | |
bert.embed(sentence) | |
print_similarity(sentence) | |
#%% Polish elmo | |
print("polish elmlo") | |
path_to_clarin_elmo = "./src/playground/data/embeddings/elmo" | |
path_to_clarin_elmo = os.path.abspath(path_to_clarin_elmo) | |
elmo_options = f"{path_to_clarin_elmo}/options.json" | |
elmo_weights = f"{path_to_clarin_elmo}/weights.hdf5" | |
elmo = ELMoEmbeddings(options_file=elmo_options, weight_file=elmo_weights) | |
sentence = Sentence(text) | |
elmo.embed(sentence) | |
print_similarity(sentence) | |
#%% FastTEXT - | |
#%% | |
print("fastext cc-bin-polish from fb") | |
fasttext_bin_file = "cc.pl.300.bin" | |
#fasttext_bin_file = "wiki.pl.bin" | |
fst_bin = f"{path}/{fasttext_bin_file}" | |
fst_bin_embed = FastTextEmbeddings(fst_bin) | |
#%% BPE | |
# init embedding | |
# dim: 50, 100,200, 300 | |
# vocab_size/syllables: 1000,3000,5000,10k, 25k,50k,100k,200k | |
bpe = BytePairEmbeddings("pl", dim=100, syllables=50000) | |
sentence = Sentence(text) | |
bpe.embed(sentence) | |
print_similarity(sentence) | |
#%% Zalando Flair | |
# models: | |
# - pl-{forward,backward}, | |
# - pl-opus-{forward,backward} | |
# - multi-{forward,backward} (multilingual with polish) | |
# - multi-{forward,backward}-fast (multilingual with polish) | |
# flair = FlairEmbeddings('pl-forward') | |
flair = FlairEmbeddings("pl-backward") | |
sentence = Sentence(text) | |
flair.embed(sentence) | |
print_similarity(sentence) | |
#%% | |
# multilingual model, we should pretrain polish | |
# there is new XLM multi language model which suports Polish | |
# https://github.com/facebookresearch/XLM#pretrained-cross-lingual-language-models | |
# Error: [E050] Can't find model 'en'. It doesn't seem to be a shortcut link, a Python package or a valid path to a data directory. | |
# you should install spacy and download the model en | |
# python -m spacy.en.download | |
# python -c "import spacy; spacy.load('en')" | |
# !!! you should download the model from fb github and convert it to | |
# https://github.com/huggingface/pytorch-transformers/issues/1157#event-2599001476 | |
# path_to_model = '/src/playground/data/embeddings/xlm' | |
# model = f'{path_to_model}/mlm_17_1280.pth' | |
# model='xlm-mlm-xnli15-1024' | |
model = "xlm-mlm-17-1280" | |
xlm = XLMEmbeddings(model) | |
#%% | |
sentence = Sentence(text) | |
xlm.embed(sentence) | |
print_similarity(sentence) | |
#%% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment