Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Last active May 20, 2020 17:47
Show Gist options
  • Save taylanbil/01425fa58235743d18d26682613d2005 to your computer and use it in GitHub Desktop.
Save taylanbil/01425fa58235743d18d26682613d2005 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
NUM_DEVICES = 8
VOCAB_SIZE = 10
OUTPUT_DIM = 32
BATCH_SIZE_PER_CORE = 1
#SEQUENCE_LEN = 100
EMBEDDING_SIZE = 1 * NUM_DEVICES
LR = 0.1
NUM_STEPS = 2
class DistributedEmbedding(nn.Module):
def __init__(self, vocab_size, embedding_size, world_size=None,
batch_dim=0):
super(DistributedEmbedding, self).__init__()
self._embedding_size = embedding_size
self._world_size = world_size or xm.xrt_world_size()
self._batch_dim = batch_dim
assert embedding_size % self._world_size == 0, "For this example to work, please provide embedding size a multiple of {}".format(self._world_size)
self._sliced_emb_size = self._embedding_size // self._world_size
self.embeddings = nn.Embedding(vocab_size, self._sliced_emb_size)
@property
def _rank(self):
# We allow delaying the rank setting to allow module creation at global scope.
return xm.get_ordinal()
def set_bsz(self, bsz):
self.bsz = bsz // self._world_size
def init_embs(self, emb):
l = self._rank*self._sliced_emb_size
r = l + self._sliced_emb_size
x = emb[:, l:r].clone().detach()
self.embeddings.weight.data.copy_(x.to(self.embeddings.weight.device))
def _get_embedding_pad(self):
size = self._embedding_size // self._world_size
return self._rank * size, (self._world_size - 1 - self._rank) * size
def forward(self, inputs):
embeds = self.embeddings(inputs)
# Place the embedding slice this rank is handling into the proper position,
# pad with zeros, and all-reduce-sum to fetch the full embedding dimension
# from all cores.
pembeds = xm.all_gather(embeds, dim=-1)
# Slice batch dimension to return the inputs which are this core competence.
assert inputs.size(self._batch_dim) % self._world_size == 0
isize = inputs.size(self._batch_dim) // self._world_size
sliced_pembeds = torch.narrow(
pembeds, self._batch_dim, self._rank * isize, isize)
return sliced_pembeds.clone().detach().requires_grad_(True), embeds
def _get_pad_indices(grad):
bsz = grad.size(self._batch_dim)
indices = [0]*(2*grad.ndim)
indices[-1-2*self._batch_dim] = bsz
indices[-2-2*self._batch_dim] = bsz*self._rank
return indices # (0,0,0,0,l,r)
def backward(self, fullbatch_slicedemb, grad):
# Gradient at this point has the full embedding dimension size
# and only contains info on the samples this core processed.
l, r = self._get_embedding_pad()
grad = xm.all_gather(grad, dim=self._batch_dim)
size = self._sliced_emb_size
sliced_grad = torch.narrow(grad, grad.ndim-1, self._rank * size, size)
fullbatch_slicedemb.backward(sliced_grad)
class ModelWithDistributedEmbeddings(nn.Module):
def __init__(self, vocab_size, embedding_size, output_dim,
world_size=None, batch_dim=0):
super(ModelWithDistributedEmbeddings, self).__init__()
self.embedding = DistributedEmbedding(
vocab_size, embedding_size, world_size=world_size)
self.linear = nn.Linear(embedding_size, output_dim, bias=False)
def initialize_weights(self, linear_weights, embeddings_table):
self.linear.weight.data.copy_(linear_weights)
self.embedding.init_embs(embeddings_table)
def forward(self, inputs):
embedded_batch, emb_globalbatch_dimsliced = self.embedding(inputs)
x = self.linear(embedded_batch)
return F.relu(x), embedded_batch, emb_globalbatch_dimsliced
def emb_backward(self, *args, **kwargs):
self.embedding.backward(*args, **kwargs)
class StandardModel(nn.Module):
def __init__(self, vocab_size, embedding_size, output_dim):
super(StandardModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.linear = nn.Linear(embedding_size, output_dim, bias=False)
def initialize_weights(self, linear_w, embeddings_table):
device = self.linear.weight.device
self.linear.weight.data.copy_(linear_w.to(device))
self.embedding.weight.data.copy_(embeddings_table.to(device))
def forward(self, inputs):
x = self.embedding(inputs)
x = self.linear(x)
return F.relu(x)
def summarize_tensor(tensor):
t = tensor.data.cpu()
out = t.min(), t.max(), t.sum()
return [_.item() for _ in out]
# Create models at global scope to avoid Colab host memory OOM.
MODEL_W_DISTR_EMB = xmp.MpModelWrapper(ModelWithDistributedEmbeddings(
VOCAB_SIZE, EMBEDDING_SIZE, OUTPUT_DIM, world_size=NUM_DEVICES))
STANDARD_MODEL = xmp.MpModelWrapper(
StandardModel(VOCAB_SIZE, EMBEDDING_SIZE, OUTPUT_DIM))
# Define initial weights at global scope, but initialize models' weights in
# forked processes, due to Embedding weights needing to be initialized to the
# dimension specified by rank.
EMBTABLE = torch.randn(VOCAB_SIZE, EMBEDDING_SIZE)
LINTABLE = torch.randn(OUTPUT_DIM, EMBEDDING_SIZE)
BSZ = BATCH_SIZE_PER_CORE * NUM_DEVICES
INPUT_TENSOR = torch.randint(0, VOCAB_SIZE-1, (BSZ,))
def compare_weights_after_one_step(index):
# Send models to device
device = xm.xla_device()
model_w_distr_emb = MODEL_W_DISTR_EMB.to(device)
standard_model = STANDARD_MODEL.to(device)
# Initialize weights
model_w_distr_emb.initialize_weights(LINTABLE, EMBTABLE)
standard_model.initialize_weights(LINTABLE, EMBTABLE)
# Create optimizers
standard_optimizer = torch.optim.SGD(
standard_model.parameters(), lr=LR*NUM_DEVICES)
# The distributed embedding needs to have its own optimizer, because
# the embedding table is sharded and we do not want gradient reduction
# happening across all cores.
distr_emb_optimizer = torch.optim.SGD(
model_w_distr_emb.embedding.parameters(), lr=LR)
linear_optimizer = torch.optim.SGD(
model_w_distr_emb.linear.parameters(), lr=LR)
# Prepare input
input_tensor = INPUT_TENSOR.to(device)
#input_tensor = torch.randint(0, VOCAB_SIZE-1, (bsz, SEQUENCE_LEN)).to(device)
for step in range(NUM_STEPS):
# Forward pass
standard_optimizer.zero_grad()
distr_emb_optimizer.zero_grad()
linear_optimizer.zero_grad()
# Standard model processes per core batches as usual
slice_i = xm.get_ordinal() * BATCH_SIZE_PER_CORE
output_st = standard_model(
#input_tensor[slice_i:slice_i+BATCH_SIZE_PER_CORE, :])
input_tensor[slice_i:slice_i+BATCH_SIZE_PER_CORE])
target = torch.randn(BATCH_SIZE_PER_CORE, OUTPUT_DIM).to(device)
# Model w/ distr embeddings processes the full global batch
# We also return the embedding layer's outputs in order to perform necessary
# operations during backward.
output_distr, fewsamples_fullemb, fullsamples_slicedemb = model_w_distr_emb(input_tensor)
# Backprop - Partial
loss_distr = ((output_distr - target)**2).mean()
loss_st = ((output_st - target)**2).mean()
loss_distr.backward() # this only propagates up to the embeddings.
loss_st.backward()
# Gradient updates and rest of Backprop
# All-Reduce and update the standard model as usual.
xm.optimizer_step(standard_optimizer)
# All-Reduce and update the distributed embedding model's
# linear optimizer as usual.
xm.optimizer_step(linear_optimizer)
# Now do the distributed embeddings backprop and update.
model_w_distr_emb.emb_backward(fullsamples_slicedemb, fewsamples_fullemb.grad)
distr_emb_optimizer.step()
# Mark the step and materialize tensors
xm.mark_step()
# Finally, let's look at the loss values and the weight differences
xm.all_reduce('sum', [loss_distr])
xm.all_reduce('sum', [loss_st])
xm.master_print('Step {} Summary:'.format(step))
xm.master_print(
'\tLoss: standard {:.4f}, distr-emb {:.4f}, delta is {:.4f}'.format(
loss_st.item(), loss_distr.item(), (loss_st-loss_distr).item()))
xm.master_print(
'\tLoss: standard {:.4f}, distr-emb {:.4f}, delta is {:.4f}'.format(
loss_st.item(), loss_distr.item(), (loss_st-loss_distr).item()))
xm.master_print(
'\tEmbedding Gradient summary: {}'.format(
summarize_tensor(standard_model.embedding.weight.cpu()-EMBTABLE)))
delta_embw = xm.all_gather(
model_w_distr_emb.embedding.embeddings.weight, dim=-1)
delta_embw = delta_embw.cpu() - standard_model.embedding.weight.cpu()
xm.master_print(
'\tEmbedding Difference b/w standard and distributed: {}'.format(
summarize_tensor(delta_embw)))
if __name__ == '__main__':
xmp.spawn(
compare_weights_after_one_step,
args=(), nprocs=NUM_DEVICES, start_method='fork')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment