import torch
import numpy as np
from collections.abc import Iterable
from tqdm.auto import tqdm
from multiprocessing import Pool
from vina2vi.models.char_based.bigram import Bigram
from vina2vi.util import (
Vietnamese,
uncased_vina_normalizer,
cased_vi_normalizer,
)
def count_np(s: str):
count_matrix = np.zeros((len(Bigram.itoc), len(Bigram.itoc)), dtype=np.int32)
if s != "":
# Without normalization, one may obtain a very different count matrix
s = cased_vi_normalizer.normalize_str(s)
tokens = list(s.lower())
unk_index = Bigram.ctoi[Bigram.unk_token]
tokens_index = [Bigram.ctoi.get(token, unk_index) for token in tokens]
l1 = [Bigram.ctoi.get(Bigram.bos_token, unk_index)] + tokens_index
l2 = tokens_index + [Bigram.ctoi.get(Bigram.eos_token, unk_index)]
np.add.at(count_matrix, (l1, l2), 1)
return count_matrix
class BigramNew(Bigram):
def fit(
self,
data: Iterable[str],
*,
total: int | None = None,
chunksize: int = 1,
) -> None:
# Multiprocessing pool idea borrowed from mCoding
# https://www.youtube.com/watch?v=X7vBbelRXn0&t=280s
with Pool() as pool:
# Unable to use a method like self.count in imap_unordered() here
# because the class Bigram contains a torch.Generator, which is not picklable.
matrices = pool.imap_unordered(
count_np,
tqdm(data, total=total),
chunksize=chunksize,
)
for matrix in matrices:
self.count_matrix += torch.from_numpy(matrix)
self.update_proba_matrix()
Last active
September 26, 2023 08:32
-
-
Save catdingding/e1f539add03be46cb428ed19d736e782 to your computer and use it in GitHub Desktop.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment