Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Last active June 25, 2020 00:26
Show Gist options
  • Save taylanbil/ce2221aeb2ae3c1572486ec42cebbfc3 to your computer and use it in GitHub Desktop.
Save taylanbil/ce2221aeb2ae3c1572486ec42cebbfc3 to your computer and use it in GitHub Desktop.
git diff HEAD~1
diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py
index 6cbe382..6f1c849 100644
--- a/dlrm_data_pytorch.py
+++ b/dlrm_data_pytorch.py
@@ -266,7 +266,7 @@ class CriteoDataset(Dataset):
if self.memory_map:
if self.split == 'none' or self.split == 'train':
- # check if need to swicth to next day and load data
+ # check if need to switch to next day and load data
if index == self.offset_per_file[self.day]:
# print("day_boundary switch", index)
self.day_boundary = self.offset_per_file[self.day]
@@ -503,7 +503,7 @@ def make_criteo_data_and_loaders(args):
num_workers=args.num_workers,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
- drop_last=False, # True
+ drop_last=args.drop_last,
)
test_loader = torch.utils.data.DataLoader(
@@ -513,7 +513,7 @@ def make_criteo_data_and_loaders(args):
num_workers=args.test_num_workers,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
- drop_last=False, # True
+ drop_last=args.drop_last,
)
return train_data, train_loader, test_data, test_loader
@@ -627,7 +627,9 @@ def collate_wrapper_random(list_of_tuples):
T)
-def make_random_data_and_loader(args, ln_emb, m_den):
+def make_random_data_and_loader(
+ args, ln_emb, m_den, n_replicas=None, rank=None,
+):
train_data = RandomDataset(
m_den,
@@ -645,6 +647,14 @@ def make_random_data_and_loader(args, ln_emb, m_den):
reset_seed_on_access=True,
rand_seed=args.numpy_rand_seed
) # WARNING: generates a batch of lookups at once
+ train_sampler = None
+ if rank is not None:
+ torch.utils.data.distributed.DistributedSampler(
+ train_data,
+ num_replicas=n_replicas,
+ rank=rank,
+ shuffle=True,
+ )
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=1,
@@ -652,7 +662,8 @@ def make_random_data_and_loader(args, ln_emb, m_den):
num_workers=args.num_workers,
collate_fn=collate_wrapper_random,
pin_memory=False,
- drop_last=False, # True
+ drop_last=args.drop_last,
+ sampler=train_sampler,
)
return train_data, train_loader
@@ -764,7 +775,9 @@ def generate_uniform_input_batch(
)
# sparse indices to be used per embedding
r = ra.random(sparse_group_size)
- sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64))
+ # XXX: why np.unique ? This is producing a different input shape..
+ #sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64))
+ sparse_group = np.round(r * (size - 1)).astype(np.int64)
# reset sparse_group_size in case some index duplicates were removed
sparse_group_size = np.int64(sparse_group.size)
# store lengths and indices
diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py
index 1955bb9..1fef8b0 100644
--- a/dlrm_s_pytorch.py
+++ b/dlrm_s_pytorch.py
@@ -58,6 +58,7 @@ import builtins
import functools
# import bisect
# import shutil
+import sys
import time
import json
# data generation
@@ -93,10 +94,14 @@ import sklearn.metrics
from torch.optim.lr_scheduler import _LRScheduler
+
exc = getattr(builtins, "IOError", "FileNotFoundError")
class LRPolicyScheduler(_LRScheduler):
- def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
+ def __init__(
+ self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps,
+ use_tpu=False, local_optimizer=None,
+ ):
self.num_warmup_steps = num_warmup_steps
self.decay_start_step = decay_start_step
self.decay_end_step = decay_start_step + num_decay_steps
@@ -105,6 +110,9 @@ class LRPolicyScheduler(_LRScheduler):
if self.decay_start_step < self.num_warmup_steps:
sys.exit("Learning rate warmup must finish before the decay starts")
+ self.use_tpu = use_tpu
+ self.sparse_feature_local_optimizer = local_optimizer
+
super(LRPolicyScheduler, self).__init__(optimizer)
def get_lr(self):
@@ -129,8 +137,18 @@ class LRPolicyScheduler(_LRScheduler):
lr = self.base_lrs
return lr
+ def step(self):
+ super().step()
+ if self.sparse_feature_local_optimizer is not None:
+ # XXX: is lr always a single value?
+ lr = self.get_lr()[0]
+ for param_group in self.sparse_feature_local_optimizer.param_groups:
+ param_group['lr'] = lr
+
+
### define dlrm in PyTorch ###
class DLRM_Net(nn.Module):
+
def create_mlp(self, ln, sigmoid_layer):
# build MLP layer by layer
layers = nn.ModuleList()
@@ -175,11 +193,26 @@ class DLRM_Net(nn.Module):
emb_l = nn.ModuleList()
for i in range(0, ln.size):
n = ln[i]
+ if (
+ self.use_tpu and
+ self.ndevices > 1 and
+ not self._tpu_index_belongs_to_ordinal(i)
+ ):
+ # tpu model-parallel mode. only create this ordinal's tables.
+ continue
# construct embedding operator
if self.qr_flag and n > self.qr_threshold:
- EE = QREmbeddingBag(n, m, self.qr_collisions,
- operation=self.qr_operation, mode="sum", sparse=True)
+ # XXX: code path not hit with current tpu tests.
+ assert not self.use_tpu, 'not implemented'.upper()
+ EE = QREmbeddingBag(
+ n, m, self.qr_collisions,
+ operation=self.qr_operation,
+ mode="sum", sparse=self.sparse,
+ xla=self.use_tpu,
+ )
elif self.md_flag:
+ assert not self.use_tpu, 'not implemented'.upper()
+ # XXX: code path not hit with current tpu tests.
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
@@ -190,7 +223,21 @@ class DLRM_Net(nn.Module):
EE.embs.weight.data = torch.tensor(W, requires_grad=True)
else:
- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)
+ if self.use_tpu:
+ # XXX: xla currently does not support `nn.EmbeddingBag`
+ # so I wrote a custom thing.
+ from tools.xla_embedding_bag import XlaEmbeddingBag
+ errmsg = (
+ '`use_tpu` was specified. XLA currently only supports '
+ 'fixed length sparse groups.'
+ )
+ assert self.offset is not None, errmsg
+ EE = XlaEmbeddingBag(
+ n, m, mode="sum",
+ sparse=self.sparse, offset=self.offset,
+ )
+ else:
+ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse)
# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
@@ -208,6 +255,25 @@ class DLRM_Net(nn.Module):
return emb_l
+ def set_xla_replica_groups(self, groups):
+ self._xla_replica_groups = groups
+ for g in groups:
+ if self._ordinal in g:
+ self._xla_replica_group = g
+ assert self.ndevices == len(g), \
+ 'replica group does not match ndevices, {} vs {}'.format(
+ len(g), self.ndevices
+ )
+ self._xla_replica_index = g.index(self._ordinal)
+ break
+ else:
+ raise ValueError(
+ 'Ordinal {} not in replica groups!'.format(self._ordinal)
+ )
+
+ def _tpu_index_belongs_to_ordinal(self, i):
+ return i % len(self._xla_replica_group) == self._xla_replica_index
+
def __init__(
self,
m_spa=None,
@@ -227,6 +293,9 @@ class DLRM_Net(nn.Module):
qr_threshold=200,
md_flag=False,
md_threshold=200,
+ sparse=True,
+ use_tpu=False,
+ offset=None
):
super(DLRM_Net, self).__init__()
@@ -247,6 +316,9 @@ class DLRM_Net(nn.Module):
self.arch_interaction_itself = arch_interaction_itself
self.sync_dense_params = sync_dense_params
self.loss_threshold = loss_threshold
+ self.sparse = (not use_tpu) and sparse
+ self.use_tpu = use_tpu
+ self.offset = offset
# create variables for QR embedding if applicable
self.qr_flag = qr_flag
if self.qr_flag:
@@ -256,13 +328,43 @@ class DLRM_Net(nn.Module):
# create variables for MD embedding if applicable
self.md_flag = md_flag
if self.md_flag:
+ if use_tpu:
+ raise NotImplementedError(
+ """
+ XXX:
+ md trick produces mixed shape inputs => recompilations
+ on tpu. Document the tradeoff after experimenting.
+ """
+ )
self.md_threshold = md_threshold
+ if self.use_tpu:
+ self._init_tpu()
# create operators
if ndevices <= 1:
self.emb_l = self.create_emb(m_spa, ln_emb)
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
self.top_l = self.create_mlp(ln_top, sigmoid_top)
+ def _init_tpu(self):
+ # no need to be ordinal aware in full data parallel mode.
+ if self.ndevices > 1:
+ import torch_xla.core.xla_model as xm
+ self._ordinal = xm.get_ordinal()
+ self._local_ordinal = xm.get_local_ordinal()
+ self._all_reduce = xm.all_reduce
+ self._all_gather = xm.all_gather
+
+ def _filter_params(self, f):
+ for name, p in self.named_parameters():
+ if f(name):
+ yield p
+
+ def mlp_parameters(self):
+ return self._filter_params(lambda name: not name.startswith('emb_l'))
+
+ def emb_parameters(self):
+ return self._filter_params(lambda name: name.startswith('emb_l'))
+
def apply_mlp(self, x, layers):
# approach 1: use ModuleList
# for layer in layers:
@@ -332,9 +434,17 @@ class DLRM_Net(nn.Module):
def forward(self, dense_x, lS_o, lS_i):
if self.ndevices <= 1:
return self.sequential_forward(dense_x, lS_o, lS_i)
+ elif self.use_tpu:
+ return self.tpu_parallel_forward(dense_x, lS_o, lS_i)
else:
return self.parallel_forward(dense_x, lS_o, lS_i)
+ def clamp_output(self, p):
+ z = p
+ if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
+ z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
+ return z
+
def sequential_forward(self, dense_x, lS_o, lS_i):
# process dense features (using bottom mlp), resulting in a row vector
x = self.apply_mlp(dense_x, self.bot_l)
@@ -355,13 +465,84 @@ class DLRM_Net(nn.Module):
p = self.apply_mlp(z, self.top_l)
# clamp output if needed
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
- z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
- else:
- z = p
-
+ z = self.clamp_output(p)
return z
+ def _partition_to_device(self, iterable):
+ return [
+ obj for k, obj in enumerate(iterable)
+ if self._tpu_index_belongs_to_ordinal(k)
+ ]
+
+ def _gather_other_embeddings(self, ordinal_data):
+ x = iter(ordinal_data)
+ ordinal_data = torch.stack(ordinal_data)
+ full_data = self._all_gather(
+ ordinal_data, dim=0, groups=self._xla_replica_groups
+ )
+ return full_data
+
+ def _narrow(self, local_bsz, tensor, dim=1):
+ return torch.narrow(
+ tensor, dim, self._xla_replica_index*local_bsz, local_bsz
+ )
+
+ def _gather_other_samples(self, array, dim=0):
+ out = torch.stack(array)
+ # dim+1 because stack introduces a dimension to the left
+ out = self._all_gather(out, dim=dim+1, groups=self._xla_replica_groups)
+ return out
+
+ def tpu_parallel_forward(self, dense_x, lS_o, lS_i):
+ batch_size = dense_x.size()[0]
+ ndevices = self.ndevices
+ assert not batch_size % ndevices, \
+ f"{batch_size} is bsz, {ndevices} devices"
+ local_bsz = batch_size // ndevices
+ # XXX: no redistribute model if bsz changes for tpus. is this ok?
+ # this is bc all weights are updated at all times.
+ # I think on gpus, it updates `ndevices` many, which can fluctuate
+ # w/ changes in bsz.
+
+ dense_x = self._narrow(local_bsz, dense_x, dim=0)
+
+ #bottom mlp
+ x = self.bot_l(dense_x)
+ # embeddings
+ lS_o = self._partition_to_device(lS_o)
+ lS_i = self._partition_to_device(lS_i)
+ ly_local = self.apply_emb(lS_o, lS_i, self.emb_l)
+
+ # at this point, each device have the embeddings belonging to itself.
+ # we do a gather to acquire all embeddings, i.e. full input.
+ # followed by a `narrow`, so rest of the model can run data-parallel.
+ # XXX: pods? will gather collect from all devices? option to collect only locally?
+ ly = self._gather_other_embeddings(ly_local)
+ # _gather introduces a dim, so batch dim is now 1
+ ly = self._narrow(local_bsz, ly, dim=1)
+ # now stop gradients from flowing back.
+ ly = [_.clone().detach().requires_grad_(True) for _ in ly]
+
+ # interactions
+ z = self.interact_features(x, ly)
+
+ # top mlp
+ p = self.top_l(z)
+
+ # clamp output if needed
+ z = self.clamp_output(p)
+
+ return z, ly_local, ly # extra return args needed during bwd.
+
+ def tpu_local_backward(self, fullbatch_localembs, localbatch_fullembs):
+ localbatch_fullgrads = [_.grad for _ in localbatch_fullembs]
+ grad = self._gather_other_samples(localbatch_fullgrads) # inv to narrow
+ grad = self._partition_to_device(grad)
+ assert len(fullbatch_localembs) == len(grad), \
+ '{} vs {}'.format(len(fullbatch_localembs), len(grad))
+ for e, g in zip(fullbatch_localembs, grad):
+ e.backward(g)
+
def parallel_forward(self, dense_x, lS_o, lS_i):
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
@@ -461,22 +642,14 @@ class DLRM_Net(nn.Module):
p0 = gather(p, self.output_d, dim=0)
# clamp output if needed
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
- z0 = torch.clamp(
- p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
- )
- else:
- z0 = p0
+ z0 = self.clamp_output(p0)
return z0
-if __name__ == "__main__":
- ### import packages ###
- import sys
+def parse_args():
import argparse
- ### parse arguments ###
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
@@ -503,22 +676,24 @@ if __name__ == "__main__":
parser.add_argument("--loss-weights", type=str, default="1.0-1.0") # for wbce
parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7
parser.add_argument("--round-targets", type=bool, default=False)
+ parser.add_argument("--pred-threshold", type=float, default=0.5)
# data
parser.add_argument("--data-size", type=int, default=1)
parser.add_argument("--num-batches", type=int, default=0)
parser.add_argument(
"--data-generation", type=str, default="random"
) # synthetic or dataset
+ parser.add_argument("--drop-last", action="store_true")
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log")
parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte
parser.add_argument("--raw-data-file", type=str, default="")
parser.add_argument("--processed-data-file", type=str, default="")
parser.add_argument("--data-randomize", type=str, default="total") # or day or none
- parser.add_argument("--data-trace-enable-padding", type=bool, default=False)
+ parser.add_argument("--data-trace-enable-padding", action='store_true')
parser.add_argument("--max-ind-range", type=int, default=-1)
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1]
parser.add_argument("--num-indices-per-lookup", type=int, default=10)
- parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False)
+ parser.add_argument("--num-indices-per-lookup-fixed", action='store_true')
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--memory-map", action="store_true", default=False)
# training
@@ -532,8 +707,11 @@ if __name__ == "__main__":
parser.add_argument("--inference-only", action="store_true", default=False)
# onnx
parser.add_argument("--save-onnx", action="store_true", default=False)
- # gpu
+ # accelerators
parser.add_argument("--use-gpu", action="store_true", default=False)
+ parser.add_argument("--use-tpu", action="store_true", default=False)
+ parser.add_argument("--tpu-model-parallel-group-len", type=int, default=1)
+ parser.add_argument("--tpu-data-parallel", action="store_true")
# debugging and profiling
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
@@ -558,16 +736,7 @@ if __name__ == "__main__":
parser.add_argument("--lr-num-warmup-steps", type=int, default=0)
parser.add_argument("--lr-decay-start-step", type=int, default=0)
parser.add_argument("--lr-num-decay-steps", type=int, default=0)
- args = parser.parse_args()
-
- if args.mlperf_logging:
- print('command line args: ', json.dumps(vars(args)))
-
- ### some basic setup ###
- np.random.seed(args.numpy_rand_seed)
- np.set_printoptions(precision=args.print_precision)
- torch.set_printoptions(precision=args.print_precision)
- torch.manual_seed(args.numpy_rand_seed)
+ args, _ = parser.parse_known_args()
if (args.test_mini_batch_size < 0):
# if the parameter is not set, use the training batch size
@@ -576,8 +745,73 @@ if __name__ == "__main__":
# if the parameter is not set, use the same parameter for training
args.test_num_workers = args.num_workers
+ return args
+
+
+def tpu_get_xla_replica_groups(args):
+ import torch_xla.core.xla_model as xm
+ # XXX: would it make sense to sort by size here -- to evenly dist. tables?
+ world_size = xm.xrt_world_size()
+ if args.tpu_data_parallel:
+ return [[i] for i in range(world_size)], [list(range(world_size))]
+ num_tables = args.arch_embedding_size.count("-") + 1
+ len_mp_group = args.tpu_model_parallel_group_len
+ assert not num_tables % len_mp_group, \
+ 'Model parallel group size has to divide number of emb tables evenly.'
+ assert not world_size % len_mp_group, \
+ 'Length of model parallel groups has to evenly divide `xrt_world_size`'
+ len_dp_group = world_size // len_mp_group
+ mp_groups = [
+ [len_mp_group*d+m for m in range(len_mp_group)]
+ for d in range(len_dp_group)
+ ]
+ dp_groups = [
+ [len_dp_group*d+m for m in range(len_dp_group)]
+ for d in range(len_mp_group)
+ ]
+ return mp_groups, dp_groups
+
+
+def main(*_args):
+ ### import packages ###
+ import sys
+
+ args = parse_args()
+ ### some basic setup ###
+ np.random.seed(args.numpy_rand_seed)
+ np.set_printoptions(precision=args.print_precision)
+ torch.set_printoptions(precision=args.print_precision)
+ # XXX: does this imply we init the emb tables the same way?
+ torch.manual_seed(args.numpy_rand_seed)
+
use_gpu = args.use_gpu and torch.cuda.is_available()
- if use_gpu:
+ use_tpu = args.use_tpu
+ print = builtins.print
+ if use_tpu:
+ use_gpu = False
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.distributed.parallel_loader as pl
+ print = xm.master_print
+ device = xm.xla_device()
+ print("Using {} TPU core(s)...".format(xm.xrt_world_size()))
+ if args.enable_profiling:
+ print("Profiling was enabled. Turning it off for TPUs.")
+ args.enable_profiling = False
+ if not args.num_indices_per_lookup_fixed:
+ # XXX: does this lead to recompilations?
+ raise NotImplementedError
+ mp_replica_groups, dp_replica_groups = tpu_get_xla_replica_groups(args)
+ print('XLA replica groups for Model Parallel:\n\t', mp_replica_groups)
+ print('XLA replica groups for Model Parallel:\n\t', dp_replica_groups)
+ if len(dp_replica_groups) == 1:
+ # i.e. no allgather etc in the emb layer.
+ print("TPU data-parallel mode, setting --tpu-data-parallel to True")
+ args.tpu_data_parallel = True
+ else:
+ print("TPU model-parallel mode, setting --drop-last=True")
+ args.drop_last = True
+ elif use_gpu:
torch.cuda.manual_seed_all(args.numpy_rand_seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda", 0)
@@ -587,6 +821,9 @@ if __name__ == "__main__":
device = torch.device("cpu")
print("Using CPU...")
+ if args.mlperf_logging:
+ print('command line args: ', json.dumps(vars(args)))
+
### prepare training data ###
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
# input data
@@ -610,8 +847,26 @@ if __name__ == "__main__":
# input and target at random
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_den = ln_bot[0]
- train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den)
- nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
+ data_args = [args, ln_emb, m_den]
+ if use_tpu:
+ ordinal = xm.get_ordinal()
+ for g in dp_replica_groups:
+ if ordinal in g:
+ n_replicas, rank = len(g), g.index(ordinal)
+ break
+ else:
+ raise ValueError(
+ 'Ordinal {} not in replica groups!'.format(self._ordinal)
+ )
+ data_args.extend([n_replicas, rank]) # extend w/ n_replicas and rank
+ train_data, train_ld = dp.make_random_data_and_loader(*data_args)
+ #nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
+ nbatches = len(train_ld)
+
+ if use_tpu:
+ # XXX: test_data is unused.
+ # Wrap w/ torch_xla's loader
+ train_ld = pl.MpDeviceLoader(train_ld, device)
### parse command line arguments ###
m_spa = args.arch_sparse_feature_size
@@ -737,7 +992,26 @@ if __name__ == "__main__":
print([S_i.detach().cpu().tolist() for S_i in lS_i])
print(T.detach().cpu().numpy())
- ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1
+ ndevices = -1
+ if use_tpu:
+ ndevices = len(dp_replica_groups)
+ # XXX: it could work when the following are violated too.
+ # TODO: implement that.
+ if args.mini_batch_size % ndevices:
+ raise NotImplementedError(
+ 'bsz is {}, ndevices is {}'.format(
+ args.mini_batch_size, ndevices,
+ )
+ )
+ if (num_fea - 1) % ndevices:
+ raise NotImplementedError(
+ 'num embtables is {}, ndevices is {}'.format(
+ num_fea-1, ndevices,
+ )
+ )
+
+ elif use_gpu:
+ ndevices = min(ngpus, args.mini_batch_size, num_fea - 1)
### construct the neural network specified above ###
# WARNING: to obtain exactly the same initialization for
@@ -761,6 +1035,9 @@ if __name__ == "__main__":
qr_threshold=args.qr_threshold,
md_flag=args.md_flag,
md_threshold=args.md_threshold,
+ sparse=device.type != 'xla',
+ use_tpu=use_tpu,
+ offset=args.num_indices_per_lookup,
)
# test prints
if args.debug_mode:
@@ -776,6 +1053,13 @@ if __name__ == "__main__":
dlrm = dlrm.to(device) # .cuda()
if dlrm.ndevices > 1:
dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb)
+ if use_tpu:
+ # XXX: ndevices is redundant, but meh.
+ if dlrm.ndevices > 1:
+ dlrm.set_xla_replica_groups(mp_replica_groups)
+ dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb)
+ dlrm.device = device
+ dlrm = dlrm.to(device)
# specify the loss function
if args.loss_function == "mse":
@@ -790,11 +1074,34 @@ if __name__ == "__main__":
if not args.inference_only:
# specify the optimizer algorithm
- optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate)
- lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
- args.lr_num_decay_steps)
+
+ if use_tpu and not args.tpu_data_parallel:
+ # tpu's paradigm is different than gpu/cpu. Each process here runs
+ # on its own and all reduces need to happen in a particular way.
+ # Data parallel part, i.e. the MLP part will be allreduced.
+ # EmbeddingBag part will not be allreduced.
+ optimizer = torch.optim.SGD(
+ dlrm.mlp_parameters(), lr=args.learning_rate
+ )
+ emb_local_optimizer = torch.optim.SGD(
+ dlrm.emb_parameters(), lr=args.learning_rate
+ )
+ lr_scheduler = LRPolicyScheduler(
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
+ args.lr_num_decay_steps, use_tpu=True,
+ local_optimizer=emb_local_optimizer,
+ )
+ else:
+ optimizer = torch.optim.SGD(
+ dlrm.parameters(), lr=args.learning_rate
+ )
+ lr_scheduler = LRPolicyScheduler(
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
+ args.lr_num_decay_steps
+ )
### main loop ###
+
def time_wrap(use_gpu):
if use_gpu:
torch.cuda.synchronize()
@@ -808,15 +1115,16 @@ if __name__ == "__main__":
else lS_i.to(device)
lS_o = [S_o.to(device) for S_o in lS_o] if isinstance(lS_o, list) \
else lS_o.to(device)
- return dlrm(
- X.to(device),
- lS_o,
- lS_i
- )
+ return dlrm(X.to(device), lS_o, lS_i)
else:
return dlrm(X, lS_o, lS_i)
def loss_fn_wrap(Z, T, use_gpu, device):
+ if T.size(0) > Z.size(0):
+ # This happens for tpus.
+ # Target tensor is likely for global batch. Narrow to this device.
+ T = dlrm._narrow(Z.size(0), T, dim=0)
+
if args.loss_function == "mse" or args.loss_function == "bce":
if use_gpu:
return loss_fn(Z, T.to(device))
@@ -848,9 +1156,12 @@ if __name__ == "__main__":
k = 0
# Load model is specified
- if not (args.load_model == ""):
+ if args.load_model != "":
print("Loading saved model {}".format(args.load_model))
- if use_gpu:
+ if use_tpu:
+ # XXX: add tpu capabilities to load.
+ raise NotImplementedError('Add tpu capabilities to load.')
+ elif use_gpu:
if dlrm.ndevices > 1:
# NOTE: when targeting inference on multiple GPUs,
# load the model as is on CPU or GPU, with the move
@@ -907,6 +1218,7 @@ if __name__ == "__main__":
)
print("time/loss/accuracy (if enabled):")
+ # XXX: what is profiler? turning it off for tpus.
with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof:
while k < args.nepochs:
if k < skip_upto_epoch:
@@ -932,6 +1244,7 @@ if __name__ == "__main__":
t1 = time_wrap(use_gpu)
# early exit if nbatches was set by the user and has been exceeded
+ # XXX: what about j being reset every epoch?
if nbatches > 0 and j >= nbatches:
break
'''
@@ -944,8 +1257,22 @@ if __name__ == "__main__":
print(T.detach().cpu().numpy())
'''
+ # XXX: clean
+ """
+ print(
+ 'SHAPES',
+ X.shape,
+ [_.shape for _ in lS_o],
+ [_.shape for _ in lS_i],
+ )
+ """
# forward pass
- Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
+ if use_tpu and not args.tpu_data_parallel:
+ # args[1:] below will be used in the custom backward
+ Z, fullbatch_localembs, localbatch_fullembs = \
+ dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
+ else:
+ Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
# loss
E = loss_fn_wrap(Z, T, use_gpu, device)
@@ -955,12 +1282,15 @@ if __name__ == "__main__":
print(Z.detach().cpu().numpy())
print(E.detach().cpu().numpy())
'''
- # compute loss and accuracy
- L = E.detach().cpu().numpy() # numpy array
- S = Z.detach().cpu().numpy() # numpy array
- T = T.detach().cpu().numpy() # numpy array
+
mbs = T.shape[0] # = args.mini_batch_size except maybe for last
- A = np.sum((np.round(S, 0) == T).astype(np.uint8))
+ # XXX: conv related: T is a vector of floats here.
+ # args.round_targets related
+ # FIXME: figure out a way to do this on tpu
+ # S = Z.detach().cpu().numpy() # numpy array
+ # T = T.detach().cpu().numpy() # numpy array
+ #A = np.sum((np.round(S, 0) == T).astype(np.uint8))
+ A = 0
if not args.inference_only:
# scaled error gradient propagation
@@ -973,17 +1303,35 @@ if __name__ == "__main__":
# if hasattr(l, 'weight'):
# print(l.weight.grad.norm().item())
- # optimizer
- optimizer.step()
+ # backward + optimizer
+ if use_tpu and not args.tpu_data_parallel:
+ # XXX: no clip grad?
+ # Full allreduce across all devices for the MLP part.
+ # XXX: what about the bottom mlp? # FIXME
+ xm.optimizer_step(optimizer, groups=None)
+ # bwd pass for the embedding tables.
+ dlrm.tpu_local_backward(
+ fullbatch_localembs, localbatch_fullembs,
+ )
+ if len(dp_replica_groups[0]) > 1:
+ xm.optimizer_step(
+ emb_local_optimizer, groups=dp_replica_groups,
+ )
+ else:
+ # no allreduce, just step
+ emb_local_optimizer.step()
+ else:
+ optimizer.step()
lr_scheduler.step()
+
if args.mlperf_logging:
total_time += iteration_time
else:
t2 = time_wrap(use_gpu)
total_time += t2 - t1
total_accu += A
- total_loss += L * mbs
+ total_loss += E.detach() * mbs
total_iter += 1
total_samp += mbs
@@ -996,6 +1344,12 @@ if __name__ == "__main__":
# print time, loss and accuracy
if should_print or should_test:
+ # XXX: detach+cpu+numpy seems to avoid the "aten" counter!!!!
+ #L = E.detach().cpu().numpy() # numpy array
+ #S = Z.detach().cpu().numpy() # numpy array
+ #T = T.detach().cpu().numpy() # numpy array
+ if use_tpu:
+ xm.mark_step()
gT = 1000.0 * total_time / total_iter if args.print_time else -1
total_time = 0
@@ -1020,6 +1374,8 @@ if __name__ == "__main__":
# testing
if should_test and not args.inference_only:
+ # XXX: code path not hit currently
+ raise NotImplementedError('not hit')
# don't measure training iter time in a test iteration
if args.mlperf_logging:
previous_iteration_time = None
@@ -1234,3 +1590,7 @@ if __name__ == "__main__":
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx")
# check the onnx model
onnx.checker.check_model(dlrm_pytorch_onnx)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dlrm_tpu_runner.py b/dlrm_tpu_runner.py
new file mode 100644
index 0000000..d31aa6d
--- /dev/null
+++ b/dlrm_tpu_runner.py
@@ -0,0 +1,15 @@
+import sys
+import argparse
+
+import torch_xla.distributed.xla_multiprocessing as xmp
+
+from dlrm_s_pytorch import main, parse_args
+
+
+if __name__ == '__main__':
+ pre_spawn_parser = argparse.ArgumentParser()
+ pre_spawn_parser.add_argument(
+ "--tpu-cores", type=int, default=8, choices=[1, 8]
+ )
+ pre_spawn_flags, _ = pre_spawn_parser.parse_known_args()
+ xmp.spawn(main, args=(), nprocs=pre_spawn_flags.tpu_cores)
diff --git a/tools/xla_embedding_bag.py b/tools/xla_embedding_bag.py
new file mode 100644
index 0000000..c474bf1
--- /dev/null
+++ b/tools/xla_embedding_bag.py
@@ -0,0 +1,38 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+
+
+class XlaEmbeddingBag(nn.Module):
+ """
+ nn.EmbeddingBag is not lowered just yet to xla.
+ This performs the same functionality, in an xla compatible, sub-optimal way.
+
+ Warning!: only works with constant offsets atm.
+ """
+
+ def __init__(self, n, m, mode, offset, *args, **kwargs):
+ super(XlaEmbeddingBag, self).__init__()
+ self.n = n
+ self.m = m
+ self.mode = mode
+ self.offset = offset
+ self.embtable = nn.Embedding(n, m, *args, **kwargs)
+
+ def forward(self, sparse_index_group_batch, sparse_offset_group_batch):
+ emb = self.embtable(sparse_index_group_batch)
+ # XXX: only works w/ constant offset atm
+ bsz = emb.size(0) // self.offset
+ emb = emb.reshape(bsz, self.offset, *emb.size()[1:])
+ reduce_fn = getattr(torch, self.mode)
+ return reduce_fn(emb, axis=1)
+ #return reduce_fn(self.embtable(_) for _ in inp_list)
+
+ @property
+ def weight(self):
+ return self.embtable.weight
diff --git a/tricks/qr_embedding_bag.py b/tricks/qr_embedding_bag.py
index 290d795..941ac1b 100644
--- a/tricks/qr_embedding_bag.py
+++ b/tricks/qr_embedding_bag.py
@@ -112,7 +112,7 @@ class QREmbeddingBag(nn.Module):
def __init__(self, num_categories, embedding_dim, num_collisions,
operation='mult', max_norm=None, norm_type=2.,
scale_grad_by_freq=False, mode='mean', sparse=False,
- _weight=None):
+ _weight=None, xla=False):
super(QREmbeddingBag, self).__init__()
assert operation in ['concat', 'mult', 'add'], 'Not valid operation!'
@@ -127,6 +127,7 @@ class QREmbeddingBag(nn.Module):
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
+ self.xla = xla
if self.operation == 'add' or self.operation == 'mult':
assert self.embedding_dim[0] == self.embedding_dim[1], \
@@ -153,17 +154,32 @@ class QREmbeddingBag(nn.Module):
nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories))
nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories))
+ def _embed_input(self, input_q, input_r, offsets, per_sample_weights):
+ # XXX
+ if self.xla:
+ raise NotImplementedError('implement tricks later')
+
+ else:
+ embed_q = F.embedding_bag(
+ input_q, self.weight_q, offsets, self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.mode, self.sparse,
+ per_sample_weights
+ )
+ embed_r = F.embedding_bag(
+ input_r, self.weight_r, offsets, self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.mode, self.sparse,
+ per_sample_weights
+ )
+ return embed_q, embed_r
+
+
def forward(self, input, offsets=None, per_sample_weights=None):
input_q = (input / self.num_collisions).long()
input_r = torch.remainder(input, self.num_collisions).long()
-
- embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.mode,
- self.sparse, per_sample_weights)
- embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.mode,
- self.sparse, per_sample_weights)
-
+ embed_q, embed_r = self._embed_input(
+ input_q, input_r, offsets=offsets,
+ per_sample_weights=per_sample_weights,
+ )
if self.operation == 'concat':
embed = torch.cat((embed_q, embed_r), dim=1)
elif self.operation == 'add':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment