Created
April 8, 2021 00:21
-
-
Save salkj/f0db13d68ab39d4a3da2374643e99c3b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from . import contextual_watch_sequence_dataset | |
from . import word2gm | |
from . import mixturesamefamily as M | |
import pandas as pd | |
import pickle | |
from itertools import combinations | |
import time | |
import torch | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.distributions as D | |
class GMM_Scale_Clipper(object): | |
def __init__(self, lower_scale_cap, upper_scale_cap): | |
self.lower_scale_cap = lower_scale_cap | |
self.upper_scale_cap = upper_scale_cap | |
def __call__(self, module, to_update): | |
with torch.no_grad(): | |
for i in range(module.num_mixture_components): | |
for token in to_update: | |
scale_tensor = getattr(module, f'gmm_scales_comp_{str(i)}') | |
scale_tensor.weight.data[token,:] = torch.max(self.lower_scale_cap * torch.ones_like(scale_tensor.weight.data[token, :]), | |
torch.min(self.upper_scale_cap * torch.ones_like(scale_tensor.weight.data[token, :]), scale_tensor.weight.data[token,:])) | |
def _batch(iterable, n=1): | |
l = len(iterable) | |
for ndx in range(0, l, n): | |
yield iterable[ndx:min(ndx + n, l)] | |
def _set_default_tensor_type(device): | |
if 'cuda' in device: | |
torch.set_default_tensor_type(torch.cuda.FloatTensor) | |
else: | |
torch.set_default_tensor_type(torch.FloatTensor) | |
return | |
def train(cc_input_data, kg_metadata, softcoded_recs, params): | |
MARGIN = params['margin'] | |
BATCH_SIZE = params['batch_size'] # small as possible? | |
NUM_MIXTURE_COMPONENTS = params['num_mixture_components'] | |
COMPONENT_DIM = params['dim'] | |
EPOCHS = params['epochs']# more? | |
CONTEXT_WINDOW = params['ws'] # good? | |
ALPHA = params['alpha'] | |
INIT_SCALE = params['init_scale'] | |
LOW_SCALE_CAP = params['low_scale_cap'] | |
UPPER_SCALE_CAP = params['upper_scale_cap'] | |
MEAN_NORM_CAP = params['mean_norm_cap'] | |
COVAR_MODE = 'diagonal' | |
KG = kg_metadata | |
SOFTCODED_RECS = softcoded_recs | |
CLASS_LABELS = params['class_labels'] | |
SCALE_GRAD_BY_FREQ = params['scale_grad_by_freq'] | |
LR = params['lr'] | |
PATH = cc_input_data | |
if torch.cuda.is_available(): | |
device = 'cuda:0' | |
else: | |
device = 'cpu' | |
print(f'Using {device}') | |
_set_default_tensor_type(device) | |
dataset = contextual_watch_sequence_dataset.contextual_watch_sequence_dataset(PATH, KG, SOFTCODED_RECS, context_window=CONTEXT_WINDOW, alpha=ALPHA, class_label=CLASS_LABELS) | |
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=BATCH_SIZE) | |
print(f'Dataset size: {len(dataset)}') | |
print(f'Number of batches: {len(dataset)/BATCH_SIZE}') | |
df = dataset.title_id | |
model = word2gm.word2gm(NUM_MIXTURE_COMPONENTS, COMPONENT_DIM, dataset.content2idx, dataset.metadata2metadataidx, BATCH_SIZE, INIT_SCALE, COVAR_MODE, MEAN_NORM_CAP, SCALE_GRAD_BY_FREQ) | |
model = model.to(device) | |
scale_clipper = GMM_Scale_Clipper(LOW_SCALE_CAP, UPPER_SCALE_CAP) | |
optimizer = optim.Adadelta(model.parameters()) | |
#optimizer = optim.Adam(model.parameters(), lr=LR) | |
#optimizer = optim.SGD(model.parameters(), lr=LR) | |
tokens = list(dataset.metadata2metadataidx.values()) + list(dataset.content2idx.values()) | |
for epoch in range(EPOCHS): | |
# Training | |
print(f'Epoch: {epoch}') | |
for i, (batch, true_context, fake_context) in enumerate(train_loader): | |
start = time.time() | |
optimizer.zero_grad() | |
log_true_energy = model(batch, true_context) | |
log_fake_energy = model(batch, fake_context) | |
# pairwise hinge | |
loss = F.relu(MARGIN + log_fake_energy - log_true_energy) | |
# pairwise logistic | |
# loss = F.softplus(log_fake_energy - log_true_energy) | |
# pointwise hinge | |
# loss = F.relu(MARGIN - log_true_energy) + F.relu(MARGIN + log_fake_energy) | |
# pointwise logistic | |
# loss = F.softplus(-1. * log_true_energy) + F.softplus(log_fake_energy) | |
loss = torch.mean(loss) | |
loss.backward(retain_graph=False) | |
optimizer.step() | |
to_update = [] | |
for t in batch: | |
to_update.append(t.item()) | |
for t in true_context: | |
to_update.append(t.item()) | |
for t in fake_context: | |
to_update.append(t.item()) | |
to_update = list(set(to_update)) | |
scale_clipper(model, to_update) | |
end = time.time() | |
with torch.no_grad(): | |
if i % 50 == 0: | |
print(50*'--') | |
print(f'Elapsed time: {end - start}') | |
print(epoch, i, loss) | |
break | |
model = model.to('cpu') | |
return model, dataset | |
def save_model(directory, dataset, model): | |
torch.save(model, directory + 'model.pth') | |
with open(directory + 'dataset.pkl', 'wb') as f: | |
pickle.dump(dataset, f) | |
def get_token_mean(token, model): | |
return torch.stack([getattr(model, f'gmm_means_comp_{str(i)}')(torch.tensor([token], dtype=torch.long)) for i in range(model.num_mixture_components)], dim=1).squeeze(0) | |
def get_token_scale(token, model): | |
return torch.stack([getattr(model, f'gmm_scales_comp_{str(i)}')(torch.tensor([token], dtype=torch.long)) for i in range(model.num_mixture_components)], dim=1).squeeze(0) | |
def get_token_mix(token, model): | |
return getattr(model, f'gmm_mix')(torch.tensor([token], dtype=torch.long)).squeeze(0) | |
def get_token_gmm(token, model): | |
token_means = get_token_mean(token, model) | |
token_scales = get_token_scale(token, model) | |
token_mix = get_token_mix(token, model) | |
gmm = M.MixtureSameFamily(D.Categorical(logits=token_mix), D.Independent(D.Normal(token_means, token_scales), 1)) | |
return gmm | |
def get_mean_gmm(content_mog_stack): | |
mean_gmm = None | |
history_length = len(content_mog_stack) | |
if history_length == 1: | |
return content_mog_stack[0] | |
with torch.no_grad(): | |
init_round = True | |
while len(content_mog_stack) != 0: | |
mog_a = content_mog_stack.pop(0) | |
mog_b = content_mog_stack.pop(0) | |
new_gmm_mix = [] | |
new_gmm_means = [] | |
new_gmm_vars = [] | |
for c_n in range(len(mog_a.mixture_distribution.probs)): | |
for c_m in range(len(mog_b.mixture_distribution.probs)): | |
new_comp_weight = mog_a.mixture_distribution.probs[c_n] * mog_b.mixture_distribution.probs[c_m] | |
if init_round: | |
# print(f'init_round: {init_round}') | |
new_comp_mean = (1./history_length * mog_a.components_distribution.base_dist.loc[c_n,:]) + (1./history_length * mog_b.components_distribution.base_dist.loc[c_m,:]) | |
new_comp_var = (1./(history_length**2) * mog_a.components_distribution.base_dist.scale[c_n,:].pow(2)) + (1./(history_length**2) * mog_b.components_distribution.base_dist.scale[c_m,:].pow(2)) | |
else: | |
# print(f'init_round: {init_round}') | |
new_comp_mean = mog_a.components_distribution.base_dist.loc[c_n,:] + (1./history_length * mog_b.components_distribution.base_dist.loc[c_m,:]) | |
new_comp_var = mog_a.components_distribution.base_dist.scale[c_n,:].pow(2) + (1./(history_length**2) * mog_b.components_distribution.base_dist.scale[c_m,:].pow(2)) | |
new_gmm_mix.append(new_comp_weight) | |
new_gmm_means.append(new_comp_mean) | |
new_gmm_vars.append(new_comp_var) | |
new_gmm_mix = torch.stack(new_gmm_mix) | |
new_gmm_means = torch.stack(new_gmm_means) | |
new_gmm_vars = torch.stack(new_gmm_vars) | |
new_gmm = M.MixtureSameFamily(D.Categorical(probs=new_gmm_mix), D.Independent(D.Normal(new_gmm_means, torch.sqrt(new_gmm_vars)), 1)) | |
init_round = False | |
if len(content_mog_stack) == 0: | |
mean_gmm = new_gmm | |
break | |
else: | |
content_mog_stack.insert(0, new_gmm) | |
return mean_gmm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributions as D | |
from . import mixturesamefamily as M | |
import math | |
import time | |
# An implementation of word2gm | |
class word2gm(nn.Module): | |
def __init__(self, num_mixture_components, | |
dist_dimensions, | |
content2idx, | |
metadata2metadataidx, | |
batch_size, | |
init_scale=1., | |
covar_mode='diagonal', | |
mean_max_norm=None, | |
scale_grad_by_freq=False): | |
super(word2gm, self).__init__() | |
self.content2idx = content2idx | |
self.metadata2metadataidx = metadata2metadataidx | |
self.batch_size = batch_size | |
self.num_mixture_components = num_mixture_components | |
self.dist_dimensions = dist_dimensions | |
self.init_scale = init_scale | |
self.covar_mode = covar_mode | |
self.mean_max_norm = mean_max_norm | |
self.scale_grad_by_freq = scale_grad_by_freq | |
for i in range(self.num_mixture_components): | |
setattr(self, f'gmm_means_comp_{str(i)}', nn.Embedding(len(content2idx) + len(self.metadata2metadataidx), self.dist_dimensions, max_norm=mean_max_norm, scale_grad_by_freq=scale_grad_by_freq)) | |
nn.init.uniform_(getattr(self, f'gmm_means_comp_{str(i)}').weight, -1. * math.sqrt(3. / self.dist_dimensions), math.sqrt(3. / self.dist_dimensions)) | |
setattr(self, f'gmm_scales_comp_{str(i)}', nn.Embedding(len(content2idx) + len(self.metadata2metadataidx), self.dist_dimensions, scale_grad_by_freq=scale_grad_by_freq)) | |
nn.init.constant_(getattr(self, f'gmm_scales_comp_{str(i)}').weight, init_scale) | |
setattr(self, f'gmm_mix', nn.Embedding(len(content2idx) + len(self.metadata2metadataidx), self.num_mixture_components, scale_grad_by_freq=scale_grad_by_freq)) | |
nn.init.constant_(getattr(self, f'gmm_mix').weight, 1.) | |
def forward(self, word, true_context): | |
a = self.log_expected_likelihood_kernel(word, true_context) | |
return a | |
def _log_expected_likelihood_kernel(self, batch_token_a, batch_token_b): | |
mean_batch_a = torch.stack([getattr(self, f'gmm_means_comp_{str(i)}')(batch_token_a) for i in range(self.num_mixture_components)], dim=1) | |
mean_batch_b = torch.stack([getattr(self, f'gmm_means_comp_{str(i)}')(batch_token_b) for i in range(self.num_mixture_components)], dim=1) | |
scale_batch_a = torch.stack([getattr(self, f'gmm_scales_comp_{str(i)}')(batch_token_a) for i in range(self.num_mixture_components)], dim=1) | |
scale_batch_b = torch.stack([getattr(self, f'gmm_scales_comp_{str(i)}')(batch_token_b) for i in range(self.num_mixture_components)], dim=1) | |
mix_batch_a = F.softmax(self.gmm_mix(batch_token_a), dim=1).unsqueeze(-1) | |
mix_batch_b = F.softmax(self.gmm_mix(batch_token_b), dim=1).unsqueeze(-1) | |
diag_var_batch_a = scale_batch_a.pow(2) | |
diag_var_batch_b = scale_batch_b.pow(2) | |
return self.batched_partial_log_energy_diagonal_components(mix_batch_a, mix_batch_b, mean_batch_a, mean_batch_b, diag_var_batch_a, diag_var_batch_b, len(batch_token_a)) | |
def log_expected_likelihood_kernel(self, batch_token_a, batch_token_b): | |
log_energy = self._log_expected_likelihood_kernel(batch_token_a, batch_token_b) | |
return log_energy | |
def batched_partial_log_energy_diagonal_components(self, mix_batch_a, mix_batch_b, mean_batch_a, mean_batch_b, diag_var_batch_a, diag_var_batch_b, batch_size): | |
eps = 0. | |
mix_prod = mix_batch_a.unsqueeze(2) * mix_batch_b.unsqueeze(1) | |
mix_prod = mix_prod.reshape((batch_size, -1)) | |
mean_diff = mean_batch_a.unsqueeze(2) - mean_batch_b.unsqueeze(1) | |
mean_diff = mean_diff.reshape((batch_size, -1, self.dist_dimensions)) | |
diag_sum = diag_var_batch_a.unsqueeze(2) + diag_var_batch_b.unsqueeze(1) | |
diag_sum = diag_sum.reshape((batch_size, -1, self.dist_dimensions)) + eps | |
inv_diag_sum = 1. / diag_sum | |
ple = -0.5 * torch.sum(torch.log(diag_sum), axis=-1) - 0.5 * torch.sum(mean_diff * inv_diag_sum * mean_diff, axis=-1) | |
max_ple = torch.max(ple, axis=-1)[0].view((-1,1)) | |
log_energy = max_ple.view((-1,)) + torch.log(torch.sum(mix_prod * (torch.exp(ple - max_ple)), axis=-1)) | |
return log_energy | |
def partial_log_energy_diagonal_components(self, mean_word2, mean_word1, diag_var_word2, diag_var_word1): | |
eps = 0. | |
mean_diff = mean_word2 - mean_word1 | |
inv_diag_sum = 1. / (diag_var_word2 + diag_var_word1 + eps) | |
ple = -0.5 * torch.sum(torch.log(diag_var_word2 + diag_var_word1 + eps)) - 0.5 * torch.sum(mean_diff * inv_diag_sum * mean_diff) | |
return ple |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment