Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Created July 17, 2020 17:59
Show Gist options
  • Save taylanbil/092f993053527a9698638106f6b348a6 to your computer and use it in GitHub Desktop.
Save taylanbil/092f993053527a9698638106f6b348a6 to your computer and use it in GitHub Desktop.
$ git diff d45342e tpu-criteo-kaggle
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a81c8ee
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,138 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py
index 6cbe382..4297430 100644
--- a/dlrm_data_pytorch.py
+++ b/dlrm_data_pytorch.py
@@ -266,7 +266,7 @@ class CriteoDataset(Dataset):
if self.memory_map:
if self.split == 'none' or self.split == 'train':
- # check if need to swicth to next day and load data
+ # check if need to switch to next day and load data
if index == self.offset_per_file[self.day]:
# print("day_boundary switch", index)
self.day_boundary = self.offset_per_file[self.day]
@@ -376,7 +376,7 @@ def ensure_dataset_preprocessed(args, d_path):
split=split)
-def make_criteo_data_and_loaders(args):
+def make_criteo_data_and_loaders(args, n_replicas=None, rank=None):
if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
# more efficient for larger batches
@@ -496,6 +496,20 @@ def make_criteo_data_and_loaders(args):
args.memory_map
)
+ train_sampler, test_sampler = None, None
+ if rank is not None:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
+ train_data,
+ num_replicas=n_replicas,
+ rank=rank,
+ shuffle=True,
+ )
+ test_sampler = torch.utils.data.distributed.DistributedSampler(
+ test_data,
+ num_replicas=n_replicas,
+ rank=rank,
+ shuffle=False,
+ )
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=args.mini_batch_size,
@@ -503,7 +517,8 @@ def make_criteo_data_and_loaders(args):
num_workers=args.num_workers,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
- drop_last=False, # True
+ drop_last=args.drop_last,
+ sampler=train_sampler,
)
test_loader = torch.utils.data.DataLoader(
@@ -513,7 +528,8 @@ def make_criteo_data_and_loaders(args):
num_workers=args.test_num_workers,
collate_fn=collate_wrapper_criteo,
pin_memory=False,
- drop_last=False, # True
+ drop_last=args.drop_last,
+ sampler=test_sampler,
)
return train_data, train_loader, test_data, test_loader
@@ -627,7 +643,9 @@ def collate_wrapper_random(list_of_tuples):
T)
-def make_random_data_and_loader(args, ln_emb, m_den):
+def make_random_data_and_loader(
+ args, ln_emb, m_den, n_replicas=None, rank=None,
+):
train_data = RandomDataset(
m_den,
@@ -645,6 +663,14 @@ def make_random_data_and_loader(args, ln_emb, m_den):
reset_seed_on_access=True,
rand_seed=args.numpy_rand_seed
) # WARNING: generates a batch of lookups at once
+ train_sampler = None
+ if rank is not None:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
+ train_data,
+ num_replicas=n_replicas,
+ rank=rank,
+ shuffle=True,
+ )
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=1,
@@ -652,7 +678,8 @@ def make_random_data_and_loader(args, ln_emb, m_den):
num_workers=args.num_workers,
collate_fn=collate_wrapper_random,
pin_memory=False,
- drop_last=False, # True
+ drop_last=args.drop_last,
+ sampler=train_sampler,
)
return train_data, train_loader
@@ -764,7 +791,9 @@ def generate_uniform_input_batch(
)
# sparse indices to be used per embedding
r = ra.random(sparse_group_size)
- sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64))
+ # XXX: why np.unique ? This is producing a different input shape..
+ #sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64))
+ sparse_group = np.round(r * (size - 1)).astype(np.int64)
# reset sparse_group_size in case some index duplicates were removed
sparse_group_size = np.int64(sparse_group.size)
# store lengths and indices
diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py
index 1955bb9..c64a9c0 100644
--- a/dlrm_s_pytorch.py
+++ b/dlrm_s_pytorch.py
@@ -58,6 +58,8 @@ import builtins
import functools
# import bisect
# import shutil
+import sys
+from datetime import datetime
import time
import json
# data generation
@@ -93,10 +95,14 @@ import sklearn.metrics
from torch.optim.lr_scheduler import _LRScheduler
+
exc = getattr(builtins, "IOError", "FileNotFoundError")
class LRPolicyScheduler(_LRScheduler):
- def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
+ def __init__(
+ self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps,
+ use_tpu=False, local_optimizer=None,
+ ):
self.num_warmup_steps = num_warmup_steps
self.decay_start_step = decay_start_step
self.decay_end_step = decay_start_step + num_decay_steps
@@ -105,6 +111,9 @@ class LRPolicyScheduler(_LRScheduler):
if self.decay_start_step < self.num_warmup_steps:
sys.exit("Learning rate warmup must finish before the decay starts")
+ self.use_tpu = use_tpu
+ self.sparse_feature_local_optimizer = local_optimizer
+
super(LRPolicyScheduler, self).__init__(optimizer)
def get_lr(self):
@@ -129,8 +138,18 @@ class LRPolicyScheduler(_LRScheduler):
lr = self.base_lrs
return lr
+ def step(self):
+ super().step()
+ if self.sparse_feature_local_optimizer is not None:
+ # XXX: is lr always a single value list?
+ lr = self.get_lr()[0]
+ for param_group in self.sparse_feature_local_optimizer.param_groups:
+ param_group['lr'] = lr
+
+
### define dlrm in PyTorch ###
class DLRM_Net(nn.Module):
+
def create_mlp(self, ln, sigmoid_layer):
# build MLP layer by layer
layers = nn.ModuleList()
@@ -171,15 +190,61 @@ class DLRM_Net(nn.Module):
# approach 2: use Sequential container to wrap all layers
return torch.nn.Sequential(*layers)
+ def _set_up_sparse_feature_info(self, ln):
+ self._group_table_count = ln.size
+ self._device_table_count = sum(
+ self._tpu_index_belongs_to_ordinal(i)
+ for i in range(ln.size)
+ )
+ self._pad_embedding_lookup = (
+ len(self._xla_replica_group) * self._device_table_count <
+ self._group_table_count
+ )
+ self._max_device_table_count_in_group = int(np.ceil(
+ self._group_table_count / len(self._xla_replica_group)
+ ))
+ self._table_count_padded = (
+ len(self._xla_replica_group) *
+ self._max_device_table_count_in_group
+ )
+ self._pad_indices = set(
+ (
+ self._table_count_padded - 1 -
+ i*self._max_device_table_count_in_group
+ )
+ for i in range(self._table_count_padded - self._group_table_count)
+ )
+ self._non_pad_indices = [
+ i for i in range(self._table_count_padded)
+ if i not in self._pad_indices
+ ]
+
def create_emb(self, m, ln):
emb_l = nn.ModuleList()
+ if self.use_tpu and self.ndevices > 1:
+ self._set_up_sparse_feature_info(ln)
for i in range(0, ln.size):
n = ln[i]
+ if (
+ self.use_tpu and
+ self.ndevices > 1 and
+ not self._tpu_index_belongs_to_ordinal(i)
+ ):
+ # tpu model-parallel mode. only create this ordinal's tables.
+ continue
# construct embedding operator
if self.qr_flag and n > self.qr_threshold:
- EE = QREmbeddingBag(n, m, self.qr_collisions,
- operation=self.qr_operation, mode="sum", sparse=True)
+ assert not self.use_tpu, \
+ 'QR trick not implemented for tpus'.upper()
+ EE = QREmbeddingBag(
+ n, m, self.qr_collisions,
+ operation=self.qr_operation,
+ mode="sum", sparse=self.sparse,
+ xla=self.use_tpu,
+ )
elif self.md_flag:
+ assert not self.use_tpu, \
+ 'MD trick not implemented for tpus'.upper()
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
@@ -190,7 +255,20 @@ class DLRM_Net(nn.Module):
EE.embs.weight.data = torch.tensor(W, requires_grad=True)
else:
- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)
+ if self.use_tpu:
+ # TODO: remove when xla supports `nn.EmbeddingBag`
+ from tools.xla_embedding_bag import XlaEmbeddingBag
+ errmsg = (
+ '`use_tpu` was specified. XLA currently only supports '
+ 'fixed length sparse groups.'
+ )
+ assert self.offset is not None, errmsg
+ EE = XlaEmbeddingBag(
+ n, m, mode="sum",
+ sparse=self.sparse, offset=self.offset,
+ )
+ else:
+ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse)
# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
@@ -208,6 +286,26 @@ class DLRM_Net(nn.Module):
return emb_l
+ def set_xla_replica_groups(self, groups):
+ self._xla_replica_groups = groups
+ for g in groups:
+ if self._ordinal in g:
+ self._xla_replica_group = g
+ assert self.ndevices == len(g), \
+ 'MP replica group does not match ndevices, {} vs {}'.format(
+ len(g), self.ndevices
+ )
+ self._xla_replica_index = g.index(self._ordinal)
+ break
+ else:
+ raise ValueError(
+ 'Ordinal {} not in replica groups!'.format(self._ordinal)
+ )
+
+ def _tpu_index_belongs_to_ordinal(self, i, ordinal=None):
+ ordinal = self._xla_replica_index if ordinal is None else ordinal
+ return i % len(self._xla_replica_group) == ordinal
+
def __init__(
self,
m_spa=None,
@@ -227,6 +325,9 @@ class DLRM_Net(nn.Module):
qr_threshold=200,
md_flag=False,
md_threshold=200,
+ sparse=True,
+ use_tpu=False,
+ offset=None
):
super(DLRM_Net, self).__init__()
@@ -247,6 +348,9 @@ class DLRM_Net(nn.Module):
self.arch_interaction_itself = arch_interaction_itself
self.sync_dense_params = sync_dense_params
self.loss_threshold = loss_threshold
+ self.sparse = (not use_tpu) and sparse
+ self.use_tpu = use_tpu
+ self.offset = offset
# create variables for QR embedding if applicable
self.qr_flag = qr_flag
if self.qr_flag:
@@ -257,12 +361,35 @@ class DLRM_Net(nn.Module):
self.md_flag = md_flag
if self.md_flag:
self.md_threshold = md_threshold
+ if self.use_tpu:
+ self._init_tpu()
# create operators
if ndevices <= 1:
self.emb_l = self.create_emb(m_spa, ln_emb)
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
self.top_l = self.create_mlp(ln_top, sigmoid_top)
+ def _init_tpu(self):
+ # no need to be ordinal aware in full data parallel mode.
+ if self.ndevices > 1:
+ import torch_xla.core.xla_model as xm
+ self._ordinal = xm.get_ordinal()
+ self._all_reduce = xm.all_reduce
+ self._all_gather = xm.all_gather
+ #self._all_to_all = xm.all_to_all
+ #self._mark_step = xm.mark_step
+
+ def _filter_params(self, f):
+ for name, p in self.named_parameters():
+ if f(name):
+ yield p
+
+ def mlp_parameters(self):
+ return self._filter_params(lambda name: not name.startswith('emb_l'))
+
+ def emb_parameters(self):
+ return self._filter_params(lambda name: name.startswith('emb_l'))
+
def apply_mlp(self, x, layers):
# approach 1: use ModuleList
# for layer in layers:
@@ -292,6 +419,13 @@ class DLRM_Net(nn.Module):
ly.append(V)
+ if self.use_tpu and self.ndevices > 1 and self._pad_embedding_lookup:
+ # tpu-comment: this device holds fewer tables compared to some other
+ # devices in the xrt_world. In order to do the `all_to_all`
+ # correctly, pad the embeddings with a dummy tensor that's going
+ # to be dropped later.
+ ly.append(torch.zeros(V.shape, device=V.device))
+
# print(ly)
return ly
@@ -332,9 +466,17 @@ class DLRM_Net(nn.Module):
def forward(self, dense_x, lS_o, lS_i):
if self.ndevices <= 1:
return self.sequential_forward(dense_x, lS_o, lS_i)
+ elif self.use_tpu:
+ return self.tpu_parallel_forward(dense_x, lS_o, lS_i)
else:
return self.parallel_forward(dense_x, lS_o, lS_i)
+ def clamp_output(self, p):
+ z = p
+ if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
+ z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
+ return z
+
def sequential_forward(self, dense_x, lS_o, lS_i):
# process dense features (using bottom mlp), resulting in a row vector
x = self.apply_mlp(dense_x, self.bot_l)
@@ -355,13 +497,110 @@ class DLRM_Net(nn.Module):
p = self.apply_mlp(z, self.top_l)
# clamp output if needed
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
- z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
- else:
- z = p
-
+ z = self.clamp_output(p)
return z
+ def _partition_to_device(self, iterable):
+ return [
+ obj for k, obj in enumerate(iterable)
+ if self._tpu_index_belongs_to_ordinal(k)
+ ]
+
+ def _collect_distribute_embeddings(self, ordinal_data, local_bsz=None):
+ if local_bsz is None:
+ local_bsz = ordinal_data[0].size(0) // len(self._xla_replica_group)
+ full_data = self._gather_other_samples(ordinal_data)
+ full_data = full_data[self._non_pad_indices]
+ return self.narrow(local_bsz, full_data, dim=1)
+
+ # TODO: once the bug in alltoall is gone, use that instead of
+ # allgather+narrow, which is 10% slower.
+ def ALLTOALL_collect_distribute_embeddings(self, ordinal_data):
+
+ # XXX: clean
+ def debug_nan(self, t):
+ self._mark_step()
+ if isinstance(t, list):
+ print('DEBUG-NAN', self._ordinal, [torch.isnan(Z.view(-1)).sum().item() for Z in t])
+ print('DEBUG-INF', self._ordinal, [torch.isinf(Z.view(-1)).sum().item() for Z in t])
+ else:
+ print('DEBUG-NAN', self._ordinal, torch.isnan(t.view(-1)).sum().item())
+ print('DEBUG-INF', self._ordinal, torch.isinf(t.view(-1)).sum().item())
+
+ ordinal_data = torch.stack(ordinal_data)
+ #debug_nan(ordinal_data)
+ #self._mark_step() # TODO: needed due to bug. Delete when bug resolved.
+ full_data = self._all_to_all(
+ ordinal_data,
+ split_dimension=1,
+ concat_dimension=0,
+ split_count=self.ndevices,
+ groups=self._xla_replica_groups,
+ )
+ # FIXME: delete these when loss doesnt nan
+ #self._mark_step() # TODO: needed due to bug. Delete when bug resolved.
+ #debug_nan(full_data)
+ return full_data[self._non_pad_indices]
+
+
+ def _gather_other_samples(self, array, dim=0):
+ out = torch.stack(array)
+ # dim+1 because stack introduces a dimension to the left
+ out = self._all_gather(out, dim=dim+1, groups=self._xla_replica_groups)
+ return out
+
+ def narrow(self, local_bsz, tensor, dim=1):
+ return torch.narrow(
+ tensor, dim, self._xla_replica_index*local_bsz, local_bsz
+ )
+
+ def tpu_parallel_forward(self, dense_x, lS_o, lS_i):
+ batch_size = dense_x.size()[0]
+ ndevices = self.ndevices
+ assert not batch_size % ndevices, \
+ f"{batch_size} is bsz, {ndevices} devices"
+ local_bsz = batch_size // ndevices
+ dense_x = self.narrow(local_bsz, dense_x, dim=0)
+
+ #bottom mlp
+ x = self.bot_l(dense_x)
+ # embeddings
+ lS_i = self._partition_to_device(lS_i)
+ # offset is assumed to be constant for tpus
+ lS_o = [self.offset for _ in lS_i]
+ ly_local = self.apply_emb(lS_o, lS_i, self.emb_l)
+
+ # at this point, each device have the embeddings belonging to itself.
+ # we do an all_to_all to acquire all embeddings, i.e. full input.
+ ly = self._collect_distribute_embeddings(ly_local)
+
+ # now stop gradients from flowing back.
+ ly = [_.clone().detach().requires_grad_(True) for _ in ly]
+
+ # interactions
+ z = self.interact_features(x, ly)
+
+ # top mlp
+ p = self.top_l(z)
+
+ # clamp output if needed
+ z = self.clamp_output(p)
+
+ return (
+ z,
+ ly_local if not self._pad_embedding_lookup else ly_local[:-1],
+ ly
+ ) # extra return args needed during the custom bwd.
+
+ def tpu_local_backward(self, fullbatch_localembs, localbatch_fullembs):
+ localbatch_fullgrads = [_.grad for _ in localbatch_fullembs]
+ grad = self._gather_other_samples(localbatch_fullgrads) # inv to narrow
+ grad = self._partition_to_device(grad)
+ assert len(fullbatch_localembs) == len(grad), \
+ '{} vs {}'.format(len(fullbatch_localembs), len(grad))
+ for e, g in zip(fullbatch_localembs, grad):
+ e.backward(g)
+
def parallel_forward(self, dense_x, lS_o, lS_i):
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
@@ -461,22 +700,14 @@ class DLRM_Net(nn.Module):
p0 = gather(p, self.output_d, dim=0)
# clamp output if needed
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
- z0 = torch.clamp(
- p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
- )
- else:
- z0 = p0
+ z0 = self.clamp_output(p0)
return z0
-if __name__ == "__main__":
- ### import packages ###
- import sys
+def parse_args():
import argparse
- ### parse arguments ###
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
@@ -503,22 +734,24 @@ if __name__ == "__main__":
parser.add_argument("--loss-weights", type=str, default="1.0-1.0") # for wbce
parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7
parser.add_argument("--round-targets", type=bool, default=False)
+ parser.add_argument("--pred-threshold", type=float, default=0.5)
# data
parser.add_argument("--data-size", type=int, default=1)
parser.add_argument("--num-batches", type=int, default=0)
parser.add_argument(
"--data-generation", type=str, default="random"
) # synthetic or dataset
+ parser.add_argument("--drop-last", action="store_true")
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log")
parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte
parser.add_argument("--raw-data-file", type=str, default="")
parser.add_argument("--processed-data-file", type=str, default="")
parser.add_argument("--data-randomize", type=str, default="total") # or day or none
- parser.add_argument("--data-trace-enable-padding", type=bool, default=False)
+ parser.add_argument("--data-trace-enable-padding", action='store_true')
parser.add_argument("--max-ind-range", type=int, default=-1)
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1]
parser.add_argument("--num-indices-per-lookup", type=int, default=10)
- parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False)
+ parser.add_argument("--num-indices-per-lookup-fixed", action='store_true')
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--memory-map", action="store_true", default=False)
# training
@@ -532,8 +765,12 @@ if __name__ == "__main__":
parser.add_argument("--inference-only", action="store_true", default=False)
# onnx
parser.add_argument("--save-onnx", action="store_true", default=False)
- # gpu
+ # accelerators
parser.add_argument("--use-gpu", action="store_true", default=False)
+ parser.add_argument("--use-tpu", action="store_true", default=False)
+ parser.add_argument("--tpu-model-parallel-group-len", type=int, default=1)
+ parser.add_argument("--tpu-data-parallel", action="store_true")
+ parser.add_argument("--tpu-metrics-debug", action="store_true")
# debugging and profiling
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
@@ -558,16 +795,7 @@ if __name__ == "__main__":
parser.add_argument("--lr-num-warmup-steps", type=int, default=0)
parser.add_argument("--lr-decay-start-step", type=int, default=0)
parser.add_argument("--lr-num-decay-steps", type=int, default=0)
- args = parser.parse_args()
-
- if args.mlperf_logging:
- print('command line args: ', json.dumps(vars(args)))
-
- ### some basic setup ###
- np.random.seed(args.numpy_rand_seed)
- np.set_printoptions(precision=args.print_precision)
- torch.set_printoptions(precision=args.print_precision)
- torch.manual_seed(args.numpy_rand_seed)
+ args, _ = parser.parse_known_args()
if (args.test_mini_batch_size < 0):
# if the parameter is not set, use the training batch size
@@ -576,8 +804,84 @@ if __name__ == "__main__":
# if the parameter is not set, use the same parameter for training
args.test_num_workers = args.num_workers
+ return args
+
+
+def tpu_get_xla_replica_groups(args):
+ import torch_xla.core.xla_model as xm
+ # XXX: would it make sense to sort by size here -- to evenly dist. tables?
+ world_size = xm.xrt_world_size()
+ if args.tpu_data_parallel:
+ return [[i] for i in range(world_size)], [list(range(world_size))]
+ num_tables = args.arch_embedding_size.count("-") + 1
+ len_mp_group = args.tpu_model_parallel_group_len
+ #assert not num_tables % len_mp_group, \
+ # 'Model parallel group size has to divide number of emb tables evenly.'
+ assert not world_size % len_mp_group, \
+ 'Length of model parallel groups has to evenly divide `xrt_world_size`'
+ len_dp_group = world_size // len_mp_group
+ mp_groups = [
+ [len_mp_group*d+m for m in range(len_mp_group)]
+ for d in range(len_dp_group)
+ ]
+ dp_groups = [
+ [len_dp_group*d+m for m in range(len_dp_group)]
+ for d in range(len_mp_group)
+ ]
+ return mp_groups, dp_groups
+
+
+def main(*_args):
+ ### import packages ###
+ import sys
+
+ args = parse_args()
+ ### some basic setup ###
+ np.random.seed(args.numpy_rand_seed)
+ np.set_printoptions(precision=args.print_precision)
+ torch.set_printoptions(precision=args.print_precision)
+ # XXX: does this imply we init the emb tables the same way?
+ torch.manual_seed(args.numpy_rand_seed)
+
use_gpu = args.use_gpu and torch.cuda.is_available()
- if use_gpu:
+ use_tpu = args.use_tpu
+ print = builtins.print
+ if use_tpu:
+ use_gpu = False
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.distributed.parallel_loader as pl
+ print = xm.master_print
+ device = xm.xla_device()
+ print("Using {} TPU core(s)...".format(xm.xrt_world_size()))
+ if args.data_set in ['kaggle', 'terabyte']:
+ print('Criteo datasets have offset = 1, forcing the arguments..')
+ args.num_indices_per_lookup_fixed = True
+ args.num_indices_per_lookup = 1
+ if args.enable_profiling:
+ print("Profiling was enabled. Turning it off for TPUs.")
+ args.enable_profiling = False
+ if not args.num_indices_per_lookup_fixed:
+ # XXX: does this lead to recompilations?
+ raise NotImplementedError('This will lead to recompilations.')
+ mp_replica_groups, dp_replica_groups = tpu_get_xla_replica_groups(args)
+ print('XLA replica groups for Model Parallel:\n\t', mp_replica_groups)
+ print('XLA replica groups for Data Parallel:\n\t', dp_replica_groups)
+ if len(dp_replica_groups) == 1:
+ # i.e. no allgather etc in the emb layer.
+ print("TPU data-parallel mode, setting --tpu-data-parallel to True")
+ args.tpu_data_parallel = True
+ else:
+ print("TPU model-parallel mode, setting --drop-last=True")
+ args.drop_last = True
+ if args.print_time:
+ print(
+ '`torch_xla` async execution is not compatible with '
+ 'ms/it reporting, turning --print-time off.'
+ )
+ args.print_time = False
+
+ elif use_gpu:
torch.cuda.manual_seed_all(args.numpy_rand_seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda", 0)
@@ -587,6 +891,9 @@ if __name__ == "__main__":
device = torch.device("cpu")
print("Using CPU...")
+ if args.mlperf_logging:
+ print('command line args: ', json.dumps(vars(args)))
+
### prepare training data ###
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
# input data
@@ -610,8 +917,26 @@ if __name__ == "__main__":
# input and target at random
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_den = ln_bot[0]
- train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den)
- nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
+ data_args = [args, ln_emb, m_den]
+ if use_tpu:
+ ordinal = xm.get_ordinal()
+ for g in dp_replica_groups:
+ if ordinal in g:
+ n_replicas, rank = len(g), g.index(ordinal)
+ break
+ else:
+ raise ValueError(
+ 'Ordinal {} not in replica groups!'.format(self._ordinal)
+ )
+ data_args.extend([n_replicas, rank]) # extend w/ n_replicas and rank
+ train_data, train_ld = dp.make_random_data_and_loader(*data_args)
+ #nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
+ nbatches = len(train_ld)
+
+ if use_tpu:
+ # XXX: test_data is unused.
+ # Wrap w/ torch_xla's loader
+ train_ld = pl.MpDeviceLoader(train_ld, device)
### parse command line arguments ###
m_spa = args.arch_sparse_feature_size
@@ -737,7 +1062,19 @@ if __name__ == "__main__":
print([S_i.detach().cpu().tolist() for S_i in lS_i])
print(T.detach().cpu().numpy())
- ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1
+ ndevices = -1
+ if use_tpu:
+ ndevices = len(dp_replica_groups)
+ if args.mini_batch_size % ndevices:
+ raise NotImplementedError(
+ 'ndevices need to divide --mini-batch-size. '
+ 'bsz is {}, ndevices is {}'.format(
+ args.mini_batch_size, ndevices,
+ )
+ )
+
+ elif use_gpu:
+ ndevices = min(ngpus, args.mini_batch_size, num_fea - 1)
### construct the neural network specified above ###
# WARNING: to obtain exactly the same initialization for
@@ -761,6 +1098,9 @@ if __name__ == "__main__":
qr_threshold=args.qr_threshold,
md_flag=args.md_flag,
md_threshold=args.md_threshold,
+ sparse=device.type != 'xla',
+ use_tpu=use_tpu,
+ offset=args.num_indices_per_lookup,
)
# test prints
if args.debug_mode:
@@ -776,6 +1116,13 @@ if __name__ == "__main__":
dlrm = dlrm.to(device) # .cuda()
if dlrm.ndevices > 1:
dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb)
+ if use_tpu:
+ # XXX: ndevices is redundant, same info contained in dp_replica_groups
+ if dlrm.ndevices > 1:
+ dlrm.set_xla_replica_groups(mp_replica_groups)
+ dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb)
+ dlrm.device = device
+ dlrm = dlrm.to(device)
# specify the loss function
if args.loss_function == "mse":
@@ -790,11 +1137,34 @@ if __name__ == "__main__":
if not args.inference_only:
# specify the optimizer algorithm
- optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate)
- lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
- args.lr_num_decay_steps)
+
+ if use_tpu and not args.tpu_data_parallel:
+ # tpu's paradigm is different than gpu/cpu. Each process here runs
+ # on its own and all reduces need to happen in a particular way.
+ # Data parallel part, i.e. the MLP part will be allreduced.
+ # EmbeddingBag part will not be allreduced.
+ optimizer = torch.optim.SGD(
+ dlrm.mlp_parameters(), lr=args.learning_rate
+ )
+ emb_local_optimizer = torch.optim.SGD(
+ dlrm.emb_parameters(), lr=args.learning_rate
+ )
+ lr_scheduler = LRPolicyScheduler(
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
+ args.lr_num_decay_steps, use_tpu=True,
+ local_optimizer=emb_local_optimizer,
+ )
+ else:
+ optimizer = torch.optim.SGD(
+ dlrm.parameters(), lr=args.learning_rate
+ )
+ lr_scheduler = LRPolicyScheduler(
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
+ args.lr_num_decay_steps
+ )
### main loop ###
+
def time_wrap(use_gpu):
if use_gpu:
torch.cuda.synchronize()
@@ -808,15 +1178,16 @@ if __name__ == "__main__":
else lS_i.to(device)
lS_o = [S_o.to(device) for S_o in lS_o] if isinstance(lS_o, list) \
else lS_o.to(device)
- return dlrm(
- X.to(device),
- lS_o,
- lS_i
- )
+ return dlrm(X.to(device), lS_o, lS_i)
else:
return dlrm(X, lS_o, lS_i)
def loss_fn_wrap(Z, T, use_gpu, device):
+ if T.size(0) > Z.size(0):
+ # This happens for tpus.
+ # Target tensor is likely for global batch. Narrow to this device.
+ T = dlrm.narrow(Z.size(0), T, dim=0)
+
if args.loss_function == "mse" or args.loss_function == "bce":
if use_gpu:
return loss_fn(Z, T.to(device))
@@ -848,9 +1219,12 @@ if __name__ == "__main__":
k = 0
# Load model is specified
- if not (args.load_model == ""):
+ if args.load_model != "":
print("Loading saved model {}".format(args.load_model))
- if use_gpu:
+ if use_tpu:
+ # XXX: add tpu capabilities to load.
+ raise NotImplementedError('Add tpu capabilities to load.')
+ elif use_gpu:
if dlrm.ndevices > 1:
# NOTE: when targeting inference on multiple GPUs,
# load the model as is on CPU or GPU, with the move
@@ -932,6 +1306,7 @@ if __name__ == "__main__":
t1 = time_wrap(use_gpu)
# early exit if nbatches was set by the user and has been exceeded
+ # XXX: what about j being reset every epoch?
if nbatches > 0 and j >= nbatches:
break
'''
@@ -945,7 +1320,12 @@ if __name__ == "__main__":
'''
# forward pass
- Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
+ if use_tpu and not args.tpu_data_parallel:
+ # args[1:] below will be used in the custom backward
+ Z, fullbatch_localembs, localbatch_fullembs = \
+ dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
+ else:
+ Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device)
# loss
E = loss_fn_wrap(Z, T, use_gpu, device)
@@ -955,17 +1335,25 @@ if __name__ == "__main__":
print(Z.detach().cpu().numpy())
print(E.detach().cpu().numpy())
'''
- # compute loss and accuracy
- L = E.detach().cpu().numpy() # numpy array
- S = Z.detach().cpu().numpy() # numpy array
- T = T.detach().cpu().numpy() # numpy array
- mbs = T.shape[0] # = args.mini_batch_size except maybe for last
- A = np.sum((np.round(S, 0) == T).astype(np.uint8))
+
+ if use_tpu:
+ mbs = T.shape[0] * len(mp_replica_groups)
+ else:
+ mbs = T.shape[0] # = args.mini_batch_size except maybe for last
+ # XXX: conv related: T is a vector of floats here.
+ # args.round_targets related
+ # FIXME: figure out a way to do this on tpu
+ # S = Z.detach().cpu().numpy() # numpy array
+ # T = T.detach().cpu().numpy() # numpy array
+ #A = np.sum((np.round(S, 0) == T).astype(np.uint8))
+ A = 0
if not args.inference_only:
# scaled error gradient propagation
# (where we do not accumulate gradients across mini-batches)
optimizer.zero_grad()
+ if use_tpu and not args.tpu_data_parallel:
+ emb_local_optimizer.zero_grad()
# backward pass
E.backward()
# debug prints (check gradient norm)
@@ -973,17 +1361,34 @@ if __name__ == "__main__":
# if hasattr(l, 'weight'):
# print(l.weight.grad.norm().item())
- # optimizer
- optimizer.step()
+ # backward + optimizer
+ if use_tpu and not args.tpu_data_parallel:
+ # XXX: no clip grad?
+ # Full allreduce across all devices for the MLP parts.
+ xm.optimizer_step(optimizer, groups=None)
+ # bwd pass for the embedding tables.
+ dlrm.tpu_local_backward(
+ fullbatch_localembs, localbatch_fullembs,
+ )
+ if len(dp_replica_groups[0]) > 1:
+ xm.optimizer_step(
+ emb_local_optimizer, groups=dp_replica_groups,
+ )
+ else:
+ # no allreduce, just step
+ emb_local_optimizer.step()
+ else:
+ optimizer.step()
lr_scheduler.step()
+
if args.mlperf_logging:
total_time += iteration_time
else:
t2 = time_wrap(use_gpu)
total_time += t2 - t1
total_accu += A
- total_loss += L * mbs
+ total_loss += E.detach() * mbs
total_iter += 1
total_samp += mbs
@@ -996,6 +1401,12 @@ if __name__ == "__main__":
# print time, loss and accuracy
if should_print or should_test:
+ # XXX: detach+cpu+numpy seems to avoid the "aten" counter!!!!
+ #L = E.detach().cpu().numpy() # numpy array
+ #S = Z.detach().cpu().numpy() # numpy array
+ #T = T.detach().cpu().numpy() # numpy array
+ if use_tpu:
+ xm.mark_step()
gT = 1000.0 * total_time / total_iter if args.print_time else -1
total_time = 0
@@ -1007,10 +1418,13 @@ if __name__ == "__main__":
str_run_type = "inference" if args.inference_only else "training"
print(
- "Finished {} it {}/{} of epoch {}, {:.2f} ms/it, ".format(
- str_run_type, j + 1, nbatches, k, gT
+ (
+ "Finished {} it {}/{} of epoch {}, {:.2f} ms/it, "
+ "loss {:.6f}, accuracy {:3.3f} %, {} samples, @ {}"
+ ).format(
+ str_run_type, j + 1, nbatches, k, gT,
+ gL, gA * 100, total_samp, datetime.now()
)
- + "loss {:.6f}, accuracy {:3.3f} %".format(gL, gA * 100)
)
# Uncomment the line below to print out the total time with overhead
# print("Accumulated time so far: {}" \
@@ -1020,6 +1434,8 @@ if __name__ == "__main__":
# testing
if should_test and not args.inference_only:
+ # XXX: code path not hit currently
+ raise NotImplementedError('not hit')
# don't measure training iter time in a test iteration
if args.mlperf_logging:
previous_iteration_time = None
@@ -1198,6 +1614,8 @@ if __name__ == "__main__":
break
k += 1 # nepochs
+ if use_tpu and args.tpu_metrics_debug:
+ print(met.metrics_report())
# profiling
if args.enable_profiling:
@@ -1234,3 +1652,7 @@ if __name__ == "__main__":
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx")
# check the onnx model
onnx.checker.check_model(dlrm_pytorch_onnx)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dlrm_tpu_runner.py b/dlrm_tpu_runner.py
new file mode 100644
index 0000000..d31aa6d
--- /dev/null
+++ b/dlrm_tpu_runner.py
@@ -0,0 +1,15 @@
+import sys
+import argparse
+
+import torch_xla.distributed.xla_multiprocessing as xmp
+
+from dlrm_s_pytorch import main, parse_args
+
+
+if __name__ == '__main__':
+ pre_spawn_parser = argparse.ArgumentParser()
+ pre_spawn_parser.add_argument(
+ "--tpu-cores", type=int, default=8, choices=[1, 8]
+ )
+ pre_spawn_flags, _ = pre_spawn_parser.parse_known_args()
+ xmp.spawn(main, args=(), nprocs=pre_spawn_flags.tpu_cores)
diff --git a/tools/xla_embedding_bag.py b/tools/xla_embedding_bag.py
new file mode 100644
index 0000000..857748b
--- /dev/null
+++ b/tools/xla_embedding_bag.py
@@ -0,0 +1,39 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+
+
+class XlaEmbeddingBag(nn.Module):
+ """
+ nn.EmbeddingBag is not lowered just yet to xla.
+ This performs the same functionality, in an xla compatible, sub-optimal way.
+
+ Warning!: only works with constant offsets atm.
+ """
+
+ def __init__(self, n, m, mode, offset, *args, **kwargs):
+ super(XlaEmbeddingBag, self).__init__()
+ self.n = n
+ self.m = m
+ self.mode = mode
+ self.reduce_fn = getattr(torch, self.mode)
+ self.offset = offset
+ self.embtable = nn.Embedding(n, m, *args, **kwargs)
+
+ def forward(self, sparse_index_group_batch, sparse_offset_group_batch):
+ emb = self.embtable(sparse_index_group_batch)
+ return emb
+ # XXX: only works w/ constant offset atm
+ bsz = emb.size(0) // self.offset
+ emb = emb.reshape(bsz, self.offset, *emb.size()[1:])
+ return self.reduce_fn(emb, axis=1)
+ #return reduce_fn(self.embtable(_) for _ in inp_list)
+
+ @property
+ def weight(self):
+ return self.embtable.weight
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment