Created
July 17, 2020 17:59
-
-
Save taylanbil/092f993053527a9698638106f6b348a6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ git diff d45342e tpu-criteo-kaggle | |
diff --git a/.gitignore b/.gitignore | |
new file mode 100644 | |
index 0000000..a81c8ee | |
--- /dev/null | |
+++ b/.gitignore | |
@@ -0,0 +1,138 @@ | |
+# Byte-compiled / optimized / DLL files | |
+__pycache__/ | |
+*.py[cod] | |
+*$py.class | |
+ | |
+# C extensions | |
+*.so | |
+ | |
+# Distribution / packaging | |
+.Python | |
+build/ | |
+develop-eggs/ | |
+dist/ | |
+downloads/ | |
+eggs/ | |
+.eggs/ | |
+lib/ | |
+lib64/ | |
+parts/ | |
+sdist/ | |
+var/ | |
+wheels/ | |
+share/python-wheels/ | |
+*.egg-info/ | |
+.installed.cfg | |
+*.egg | |
+MANIFEST | |
+ | |
+# PyInstaller | |
+# Usually these files are written by a python script from a template | |
+# before PyInstaller builds the exe, so as to inject date/other infos into it. | |
+*.manifest | |
+*.spec | |
+ | |
+# Installer logs | |
+pip-log.txt | |
+pip-delete-this-directory.txt | |
+ | |
+# Unit test / coverage reports | |
+htmlcov/ | |
+.tox/ | |
+.nox/ | |
+.coverage | |
+.coverage.* | |
+.cache | |
+nosetests.xml | |
+coverage.xml | |
+*.cover | |
+*.py,cover | |
+.hypothesis/ | |
+.pytest_cache/ | |
+cover/ | |
+ | |
+# Translations | |
+*.mo | |
+*.pot | |
+ | |
+# Django stuff: | |
+*.log | |
+local_settings.py | |
+db.sqlite3 | |
+db.sqlite3-journal | |
+ | |
+# Flask stuff: | |
+instance/ | |
+.webassets-cache | |
+ | |
+# Scrapy stuff: | |
+.scrapy | |
+ | |
+# Sphinx documentation | |
+docs/_build/ | |
+ | |
+# PyBuilder | |
+.pybuilder/ | |
+target/ | |
+ | |
+# Jupyter Notebook | |
+.ipynb_checkpoints | |
+ | |
+# IPython | |
+profile_default/ | |
+ipython_config.py | |
+ | |
+# pyenv | |
+# For a library or package, you might want to ignore these files since the code is | |
+# intended to run in multiple environments; otherwise, check them in: | |
+# .python-version | |
+ | |
+# pipenv | |
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | |
+# However, in case of collaboration, if having platform-specific dependencies or dependencies | |
+# having no cross-platform support, pipenv may install dependencies that don't work, or not | |
+# install all needed dependencies. | |
+#Pipfile.lock | |
+ | |
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow | |
+__pypackages__/ | |
+ | |
+# Celery stuff | |
+celerybeat-schedule | |
+celerybeat.pid | |
+ | |
+# SageMath parsed files | |
+*.sage.py | |
+ | |
+# Environments | |
+.env | |
+.venv | |
+env/ | |
+venv/ | |
+ENV/ | |
+env.bak/ | |
+venv.bak/ | |
+ | |
+# Spyder project settings | |
+.spyderproject | |
+.spyproject | |
+ | |
+# Rope project settings | |
+.ropeproject | |
+ | |
+# mkdocs documentation | |
+/site | |
+ | |
+# mypy | |
+.mypy_cache/ | |
+.dmypy.json | |
+dmypy.json | |
+ | |
+# Pyre type checker | |
+.pyre/ | |
+ | |
+# pytype static type analyzer | |
+.pytype/ | |
+ | |
+# Cython debug symbols | |
+cython_debug/ | |
diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py | |
index 6cbe382..4297430 100644 | |
--- a/dlrm_data_pytorch.py | |
+++ b/dlrm_data_pytorch.py | |
@@ -266,7 +266,7 @@ class CriteoDataset(Dataset): | |
if self.memory_map: | |
if self.split == 'none' or self.split == 'train': | |
- # check if need to swicth to next day and load data | |
+ # check if need to switch to next day and load data | |
if index == self.offset_per_file[self.day]: | |
# print("day_boundary switch", index) | |
self.day_boundary = self.offset_per_file[self.day] | |
@@ -376,7 +376,7 @@ def ensure_dataset_preprocessed(args, d_path): | |
split=split) | |
-def make_criteo_data_and_loaders(args): | |
+def make_criteo_data_and_loaders(args, n_replicas=None, rank=None): | |
if args.mlperf_logging and args.memory_map and args.data_set == "terabyte": | |
# more efficient for larger batches | |
@@ -496,6 +496,20 @@ def make_criteo_data_and_loaders(args): | |
args.memory_map | |
) | |
+ train_sampler, test_sampler = None, None | |
+ if rank is not None: | |
+ train_sampler = torch.utils.data.distributed.DistributedSampler( | |
+ train_data, | |
+ num_replicas=n_replicas, | |
+ rank=rank, | |
+ shuffle=True, | |
+ ) | |
+ test_sampler = torch.utils.data.distributed.DistributedSampler( | |
+ test_data, | |
+ num_replicas=n_replicas, | |
+ rank=rank, | |
+ shuffle=False, | |
+ ) | |
train_loader = torch.utils.data.DataLoader( | |
train_data, | |
batch_size=args.mini_batch_size, | |
@@ -503,7 +517,8 @@ def make_criteo_data_and_loaders(args): | |
num_workers=args.num_workers, | |
collate_fn=collate_wrapper_criteo, | |
pin_memory=False, | |
- drop_last=False, # True | |
+ drop_last=args.drop_last, | |
+ sampler=train_sampler, | |
) | |
test_loader = torch.utils.data.DataLoader( | |
@@ -513,7 +528,8 @@ def make_criteo_data_and_loaders(args): | |
num_workers=args.test_num_workers, | |
collate_fn=collate_wrapper_criteo, | |
pin_memory=False, | |
- drop_last=False, # True | |
+ drop_last=args.drop_last, | |
+ sampler=test_sampler, | |
) | |
return train_data, train_loader, test_data, test_loader | |
@@ -627,7 +643,9 @@ def collate_wrapper_random(list_of_tuples): | |
T) | |
-def make_random_data_and_loader(args, ln_emb, m_den): | |
+def make_random_data_and_loader( | |
+ args, ln_emb, m_den, n_replicas=None, rank=None, | |
+): | |
train_data = RandomDataset( | |
m_den, | |
@@ -645,6 +663,14 @@ def make_random_data_and_loader(args, ln_emb, m_den): | |
reset_seed_on_access=True, | |
rand_seed=args.numpy_rand_seed | |
) # WARNING: generates a batch of lookups at once | |
+ train_sampler = None | |
+ if rank is not None: | |
+ train_sampler = torch.utils.data.distributed.DistributedSampler( | |
+ train_data, | |
+ num_replicas=n_replicas, | |
+ rank=rank, | |
+ shuffle=True, | |
+ ) | |
train_loader = torch.utils.data.DataLoader( | |
train_data, | |
batch_size=1, | |
@@ -652,7 +678,8 @@ def make_random_data_and_loader(args, ln_emb, m_den): | |
num_workers=args.num_workers, | |
collate_fn=collate_wrapper_random, | |
pin_memory=False, | |
- drop_last=False, # True | |
+ drop_last=args.drop_last, | |
+ sampler=train_sampler, | |
) | |
return train_data, train_loader | |
@@ -764,7 +791,9 @@ def generate_uniform_input_batch( | |
) | |
# sparse indices to be used per embedding | |
r = ra.random(sparse_group_size) | |
- sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64)) | |
+ # XXX: why np.unique ? This is producing a different input shape.. | |
+ #sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64)) | |
+ sparse_group = np.round(r * (size - 1)).astype(np.int64) | |
# reset sparse_group_size in case some index duplicates were removed | |
sparse_group_size = np.int64(sparse_group.size) | |
# store lengths and indices | |
diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py | |
index 1955bb9..c64a9c0 100644 | |
--- a/dlrm_s_pytorch.py | |
+++ b/dlrm_s_pytorch.py | |
@@ -58,6 +58,8 @@ import builtins | |
import functools | |
# import bisect | |
# import shutil | |
+import sys | |
+from datetime import datetime | |
import time | |
import json | |
# data generation | |
@@ -93,10 +95,14 @@ import sklearn.metrics | |
from torch.optim.lr_scheduler import _LRScheduler | |
+ | |
exc = getattr(builtins, "IOError", "FileNotFoundError") | |
class LRPolicyScheduler(_LRScheduler): | |
- def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): | |
+ def __init__( | |
+ self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps, | |
+ use_tpu=False, local_optimizer=None, | |
+ ): | |
self.num_warmup_steps = num_warmup_steps | |
self.decay_start_step = decay_start_step | |
self.decay_end_step = decay_start_step + num_decay_steps | |
@@ -105,6 +111,9 @@ class LRPolicyScheduler(_LRScheduler): | |
if self.decay_start_step < self.num_warmup_steps: | |
sys.exit("Learning rate warmup must finish before the decay starts") | |
+ self.use_tpu = use_tpu | |
+ self.sparse_feature_local_optimizer = local_optimizer | |
+ | |
super(LRPolicyScheduler, self).__init__(optimizer) | |
def get_lr(self): | |
@@ -129,8 +138,18 @@ class LRPolicyScheduler(_LRScheduler): | |
lr = self.base_lrs | |
return lr | |
+ def step(self): | |
+ super().step() | |
+ if self.sparse_feature_local_optimizer is not None: | |
+ # XXX: is lr always a single value list? | |
+ lr = self.get_lr()[0] | |
+ for param_group in self.sparse_feature_local_optimizer.param_groups: | |
+ param_group['lr'] = lr | |
+ | |
+ | |
### define dlrm in PyTorch ### | |
class DLRM_Net(nn.Module): | |
+ | |
def create_mlp(self, ln, sigmoid_layer): | |
# build MLP layer by layer | |
layers = nn.ModuleList() | |
@@ -171,15 +190,61 @@ class DLRM_Net(nn.Module): | |
# approach 2: use Sequential container to wrap all layers | |
return torch.nn.Sequential(*layers) | |
+ def _set_up_sparse_feature_info(self, ln): | |
+ self._group_table_count = ln.size | |
+ self._device_table_count = sum( | |
+ self._tpu_index_belongs_to_ordinal(i) | |
+ for i in range(ln.size) | |
+ ) | |
+ self._pad_embedding_lookup = ( | |
+ len(self._xla_replica_group) * self._device_table_count < | |
+ self._group_table_count | |
+ ) | |
+ self._max_device_table_count_in_group = int(np.ceil( | |
+ self._group_table_count / len(self._xla_replica_group) | |
+ )) | |
+ self._table_count_padded = ( | |
+ len(self._xla_replica_group) * | |
+ self._max_device_table_count_in_group | |
+ ) | |
+ self._pad_indices = set( | |
+ ( | |
+ self._table_count_padded - 1 - | |
+ i*self._max_device_table_count_in_group | |
+ ) | |
+ for i in range(self._table_count_padded - self._group_table_count) | |
+ ) | |
+ self._non_pad_indices = [ | |
+ i for i in range(self._table_count_padded) | |
+ if i not in self._pad_indices | |
+ ] | |
+ | |
def create_emb(self, m, ln): | |
emb_l = nn.ModuleList() | |
+ if self.use_tpu and self.ndevices > 1: | |
+ self._set_up_sparse_feature_info(ln) | |
for i in range(0, ln.size): | |
n = ln[i] | |
+ if ( | |
+ self.use_tpu and | |
+ self.ndevices > 1 and | |
+ not self._tpu_index_belongs_to_ordinal(i) | |
+ ): | |
+ # tpu model-parallel mode. only create this ordinal's tables. | |
+ continue | |
# construct embedding operator | |
if self.qr_flag and n > self.qr_threshold: | |
- EE = QREmbeddingBag(n, m, self.qr_collisions, | |
- operation=self.qr_operation, mode="sum", sparse=True) | |
+ assert not self.use_tpu, \ | |
+ 'QR trick not implemented for tpus'.upper() | |
+ EE = QREmbeddingBag( | |
+ n, m, self.qr_collisions, | |
+ operation=self.qr_operation, | |
+ mode="sum", sparse=self.sparse, | |
+ xla=self.use_tpu, | |
+ ) | |
elif self.md_flag: | |
+ assert not self.use_tpu, \ | |
+ 'MD trick not implemented for tpus'.upper() | |
base = max(m) | |
_m = m[i] if n > self.md_threshold else base | |
EE = PrEmbeddingBag(n, _m, base) | |
@@ -190,7 +255,20 @@ class DLRM_Net(nn.Module): | |
EE.embs.weight.data = torch.tensor(W, requires_grad=True) | |
else: | |
- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) | |
+ if self.use_tpu: | |
+ # TODO: remove when xla supports `nn.EmbeddingBag` | |
+ from tools.xla_embedding_bag import XlaEmbeddingBag | |
+ errmsg = ( | |
+ '`use_tpu` was specified. XLA currently only supports ' | |
+ 'fixed length sparse groups.' | |
+ ) | |
+ assert self.offset is not None, errmsg | |
+ EE = XlaEmbeddingBag( | |
+ n, m, mode="sum", | |
+ sparse=self.sparse, offset=self.offset, | |
+ ) | |
+ else: | |
+ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse) | |
# initialize embeddings | |
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) | |
@@ -208,6 +286,26 @@ class DLRM_Net(nn.Module): | |
return emb_l | |
+ def set_xla_replica_groups(self, groups): | |
+ self._xla_replica_groups = groups | |
+ for g in groups: | |
+ if self._ordinal in g: | |
+ self._xla_replica_group = g | |
+ assert self.ndevices == len(g), \ | |
+ 'MP replica group does not match ndevices, {} vs {}'.format( | |
+ len(g), self.ndevices | |
+ ) | |
+ self._xla_replica_index = g.index(self._ordinal) | |
+ break | |
+ else: | |
+ raise ValueError( | |
+ 'Ordinal {} not in replica groups!'.format(self._ordinal) | |
+ ) | |
+ | |
+ def _tpu_index_belongs_to_ordinal(self, i, ordinal=None): | |
+ ordinal = self._xla_replica_index if ordinal is None else ordinal | |
+ return i % len(self._xla_replica_group) == ordinal | |
+ | |
def __init__( | |
self, | |
m_spa=None, | |
@@ -227,6 +325,9 @@ class DLRM_Net(nn.Module): | |
qr_threshold=200, | |
md_flag=False, | |
md_threshold=200, | |
+ sparse=True, | |
+ use_tpu=False, | |
+ offset=None | |
): | |
super(DLRM_Net, self).__init__() | |
@@ -247,6 +348,9 @@ class DLRM_Net(nn.Module): | |
self.arch_interaction_itself = arch_interaction_itself | |
self.sync_dense_params = sync_dense_params | |
self.loss_threshold = loss_threshold | |
+ self.sparse = (not use_tpu) and sparse | |
+ self.use_tpu = use_tpu | |
+ self.offset = offset | |
# create variables for QR embedding if applicable | |
self.qr_flag = qr_flag | |
if self.qr_flag: | |
@@ -257,12 +361,35 @@ class DLRM_Net(nn.Module): | |
self.md_flag = md_flag | |
if self.md_flag: | |
self.md_threshold = md_threshold | |
+ if self.use_tpu: | |
+ self._init_tpu() | |
# create operators | |
if ndevices <= 1: | |
self.emb_l = self.create_emb(m_spa, ln_emb) | |
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) | |
self.top_l = self.create_mlp(ln_top, sigmoid_top) | |
+ def _init_tpu(self): | |
+ # no need to be ordinal aware in full data parallel mode. | |
+ if self.ndevices > 1: | |
+ import torch_xla.core.xla_model as xm | |
+ self._ordinal = xm.get_ordinal() | |
+ self._all_reduce = xm.all_reduce | |
+ self._all_gather = xm.all_gather | |
+ #self._all_to_all = xm.all_to_all | |
+ #self._mark_step = xm.mark_step | |
+ | |
+ def _filter_params(self, f): | |
+ for name, p in self.named_parameters(): | |
+ if f(name): | |
+ yield p | |
+ | |
+ def mlp_parameters(self): | |
+ return self._filter_params(lambda name: not name.startswith('emb_l')) | |
+ | |
+ def emb_parameters(self): | |
+ return self._filter_params(lambda name: name.startswith('emb_l')) | |
+ | |
def apply_mlp(self, x, layers): | |
# approach 1: use ModuleList | |
# for layer in layers: | |
@@ -292,6 +419,13 @@ class DLRM_Net(nn.Module): | |
ly.append(V) | |
+ if self.use_tpu and self.ndevices > 1 and self._pad_embedding_lookup: | |
+ # tpu-comment: this device holds fewer tables compared to some other | |
+ # devices in the xrt_world. In order to do the `all_to_all` | |
+ # correctly, pad the embeddings with a dummy tensor that's going | |
+ # to be dropped later. | |
+ ly.append(torch.zeros(V.shape, device=V.device)) | |
+ | |
# print(ly) | |
return ly | |
@@ -332,9 +466,17 @@ class DLRM_Net(nn.Module): | |
def forward(self, dense_x, lS_o, lS_i): | |
if self.ndevices <= 1: | |
return self.sequential_forward(dense_x, lS_o, lS_i) | |
+ elif self.use_tpu: | |
+ return self.tpu_parallel_forward(dense_x, lS_o, lS_i) | |
else: | |
return self.parallel_forward(dense_x, lS_o, lS_i) | |
+ def clamp_output(self, p): | |
+ z = p | |
+ if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: | |
+ z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) | |
+ return z | |
+ | |
def sequential_forward(self, dense_x, lS_o, lS_i): | |
# process dense features (using bottom mlp), resulting in a row vector | |
x = self.apply_mlp(dense_x, self.bot_l) | |
@@ -355,13 +497,110 @@ class DLRM_Net(nn.Module): | |
p = self.apply_mlp(z, self.top_l) | |
# clamp output if needed | |
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: | |
- z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) | |
- else: | |
- z = p | |
- | |
+ z = self.clamp_output(p) | |
return z | |
+ def _partition_to_device(self, iterable): | |
+ return [ | |
+ obj for k, obj in enumerate(iterable) | |
+ if self._tpu_index_belongs_to_ordinal(k) | |
+ ] | |
+ | |
+ def _collect_distribute_embeddings(self, ordinal_data, local_bsz=None): | |
+ if local_bsz is None: | |
+ local_bsz = ordinal_data[0].size(0) // len(self._xla_replica_group) | |
+ full_data = self._gather_other_samples(ordinal_data) | |
+ full_data = full_data[self._non_pad_indices] | |
+ return self.narrow(local_bsz, full_data, dim=1) | |
+ | |
+ # TODO: once the bug in alltoall is gone, use that instead of | |
+ # allgather+narrow, which is 10% slower. | |
+ def ALLTOALL_collect_distribute_embeddings(self, ordinal_data): | |
+ | |
+ # XXX: clean | |
+ def debug_nan(self, t): | |
+ self._mark_step() | |
+ if isinstance(t, list): | |
+ print('DEBUG-NAN', self._ordinal, [torch.isnan(Z.view(-1)).sum().item() for Z in t]) | |
+ print('DEBUG-INF', self._ordinal, [torch.isinf(Z.view(-1)).sum().item() for Z in t]) | |
+ else: | |
+ print('DEBUG-NAN', self._ordinal, torch.isnan(t.view(-1)).sum().item()) | |
+ print('DEBUG-INF', self._ordinal, torch.isinf(t.view(-1)).sum().item()) | |
+ | |
+ ordinal_data = torch.stack(ordinal_data) | |
+ #debug_nan(ordinal_data) | |
+ #self._mark_step() # TODO: needed due to bug. Delete when bug resolved. | |
+ full_data = self._all_to_all( | |
+ ordinal_data, | |
+ split_dimension=1, | |
+ concat_dimension=0, | |
+ split_count=self.ndevices, | |
+ groups=self._xla_replica_groups, | |
+ ) | |
+ # FIXME: delete these when loss doesnt nan | |
+ #self._mark_step() # TODO: needed due to bug. Delete when bug resolved. | |
+ #debug_nan(full_data) | |
+ return full_data[self._non_pad_indices] | |
+ | |
+ | |
+ def _gather_other_samples(self, array, dim=0): | |
+ out = torch.stack(array) | |
+ # dim+1 because stack introduces a dimension to the left | |
+ out = self._all_gather(out, dim=dim+1, groups=self._xla_replica_groups) | |
+ return out | |
+ | |
+ def narrow(self, local_bsz, tensor, dim=1): | |
+ return torch.narrow( | |
+ tensor, dim, self._xla_replica_index*local_bsz, local_bsz | |
+ ) | |
+ | |
+ def tpu_parallel_forward(self, dense_x, lS_o, lS_i): | |
+ batch_size = dense_x.size()[0] | |
+ ndevices = self.ndevices | |
+ assert not batch_size % ndevices, \ | |
+ f"{batch_size} is bsz, {ndevices} devices" | |
+ local_bsz = batch_size // ndevices | |
+ dense_x = self.narrow(local_bsz, dense_x, dim=0) | |
+ | |
+ #bottom mlp | |
+ x = self.bot_l(dense_x) | |
+ # embeddings | |
+ lS_i = self._partition_to_device(lS_i) | |
+ # offset is assumed to be constant for tpus | |
+ lS_o = [self.offset for _ in lS_i] | |
+ ly_local = self.apply_emb(lS_o, lS_i, self.emb_l) | |
+ | |
+ # at this point, each device have the embeddings belonging to itself. | |
+ # we do an all_to_all to acquire all embeddings, i.e. full input. | |
+ ly = self._collect_distribute_embeddings(ly_local) | |
+ | |
+ # now stop gradients from flowing back. | |
+ ly = [_.clone().detach().requires_grad_(True) for _ in ly] | |
+ | |
+ # interactions | |
+ z = self.interact_features(x, ly) | |
+ | |
+ # top mlp | |
+ p = self.top_l(z) | |
+ | |
+ # clamp output if needed | |
+ z = self.clamp_output(p) | |
+ | |
+ return ( | |
+ z, | |
+ ly_local if not self._pad_embedding_lookup else ly_local[:-1], | |
+ ly | |
+ ) # extra return args needed during the custom bwd. | |
+ | |
+ def tpu_local_backward(self, fullbatch_localembs, localbatch_fullembs): | |
+ localbatch_fullgrads = [_.grad for _ in localbatch_fullembs] | |
+ grad = self._gather_other_samples(localbatch_fullgrads) # inv to narrow | |
+ grad = self._partition_to_device(grad) | |
+ assert len(fullbatch_localembs) == len(grad), \ | |
+ '{} vs {}'.format(len(fullbatch_localembs), len(grad)) | |
+ for e, g in zip(fullbatch_localembs, grad): | |
+ e.backward(g) | |
+ | |
def parallel_forward(self, dense_x, lS_o, lS_i): | |
### prepare model (overwrite) ### | |
# WARNING: # of devices must be >= batch size in parallel_forward call | |
@@ -461,22 +700,14 @@ class DLRM_Net(nn.Module): | |
p0 = gather(p, self.output_d, dim=0) | |
# clamp output if needed | |
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: | |
- z0 = torch.clamp( | |
- p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold) | |
- ) | |
- else: | |
- z0 = p0 | |
+ z0 = self.clamp_output(p0) | |
return z0 | |
-if __name__ == "__main__": | |
- ### import packages ### | |
- import sys | |
+def parse_args(): | |
import argparse | |
- ### parse arguments ### | |
parser = argparse.ArgumentParser( | |
description="Train Deep Learning Recommendation Model (DLRM)" | |
) | |
@@ -503,22 +734,24 @@ if __name__ == "__main__": | |
parser.add_argument("--loss-weights", type=str, default="1.0-1.0") # for wbce | |
parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7 | |
parser.add_argument("--round-targets", type=bool, default=False) | |
+ parser.add_argument("--pred-threshold", type=float, default=0.5) | |
# data | |
parser.add_argument("--data-size", type=int, default=1) | |
parser.add_argument("--num-batches", type=int, default=0) | |
parser.add_argument( | |
"--data-generation", type=str, default="random" | |
) # synthetic or dataset | |
+ parser.add_argument("--drop-last", action="store_true") | |
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log") | |
parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte | |
parser.add_argument("--raw-data-file", type=str, default="") | |
parser.add_argument("--processed-data-file", type=str, default="") | |
parser.add_argument("--data-randomize", type=str, default="total") # or day or none | |
- parser.add_argument("--data-trace-enable-padding", type=bool, default=False) | |
+ parser.add_argument("--data-trace-enable-padding", action='store_true') | |
parser.add_argument("--max-ind-range", type=int, default=-1) | |
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1] | |
parser.add_argument("--num-indices-per-lookup", type=int, default=10) | |
- parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False) | |
+ parser.add_argument("--num-indices-per-lookup-fixed", action='store_true') | |
parser.add_argument("--num-workers", type=int, default=0) | |
parser.add_argument("--memory-map", action="store_true", default=False) | |
# training | |
@@ -532,8 +765,12 @@ if __name__ == "__main__": | |
parser.add_argument("--inference-only", action="store_true", default=False) | |
# onnx | |
parser.add_argument("--save-onnx", action="store_true", default=False) | |
- # gpu | |
+ # accelerators | |
parser.add_argument("--use-gpu", action="store_true", default=False) | |
+ parser.add_argument("--use-tpu", action="store_true", default=False) | |
+ parser.add_argument("--tpu-model-parallel-group-len", type=int, default=1) | |
+ parser.add_argument("--tpu-data-parallel", action="store_true") | |
+ parser.add_argument("--tpu-metrics-debug", action="store_true") | |
# debugging and profiling | |
parser.add_argument("--print-freq", type=int, default=1) | |
parser.add_argument("--test-freq", type=int, default=-1) | |
@@ -558,16 +795,7 @@ if __name__ == "__main__": | |
parser.add_argument("--lr-num-warmup-steps", type=int, default=0) | |
parser.add_argument("--lr-decay-start-step", type=int, default=0) | |
parser.add_argument("--lr-num-decay-steps", type=int, default=0) | |
- args = parser.parse_args() | |
- | |
- if args.mlperf_logging: | |
- print('command line args: ', json.dumps(vars(args))) | |
- | |
- ### some basic setup ### | |
- np.random.seed(args.numpy_rand_seed) | |
- np.set_printoptions(precision=args.print_precision) | |
- torch.set_printoptions(precision=args.print_precision) | |
- torch.manual_seed(args.numpy_rand_seed) | |
+ args, _ = parser.parse_known_args() | |
if (args.test_mini_batch_size < 0): | |
# if the parameter is not set, use the training batch size | |
@@ -576,8 +804,84 @@ if __name__ == "__main__": | |
# if the parameter is not set, use the same parameter for training | |
args.test_num_workers = args.num_workers | |
+ return args | |
+ | |
+ | |
+def tpu_get_xla_replica_groups(args): | |
+ import torch_xla.core.xla_model as xm | |
+ # XXX: would it make sense to sort by size here -- to evenly dist. tables? | |
+ world_size = xm.xrt_world_size() | |
+ if args.tpu_data_parallel: | |
+ return [[i] for i in range(world_size)], [list(range(world_size))] | |
+ num_tables = args.arch_embedding_size.count("-") + 1 | |
+ len_mp_group = args.tpu_model_parallel_group_len | |
+ #assert not num_tables % len_mp_group, \ | |
+ # 'Model parallel group size has to divide number of emb tables evenly.' | |
+ assert not world_size % len_mp_group, \ | |
+ 'Length of model parallel groups has to evenly divide `xrt_world_size`' | |
+ len_dp_group = world_size // len_mp_group | |
+ mp_groups = [ | |
+ [len_mp_group*d+m for m in range(len_mp_group)] | |
+ for d in range(len_dp_group) | |
+ ] | |
+ dp_groups = [ | |
+ [len_dp_group*d+m for m in range(len_dp_group)] | |
+ for d in range(len_mp_group) | |
+ ] | |
+ return mp_groups, dp_groups | |
+ | |
+ | |
+def main(*_args): | |
+ ### import packages ### | |
+ import sys | |
+ | |
+ args = parse_args() | |
+ ### some basic setup ### | |
+ np.random.seed(args.numpy_rand_seed) | |
+ np.set_printoptions(precision=args.print_precision) | |
+ torch.set_printoptions(precision=args.print_precision) | |
+ # XXX: does this imply we init the emb tables the same way? | |
+ torch.manual_seed(args.numpy_rand_seed) | |
+ | |
use_gpu = args.use_gpu and torch.cuda.is_available() | |
- if use_gpu: | |
+ use_tpu = args.use_tpu | |
+ print = builtins.print | |
+ if use_tpu: | |
+ use_gpu = False | |
+ import torch_xla.core.xla_model as xm | |
+ import torch_xla.debug.metrics as met | |
+ import torch_xla.distributed.parallel_loader as pl | |
+ print = xm.master_print | |
+ device = xm.xla_device() | |
+ print("Using {} TPU core(s)...".format(xm.xrt_world_size())) | |
+ if args.data_set in ['kaggle', 'terabyte']: | |
+ print('Criteo datasets have offset = 1, forcing the arguments..') | |
+ args.num_indices_per_lookup_fixed = True | |
+ args.num_indices_per_lookup = 1 | |
+ if args.enable_profiling: | |
+ print("Profiling was enabled. Turning it off for TPUs.") | |
+ args.enable_profiling = False | |
+ if not args.num_indices_per_lookup_fixed: | |
+ # XXX: does this lead to recompilations? | |
+ raise NotImplementedError('This will lead to recompilations.') | |
+ mp_replica_groups, dp_replica_groups = tpu_get_xla_replica_groups(args) | |
+ print('XLA replica groups for Model Parallel:\n\t', mp_replica_groups) | |
+ print('XLA replica groups for Data Parallel:\n\t', dp_replica_groups) | |
+ if len(dp_replica_groups) == 1: | |
+ # i.e. no allgather etc in the emb layer. | |
+ print("TPU data-parallel mode, setting --tpu-data-parallel to True") | |
+ args.tpu_data_parallel = True | |
+ else: | |
+ print("TPU model-parallel mode, setting --drop-last=True") | |
+ args.drop_last = True | |
+ if args.print_time: | |
+ print( | |
+ '`torch_xla` async execution is not compatible with ' | |
+ 'ms/it reporting, turning --print-time off.' | |
+ ) | |
+ args.print_time = False | |
+ | |
+ elif use_gpu: | |
torch.cuda.manual_seed_all(args.numpy_rand_seed) | |
torch.backends.cudnn.deterministic = True | |
device = torch.device("cuda", 0) | |
@@ -587,6 +891,9 @@ if __name__ == "__main__": | |
device = torch.device("cpu") | |
print("Using CPU...") | |
+ if args.mlperf_logging: | |
+ print('command line args: ', json.dumps(vars(args))) | |
+ | |
### prepare training data ### | |
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") | |
# input data | |
@@ -610,8 +917,26 @@ if __name__ == "__main__": | |
# input and target at random | |
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") | |
m_den = ln_bot[0] | |
- train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den) | |
- nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) | |
+ data_args = [args, ln_emb, m_den] | |
+ if use_tpu: | |
+ ordinal = xm.get_ordinal() | |
+ for g in dp_replica_groups: | |
+ if ordinal in g: | |
+ n_replicas, rank = len(g), g.index(ordinal) | |
+ break | |
+ else: | |
+ raise ValueError( | |
+ 'Ordinal {} not in replica groups!'.format(self._ordinal) | |
+ ) | |
+ data_args.extend([n_replicas, rank]) # extend w/ n_replicas and rank | |
+ train_data, train_ld = dp.make_random_data_and_loader(*data_args) | |
+ #nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) | |
+ nbatches = len(train_ld) | |
+ | |
+ if use_tpu: | |
+ # XXX: test_data is unused. | |
+ # Wrap w/ torch_xla's loader | |
+ train_ld = pl.MpDeviceLoader(train_ld, device) | |
### parse command line arguments ### | |
m_spa = args.arch_sparse_feature_size | |
@@ -737,7 +1062,19 @@ if __name__ == "__main__": | |
print([S_i.detach().cpu().tolist() for S_i in lS_i]) | |
print(T.detach().cpu().numpy()) | |
- ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1 | |
+ ndevices = -1 | |
+ if use_tpu: | |
+ ndevices = len(dp_replica_groups) | |
+ if args.mini_batch_size % ndevices: | |
+ raise NotImplementedError( | |
+ 'ndevices need to divide --mini-batch-size. ' | |
+ 'bsz is {}, ndevices is {}'.format( | |
+ args.mini_batch_size, ndevices, | |
+ ) | |
+ ) | |
+ | |
+ elif use_gpu: | |
+ ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) | |
### construct the neural network specified above ### | |
# WARNING: to obtain exactly the same initialization for | |
@@ -761,6 +1098,9 @@ if __name__ == "__main__": | |
qr_threshold=args.qr_threshold, | |
md_flag=args.md_flag, | |
md_threshold=args.md_threshold, | |
+ sparse=device.type != 'xla', | |
+ use_tpu=use_tpu, | |
+ offset=args.num_indices_per_lookup, | |
) | |
# test prints | |
if args.debug_mode: | |
@@ -776,6 +1116,13 @@ if __name__ == "__main__": | |
dlrm = dlrm.to(device) # .cuda() | |
if dlrm.ndevices > 1: | |
dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb) | |
+ if use_tpu: | |
+ # XXX: ndevices is redundant, same info contained in dp_replica_groups | |
+ if dlrm.ndevices > 1: | |
+ dlrm.set_xla_replica_groups(mp_replica_groups) | |
+ dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb) | |
+ dlrm.device = device | |
+ dlrm = dlrm.to(device) | |
# specify the loss function | |
if args.loss_function == "mse": | |
@@ -790,11 +1137,34 @@ if __name__ == "__main__": | |
if not args.inference_only: | |
# specify the optimizer algorithm | |
- optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate) | |
- lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
- args.lr_num_decay_steps) | |
+ | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ # tpu's paradigm is different than gpu/cpu. Each process here runs | |
+ # on its own and all reduces need to happen in a particular way. | |
+ # Data parallel part, i.e. the MLP part will be allreduced. | |
+ # EmbeddingBag part will not be allreduced. | |
+ optimizer = torch.optim.SGD( | |
+ dlrm.mlp_parameters(), lr=args.learning_rate | |
+ ) | |
+ emb_local_optimizer = torch.optim.SGD( | |
+ dlrm.emb_parameters(), lr=args.learning_rate | |
+ ) | |
+ lr_scheduler = LRPolicyScheduler( | |
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
+ args.lr_num_decay_steps, use_tpu=True, | |
+ local_optimizer=emb_local_optimizer, | |
+ ) | |
+ else: | |
+ optimizer = torch.optim.SGD( | |
+ dlrm.parameters(), lr=args.learning_rate | |
+ ) | |
+ lr_scheduler = LRPolicyScheduler( | |
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
+ args.lr_num_decay_steps | |
+ ) | |
### main loop ### | |
+ | |
def time_wrap(use_gpu): | |
if use_gpu: | |
torch.cuda.synchronize() | |
@@ -808,15 +1178,16 @@ if __name__ == "__main__": | |
else lS_i.to(device) | |
lS_o = [S_o.to(device) for S_o in lS_o] if isinstance(lS_o, list) \ | |
else lS_o.to(device) | |
- return dlrm( | |
- X.to(device), | |
- lS_o, | |
- lS_i | |
- ) | |
+ return dlrm(X.to(device), lS_o, lS_i) | |
else: | |
return dlrm(X, lS_o, lS_i) | |
def loss_fn_wrap(Z, T, use_gpu, device): | |
+ if T.size(0) > Z.size(0): | |
+ # This happens for tpus. | |
+ # Target tensor is likely for global batch. Narrow to this device. | |
+ T = dlrm.narrow(Z.size(0), T, dim=0) | |
+ | |
if args.loss_function == "mse" or args.loss_function == "bce": | |
if use_gpu: | |
return loss_fn(Z, T.to(device)) | |
@@ -848,9 +1219,12 @@ if __name__ == "__main__": | |
k = 0 | |
# Load model is specified | |
- if not (args.load_model == ""): | |
+ if args.load_model != "": | |
print("Loading saved model {}".format(args.load_model)) | |
- if use_gpu: | |
+ if use_tpu: | |
+ # XXX: add tpu capabilities to load. | |
+ raise NotImplementedError('Add tpu capabilities to load.') | |
+ elif use_gpu: | |
if dlrm.ndevices > 1: | |
# NOTE: when targeting inference on multiple GPUs, | |
# load the model as is on CPU or GPU, with the move | |
@@ -932,6 +1306,7 @@ if __name__ == "__main__": | |
t1 = time_wrap(use_gpu) | |
# early exit if nbatches was set by the user and has been exceeded | |
+ # XXX: what about j being reset every epoch? | |
if nbatches > 0 and j >= nbatches: | |
break | |
''' | |
@@ -945,7 +1320,12 @@ if __name__ == "__main__": | |
''' | |
# forward pass | |
- Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device) | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ # args[1:] below will be used in the custom backward | |
+ Z, fullbatch_localembs, localbatch_fullembs = \ | |
+ dlrm_wrap(X, lS_o, lS_i, use_gpu, device) | |
+ else: | |
+ Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device) | |
# loss | |
E = loss_fn_wrap(Z, T, use_gpu, device) | |
@@ -955,17 +1335,25 @@ if __name__ == "__main__": | |
print(Z.detach().cpu().numpy()) | |
print(E.detach().cpu().numpy()) | |
''' | |
- # compute loss and accuracy | |
- L = E.detach().cpu().numpy() # numpy array | |
- S = Z.detach().cpu().numpy() # numpy array | |
- T = T.detach().cpu().numpy() # numpy array | |
- mbs = T.shape[0] # = args.mini_batch_size except maybe for last | |
- A = np.sum((np.round(S, 0) == T).astype(np.uint8)) | |
+ | |
+ if use_tpu: | |
+ mbs = T.shape[0] * len(mp_replica_groups) | |
+ else: | |
+ mbs = T.shape[0] # = args.mini_batch_size except maybe for last | |
+ # XXX: conv related: T is a vector of floats here. | |
+ # args.round_targets related | |
+ # FIXME: figure out a way to do this on tpu | |
+ # S = Z.detach().cpu().numpy() # numpy array | |
+ # T = T.detach().cpu().numpy() # numpy array | |
+ #A = np.sum((np.round(S, 0) == T).astype(np.uint8)) | |
+ A = 0 | |
if not args.inference_only: | |
# scaled error gradient propagation | |
# (where we do not accumulate gradients across mini-batches) | |
optimizer.zero_grad() | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ emb_local_optimizer.zero_grad() | |
# backward pass | |
E.backward() | |
# debug prints (check gradient norm) | |
@@ -973,17 +1361,34 @@ if __name__ == "__main__": | |
# if hasattr(l, 'weight'): | |
# print(l.weight.grad.norm().item()) | |
- # optimizer | |
- optimizer.step() | |
+ # backward + optimizer | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ # XXX: no clip grad? | |
+ # Full allreduce across all devices for the MLP parts. | |
+ xm.optimizer_step(optimizer, groups=None) | |
+ # bwd pass for the embedding tables. | |
+ dlrm.tpu_local_backward( | |
+ fullbatch_localembs, localbatch_fullembs, | |
+ ) | |
+ if len(dp_replica_groups[0]) > 1: | |
+ xm.optimizer_step( | |
+ emb_local_optimizer, groups=dp_replica_groups, | |
+ ) | |
+ else: | |
+ # no allreduce, just step | |
+ emb_local_optimizer.step() | |
+ else: | |
+ optimizer.step() | |
lr_scheduler.step() | |
+ | |
if args.mlperf_logging: | |
total_time += iteration_time | |
else: | |
t2 = time_wrap(use_gpu) | |
total_time += t2 - t1 | |
total_accu += A | |
- total_loss += L * mbs | |
+ total_loss += E.detach() * mbs | |
total_iter += 1 | |
total_samp += mbs | |
@@ -996,6 +1401,12 @@ if __name__ == "__main__": | |
# print time, loss and accuracy | |
if should_print or should_test: | |
+ # XXX: detach+cpu+numpy seems to avoid the "aten" counter!!!! | |
+ #L = E.detach().cpu().numpy() # numpy array | |
+ #S = Z.detach().cpu().numpy() # numpy array | |
+ #T = T.detach().cpu().numpy() # numpy array | |
+ if use_tpu: | |
+ xm.mark_step() | |
gT = 1000.0 * total_time / total_iter if args.print_time else -1 | |
total_time = 0 | |
@@ -1007,10 +1418,13 @@ if __name__ == "__main__": | |
str_run_type = "inference" if args.inference_only else "training" | |
print( | |
- "Finished {} it {}/{} of epoch {}, {:.2f} ms/it, ".format( | |
- str_run_type, j + 1, nbatches, k, gT | |
+ ( | |
+ "Finished {} it {}/{} of epoch {}, {:.2f} ms/it, " | |
+ "loss {:.6f}, accuracy {:3.3f} %, {} samples, @ {}" | |
+ ).format( | |
+ str_run_type, j + 1, nbatches, k, gT, | |
+ gL, gA * 100, total_samp, datetime.now() | |
) | |
- + "loss {:.6f}, accuracy {:3.3f} %".format(gL, gA * 100) | |
) | |
# Uncomment the line below to print out the total time with overhead | |
# print("Accumulated time so far: {}" \ | |
@@ -1020,6 +1434,8 @@ if __name__ == "__main__": | |
# testing | |
if should_test and not args.inference_only: | |
+ # XXX: code path not hit currently | |
+ raise NotImplementedError('not hit') | |
# don't measure training iter time in a test iteration | |
if args.mlperf_logging: | |
previous_iteration_time = None | |
@@ -1198,6 +1614,8 @@ if __name__ == "__main__": | |
break | |
k += 1 # nepochs | |
+ if use_tpu and args.tpu_metrics_debug: | |
+ print(met.metrics_report()) | |
# profiling | |
if args.enable_profiling: | |
@@ -1234,3 +1652,7 @@ if __name__ == "__main__": | |
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx") | |
# check the onnx model | |
onnx.checker.check_model(dlrm_pytorch_onnx) | |
+ | |
+ | |
+if __name__ == "__main__": | |
+ main() | |
diff --git a/dlrm_tpu_runner.py b/dlrm_tpu_runner.py | |
new file mode 100644 | |
index 0000000..d31aa6d | |
--- /dev/null | |
+++ b/dlrm_tpu_runner.py | |
@@ -0,0 +1,15 @@ | |
+import sys | |
+import argparse | |
+ | |
+import torch_xla.distributed.xla_multiprocessing as xmp | |
+ | |
+from dlrm_s_pytorch import main, parse_args | |
+ | |
+ | |
+if __name__ == '__main__': | |
+ pre_spawn_parser = argparse.ArgumentParser() | |
+ pre_spawn_parser.add_argument( | |
+ "--tpu-cores", type=int, default=8, choices=[1, 8] | |
+ ) | |
+ pre_spawn_flags, _ = pre_spawn_parser.parse_known_args() | |
+ xmp.spawn(main, args=(), nprocs=pre_spawn_flags.tpu_cores) | |
diff --git a/tools/xla_embedding_bag.py b/tools/xla_embedding_bag.py | |
new file mode 100644 | |
index 0000000..857748b | |
--- /dev/null | |
+++ b/tools/xla_embedding_bag.py | |
@@ -0,0 +1,39 @@ | |
+# Copyright (c) Facebook, Inc. and its affiliates. | |
+# | |
+# This source code is licensed under the MIT license found in the | |
+# LICENSE file in the root directory of this source tree. | |
+ | |
+ | |
+import torch | |
+import torch.nn as nn | |
+ | |
+ | |
+class XlaEmbeddingBag(nn.Module): | |
+ """ | |
+ nn.EmbeddingBag is not lowered just yet to xla. | |
+ This performs the same functionality, in an xla compatible, sub-optimal way. | |
+ | |
+ Warning!: only works with constant offsets atm. | |
+ """ | |
+ | |
+ def __init__(self, n, m, mode, offset, *args, **kwargs): | |
+ super(XlaEmbeddingBag, self).__init__() | |
+ self.n = n | |
+ self.m = m | |
+ self.mode = mode | |
+ self.reduce_fn = getattr(torch, self.mode) | |
+ self.offset = offset | |
+ self.embtable = nn.Embedding(n, m, *args, **kwargs) | |
+ | |
+ def forward(self, sparse_index_group_batch, sparse_offset_group_batch): | |
+ emb = self.embtable(sparse_index_group_batch) | |
+ return emb | |
+ # XXX: only works w/ constant offset atm | |
+ bsz = emb.size(0) // self.offset | |
+ emb = emb.reshape(bsz, self.offset, *emb.size()[1:]) | |
+ return self.reduce_fn(emb, axis=1) | |
+ #return reduce_fn(self.embtable(_) for _ in inp_list) | |
+ | |
+ @property | |
+ def weight(self): | |
+ return self.embtable.weight |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment