Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created July 10, 2022 13:43
Show Gist options
  • Save Birch-san/77326b08a624fb9aca0e21f2234f508a to your computer and use it in GitHub Desktop.
Save Birch-san/77326b08a624fb9aca0e21f2234f508a to your computer and use it in GitHub Desktop.
from torch.nn import Embedding
from typing import Tuple, TypeVar, Iterable
from typing_extensions import TypeAlias
from enum import Enum, auto
from math import ceil
from torch import BoolTensor, LongTensor, sparse_coo_tensor, ones
from itertools import chain
class Label(Enum):
touhou = 0
hololive = auto()
marisa = auto()
reimu = auto()
youmu = auto()
sakuya = auto()
flandre = auto()
reiuji = auto()
reisen = auto()
tewi = auto()
patchouli = auto()
aya = auto()
pekora = auto()
kronii = auto()
gura = auto()
suisei = auto()
ame = auto()
noel = auto()
subaru = auto()
kiara = auto()
black_hair = auto()
silver_hair = auto()
blue_hair = auto()
blonde_hair = auto()
purple_hair = auto()
orange_hair = auto()
bunny_ears = auto()
bird_person = auto()
vocab_size=len(Label)
# t5-small compressed 32100 vocab tokens into 512 dims
# there's plenty of range per bfloat16 to represent a variety of tokens
embedding_dim=ceil(512/32100 * vocab_size)
model = Embedding(num_embeddings=vocab_size,
embedding_dim=embedding_dim,
sparse=True)
T = TypeVar('T')
_Caption: TypeAlias = Tuple[Label, ...]
_Captions: TypeAlias = Tuple[_Caption, ...]
def make_row_indices(enumerated: Tuple[int, _Caption]) -> Tuple[int, ...]:
(ix, labels) = enumerated
return (ix,) * len(labels)
def flatten(captions: Iterable[Tuple[T, ...]]) -> Iterable[T]:
return chain.from_iterable(captions)
def get_value(label: Label) -> int:
return label.value
def captions_to_tensor(captions: _Captions) -> BoolTensor:
row_indices: Tuple[int, ...] = tuple(flatten(map(make_row_indices, enumerate(captions))))
labels: Tuple[int, ...] = tuple(map(get_value, flatten(captions)))
indices_nominal: Tuple[Tuple[int, ...], Tuple[int, ...]] = (row_indices, labels)
return sparse_coo_tensor(
indices=LongTensor(indices_nominal),
values=ones(len(row_indices), dtype=bool),
dtype=bool)
captions: _Captions = (
(Label.touhou, Label.marisa, Label.blonde_hair),
(Label.touhou, Label.reimu, Label.black_hair),
(Label.touhou, Label.youmu, Label.silver_hair),
(Label.touhou, Label.sakuya, Label.silver_hair),
(Label.touhou, Label.flandre, Label.blonde_hair),
(Label.touhou, Label.reiuji, Label.black_hair, Label.bird_person),
(Label.touhou, Label.reisen, Label.purple_hair, Label.bunny_ears),
(Label.touhou, Label.tewi, Label.black_hair, Label.bunny_ears),
(Label.touhou, Label.patchouli, Label.purple_hair),
(Label.touhou, Label.aya, Label.black_hair, Label.black_hair),
(Label.hololive, Label.pekora, Label.blue_hair, Label.bunny_ears),
(Label.hololive, Label.kronii, Label.blue_hair),
(Label.hololive, Label.suisei, Label.blue_hair),
(Label.hololive, Label.gura, Label.silver_hair),
(Label.hololive, Label.noel, Label.silver_hair),
(Label.hololive, Label.ame, Label.blonde_hair),
(Label.hololive, Label.subaru, Label.black_hair, Label.bird_person),
(Label.hololive, Label.kiara, Label.black_hair, Label.bird_person),
)
batch_of_captions_tensor: BoolTensor = captions_to_tensor(captions[:2])
# okay we have our Embedding, we have a batch of captions... now we need to train the Embedding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment