Last active
May 20, 2020 17:47
-
-
Save taylanbil/01425fa58235743d18d26682613d2005 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch_xla | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
NUM_DEVICES = 8 | |
VOCAB_SIZE = 10 | |
OUTPUT_DIM = 32 | |
BATCH_SIZE_PER_CORE = 1 | |
#SEQUENCE_LEN = 100 | |
EMBEDDING_SIZE = 1 * NUM_DEVICES | |
LR = 0.1 | |
NUM_STEPS = 2 | |
class DistributedEmbedding(nn.Module): | |
def __init__(self, vocab_size, embedding_size, world_size=None, | |
batch_dim=0): | |
super(DistributedEmbedding, self).__init__() | |
self._embedding_size = embedding_size | |
self._world_size = world_size or xm.xrt_world_size() | |
self._batch_dim = batch_dim | |
assert embedding_size % self._world_size == 0, "For this example to work, please provide embedding size a multiple of {}".format(self._world_size) | |
self._sliced_emb_size = self._embedding_size // self._world_size | |
self.embeddings = nn.Embedding(vocab_size, self._sliced_emb_size) | |
@property | |
def _rank(self): | |
# We allow delaying the rank setting to allow module creation at global scope. | |
return xm.get_ordinal() | |
def set_bsz(self, bsz): | |
self.bsz = bsz // self._world_size | |
def init_embs(self, emb): | |
l = self._rank*self._sliced_emb_size | |
r = l + self._sliced_emb_size | |
x = emb[:, l:r].clone().detach() | |
self.embeddings.weight.data.copy_(x.to(self.embeddings.weight.device)) | |
def _get_embedding_pad(self): | |
size = self._embedding_size // self._world_size | |
return self._rank * size, (self._world_size - 1 - self._rank) * size | |
def forward(self, inputs): | |
embeds = self.embeddings(inputs) | |
# Place the embedding slice this rank is handling into the proper position, | |
# pad with zeros, and all-reduce-sum to fetch the full embedding dimension | |
# from all cores. | |
pembeds = xm.all_gather(embeds, dim=-1) | |
# Slice batch dimension to return the inputs which are this core competence. | |
assert inputs.size(self._batch_dim) % self._world_size == 0 | |
isize = inputs.size(self._batch_dim) // self._world_size | |
sliced_pembeds = torch.narrow( | |
pembeds, self._batch_dim, self._rank * isize, isize) | |
return sliced_pembeds.clone().detach().requires_grad_(True), embeds | |
def _get_pad_indices(grad): | |
bsz = grad.size(self._batch_dim) | |
indices = [0]*(2*grad.ndim) | |
indices[-1-2*self._batch_dim] = bsz | |
indices[-2-2*self._batch_dim] = bsz*self._rank | |
return indices # (0,0,0,0,l,r) | |
def backward(self, fullbatch_slicedemb, grad): | |
# Gradient at this point has the full embedding dimension size | |
# and only contains info on the samples this core processed. | |
l, r = self._get_embedding_pad() | |
grad = xm.all_gather(grad, dim=self._batch_dim) | |
size = self._sliced_emb_size | |
sliced_grad = torch.narrow(grad, grad.ndim-1, self._rank * size, size) | |
fullbatch_slicedemb.backward(sliced_grad) | |
class ModelWithDistributedEmbeddings(nn.Module): | |
def __init__(self, vocab_size, embedding_size, output_dim, | |
world_size=None, batch_dim=0): | |
super(ModelWithDistributedEmbeddings, self).__init__() | |
self.embedding = DistributedEmbedding( | |
vocab_size, embedding_size, world_size=world_size) | |
self.linear = nn.Linear(embedding_size, output_dim, bias=False) | |
def initialize_weights(self, linear_weights, embeddings_table): | |
self.linear.weight.data.copy_(linear_weights) | |
self.embedding.init_embs(embeddings_table) | |
def forward(self, inputs): | |
embedded_batch, emb_globalbatch_dimsliced = self.embedding(inputs) | |
x = self.linear(embedded_batch) | |
return F.relu(x), embedded_batch, emb_globalbatch_dimsliced | |
def emb_backward(self, *args, **kwargs): | |
self.embedding.backward(*args, **kwargs) | |
class StandardModel(nn.Module): | |
def __init__(self, vocab_size, embedding_size, output_dim): | |
super(StandardModel, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_size) | |
self.linear = nn.Linear(embedding_size, output_dim, bias=False) | |
def initialize_weights(self, linear_w, embeddings_table): | |
device = self.linear.weight.device | |
self.linear.weight.data.copy_(linear_w.to(device)) | |
self.embedding.weight.data.copy_(embeddings_table.to(device)) | |
def forward(self, inputs): | |
x = self.embedding(inputs) | |
x = self.linear(x) | |
return F.relu(x) | |
def summarize_tensor(tensor): | |
t = tensor.data.cpu() | |
out = t.min(), t.max(), t.sum() | |
return [_.item() for _ in out] | |
# Create models at global scope to avoid Colab host memory OOM. | |
MODEL_W_DISTR_EMB = xmp.MpModelWrapper(ModelWithDistributedEmbeddings( | |
VOCAB_SIZE, EMBEDDING_SIZE, OUTPUT_DIM, world_size=NUM_DEVICES)) | |
STANDARD_MODEL = xmp.MpModelWrapper( | |
StandardModel(VOCAB_SIZE, EMBEDDING_SIZE, OUTPUT_DIM)) | |
# Define initial weights at global scope, but initialize models' weights in | |
# forked processes, due to Embedding weights needing to be initialized to the | |
# dimension specified by rank. | |
EMBTABLE = torch.randn(VOCAB_SIZE, EMBEDDING_SIZE) | |
LINTABLE = torch.randn(OUTPUT_DIM, EMBEDDING_SIZE) | |
BSZ = BATCH_SIZE_PER_CORE * NUM_DEVICES | |
INPUT_TENSOR = torch.randint(0, VOCAB_SIZE-1, (BSZ,)) | |
def compare_weights_after_one_step(index): | |
# Send models to device | |
device = xm.xla_device() | |
model_w_distr_emb = MODEL_W_DISTR_EMB.to(device) | |
standard_model = STANDARD_MODEL.to(device) | |
# Initialize weights | |
model_w_distr_emb.initialize_weights(LINTABLE, EMBTABLE) | |
standard_model.initialize_weights(LINTABLE, EMBTABLE) | |
# Create optimizers | |
standard_optimizer = torch.optim.SGD( | |
standard_model.parameters(), lr=LR*NUM_DEVICES) | |
# The distributed embedding needs to have its own optimizer, because | |
# the embedding table is sharded and we do not want gradient reduction | |
# happening across all cores. | |
distr_emb_optimizer = torch.optim.SGD( | |
model_w_distr_emb.embedding.parameters(), lr=LR) | |
linear_optimizer = torch.optim.SGD( | |
model_w_distr_emb.linear.parameters(), lr=LR) | |
# Prepare input | |
input_tensor = INPUT_TENSOR.to(device) | |
#input_tensor = torch.randint(0, VOCAB_SIZE-1, (bsz, SEQUENCE_LEN)).to(device) | |
for step in range(NUM_STEPS): | |
# Forward pass | |
standard_optimizer.zero_grad() | |
distr_emb_optimizer.zero_grad() | |
linear_optimizer.zero_grad() | |
# Standard model processes per core batches as usual | |
slice_i = xm.get_ordinal() * BATCH_SIZE_PER_CORE | |
output_st = standard_model( | |
#input_tensor[slice_i:slice_i+BATCH_SIZE_PER_CORE, :]) | |
input_tensor[slice_i:slice_i+BATCH_SIZE_PER_CORE]) | |
target = torch.randn(BATCH_SIZE_PER_CORE, OUTPUT_DIM).to(device) | |
# Model w/ distr embeddings processes the full global batch | |
# We also return the embedding layer's outputs in order to perform necessary | |
# operations during backward. | |
output_distr, fewsamples_fullemb, fullsamples_slicedemb = model_w_distr_emb(input_tensor) | |
# Backprop - Partial | |
loss_distr = ((output_distr - target)**2).mean() | |
loss_st = ((output_st - target)**2).mean() | |
loss_distr.backward() # this only propagates up to the embeddings. | |
loss_st.backward() | |
# Gradient updates and rest of Backprop | |
# All-Reduce and update the standard model as usual. | |
xm.optimizer_step(standard_optimizer) | |
# All-Reduce and update the distributed embedding model's | |
# linear optimizer as usual. | |
xm.optimizer_step(linear_optimizer) | |
# Now do the distributed embeddings backprop and update. | |
model_w_distr_emb.emb_backward(fullsamples_slicedemb, fewsamples_fullemb.grad) | |
distr_emb_optimizer.step() | |
# Mark the step and materialize tensors | |
xm.mark_step() | |
# Finally, let's look at the loss values and the weight differences | |
xm.all_reduce('sum', [loss_distr]) | |
xm.all_reduce('sum', [loss_st]) | |
xm.master_print('Step {} Summary:'.format(step)) | |
xm.master_print( | |
'\tLoss: standard {:.4f}, distr-emb {:.4f}, delta is {:.4f}'.format( | |
loss_st.item(), loss_distr.item(), (loss_st-loss_distr).item())) | |
xm.master_print( | |
'\tLoss: standard {:.4f}, distr-emb {:.4f}, delta is {:.4f}'.format( | |
loss_st.item(), loss_distr.item(), (loss_st-loss_distr).item())) | |
xm.master_print( | |
'\tEmbedding Gradient summary: {}'.format( | |
summarize_tensor(standard_model.embedding.weight.cpu()-EMBTABLE))) | |
delta_embw = xm.all_gather( | |
model_w_distr_emb.embedding.embeddings.weight, dim=-1) | |
delta_embw = delta_embw.cpu() - standard_model.embedding.weight.cpu() | |
xm.master_print( | |
'\tEmbedding Difference b/w standard and distributed: {}'.format( | |
summarize_tensor(delta_embw))) | |
if __name__ == '__main__': | |
xmp.spawn( | |
compare_weights_after_one_step, | |
args=(), nprocs=NUM_DEVICES, start_method='fork') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment