Last active
June 25, 2020 00:26
-
-
Save taylanbil/ce2221aeb2ae3c1572486ec42cebbfc3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
git diff HEAD~1 | |
diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py | |
index 6cbe382..6f1c849 100644 | |
--- a/dlrm_data_pytorch.py | |
+++ b/dlrm_data_pytorch.py | |
@@ -266,7 +266,7 @@ class CriteoDataset(Dataset): | |
if self.memory_map: | |
if self.split == 'none' or self.split == 'train': | |
- # check if need to swicth to next day and load data | |
+ # check if need to switch to next day and load data | |
if index == self.offset_per_file[self.day]: | |
# print("day_boundary switch", index) | |
self.day_boundary = self.offset_per_file[self.day] | |
@@ -503,7 +503,7 @@ def make_criteo_data_and_loaders(args): | |
num_workers=args.num_workers, | |
collate_fn=collate_wrapper_criteo, | |
pin_memory=False, | |
- drop_last=False, # True | |
+ drop_last=args.drop_last, | |
) | |
test_loader = torch.utils.data.DataLoader( | |
@@ -513,7 +513,7 @@ def make_criteo_data_and_loaders(args): | |
num_workers=args.test_num_workers, | |
collate_fn=collate_wrapper_criteo, | |
pin_memory=False, | |
- drop_last=False, # True | |
+ drop_last=args.drop_last, | |
) | |
return train_data, train_loader, test_data, test_loader | |
@@ -627,7 +627,9 @@ def collate_wrapper_random(list_of_tuples): | |
T) | |
-def make_random_data_and_loader(args, ln_emb, m_den): | |
+def make_random_data_and_loader( | |
+ args, ln_emb, m_den, n_replicas=None, rank=None, | |
+): | |
train_data = RandomDataset( | |
m_den, | |
@@ -645,6 +647,14 @@ def make_random_data_and_loader(args, ln_emb, m_den): | |
reset_seed_on_access=True, | |
rand_seed=args.numpy_rand_seed | |
) # WARNING: generates a batch of lookups at once | |
+ train_sampler = None | |
+ if rank is not None: | |
+ torch.utils.data.distributed.DistributedSampler( | |
+ train_data, | |
+ num_replicas=n_replicas, | |
+ rank=rank, | |
+ shuffle=True, | |
+ ) | |
train_loader = torch.utils.data.DataLoader( | |
train_data, | |
batch_size=1, | |
@@ -652,7 +662,8 @@ def make_random_data_and_loader(args, ln_emb, m_den): | |
num_workers=args.num_workers, | |
collate_fn=collate_wrapper_random, | |
pin_memory=False, | |
- drop_last=False, # True | |
+ drop_last=args.drop_last, | |
+ sampler=train_sampler, | |
) | |
return train_data, train_loader | |
@@ -764,7 +775,9 @@ def generate_uniform_input_batch( | |
) | |
# sparse indices to be used per embedding | |
r = ra.random(sparse_group_size) | |
- sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64)) | |
+ # XXX: why np.unique ? This is producing a different input shape.. | |
+ #sparse_group = np.unique(np.round(r * (size - 1)).astype(np.int64)) | |
+ sparse_group = np.round(r * (size - 1)).astype(np.int64) | |
# reset sparse_group_size in case some index duplicates were removed | |
sparse_group_size = np.int64(sparse_group.size) | |
# store lengths and indices | |
diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py | |
index 1955bb9..1fef8b0 100644 | |
--- a/dlrm_s_pytorch.py | |
+++ b/dlrm_s_pytorch.py | |
@@ -58,6 +58,7 @@ import builtins | |
import functools | |
# import bisect | |
# import shutil | |
+import sys | |
import time | |
import json | |
# data generation | |
@@ -93,10 +94,14 @@ import sklearn.metrics | |
from torch.optim.lr_scheduler import _LRScheduler | |
+ | |
exc = getattr(builtins, "IOError", "FileNotFoundError") | |
class LRPolicyScheduler(_LRScheduler): | |
- def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps): | |
+ def __init__( | |
+ self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps, | |
+ use_tpu=False, local_optimizer=None, | |
+ ): | |
self.num_warmup_steps = num_warmup_steps | |
self.decay_start_step = decay_start_step | |
self.decay_end_step = decay_start_step + num_decay_steps | |
@@ -105,6 +110,9 @@ class LRPolicyScheduler(_LRScheduler): | |
if self.decay_start_step < self.num_warmup_steps: | |
sys.exit("Learning rate warmup must finish before the decay starts") | |
+ self.use_tpu = use_tpu | |
+ self.sparse_feature_local_optimizer = local_optimizer | |
+ | |
super(LRPolicyScheduler, self).__init__(optimizer) | |
def get_lr(self): | |
@@ -129,8 +137,18 @@ class LRPolicyScheduler(_LRScheduler): | |
lr = self.base_lrs | |
return lr | |
+ def step(self): | |
+ super().step() | |
+ if self.sparse_feature_local_optimizer is not None: | |
+ # XXX: is lr always a single value? | |
+ lr = self.get_lr()[0] | |
+ for param_group in self.sparse_feature_local_optimizer.param_groups: | |
+ param_group['lr'] = lr | |
+ | |
+ | |
### define dlrm in PyTorch ### | |
class DLRM_Net(nn.Module): | |
+ | |
def create_mlp(self, ln, sigmoid_layer): | |
# build MLP layer by layer | |
layers = nn.ModuleList() | |
@@ -175,11 +193,26 @@ class DLRM_Net(nn.Module): | |
emb_l = nn.ModuleList() | |
for i in range(0, ln.size): | |
n = ln[i] | |
+ if ( | |
+ self.use_tpu and | |
+ self.ndevices > 1 and | |
+ not self._tpu_index_belongs_to_ordinal(i) | |
+ ): | |
+ # tpu model-parallel mode. only create this ordinal's tables. | |
+ continue | |
# construct embedding operator | |
if self.qr_flag and n > self.qr_threshold: | |
- EE = QREmbeddingBag(n, m, self.qr_collisions, | |
- operation=self.qr_operation, mode="sum", sparse=True) | |
+ # XXX: code path not hit with current tpu tests. | |
+ assert not self.use_tpu, 'not implemented'.upper() | |
+ EE = QREmbeddingBag( | |
+ n, m, self.qr_collisions, | |
+ operation=self.qr_operation, | |
+ mode="sum", sparse=self.sparse, | |
+ xla=self.use_tpu, | |
+ ) | |
elif self.md_flag: | |
+ assert not self.use_tpu, 'not implemented'.upper() | |
+ # XXX: code path not hit with current tpu tests. | |
base = max(m) | |
_m = m[i] if n > self.md_threshold else base | |
EE = PrEmbeddingBag(n, _m, base) | |
@@ -190,7 +223,21 @@ class DLRM_Net(nn.Module): | |
EE.embs.weight.data = torch.tensor(W, requires_grad=True) | |
else: | |
- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) | |
+ if self.use_tpu: | |
+ # XXX: xla currently does not support `nn.EmbeddingBag` | |
+ # so I wrote a custom thing. | |
+ from tools.xla_embedding_bag import XlaEmbeddingBag | |
+ errmsg = ( | |
+ '`use_tpu` was specified. XLA currently only supports ' | |
+ 'fixed length sparse groups.' | |
+ ) | |
+ assert self.offset is not None, errmsg | |
+ EE = XlaEmbeddingBag( | |
+ n, m, mode="sum", | |
+ sparse=self.sparse, offset=self.offset, | |
+ ) | |
+ else: | |
+ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse) | |
# initialize embeddings | |
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) | |
@@ -208,6 +255,25 @@ class DLRM_Net(nn.Module): | |
return emb_l | |
+ def set_xla_replica_groups(self, groups): | |
+ self._xla_replica_groups = groups | |
+ for g in groups: | |
+ if self._ordinal in g: | |
+ self._xla_replica_group = g | |
+ assert self.ndevices == len(g), \ | |
+ 'replica group does not match ndevices, {} vs {}'.format( | |
+ len(g), self.ndevices | |
+ ) | |
+ self._xla_replica_index = g.index(self._ordinal) | |
+ break | |
+ else: | |
+ raise ValueError( | |
+ 'Ordinal {} not in replica groups!'.format(self._ordinal) | |
+ ) | |
+ | |
+ def _tpu_index_belongs_to_ordinal(self, i): | |
+ return i % len(self._xla_replica_group) == self._xla_replica_index | |
+ | |
def __init__( | |
self, | |
m_spa=None, | |
@@ -227,6 +293,9 @@ class DLRM_Net(nn.Module): | |
qr_threshold=200, | |
md_flag=False, | |
md_threshold=200, | |
+ sparse=True, | |
+ use_tpu=False, | |
+ offset=None | |
): | |
super(DLRM_Net, self).__init__() | |
@@ -247,6 +316,9 @@ class DLRM_Net(nn.Module): | |
self.arch_interaction_itself = arch_interaction_itself | |
self.sync_dense_params = sync_dense_params | |
self.loss_threshold = loss_threshold | |
+ self.sparse = (not use_tpu) and sparse | |
+ self.use_tpu = use_tpu | |
+ self.offset = offset | |
# create variables for QR embedding if applicable | |
self.qr_flag = qr_flag | |
if self.qr_flag: | |
@@ -256,13 +328,43 @@ class DLRM_Net(nn.Module): | |
# create variables for MD embedding if applicable | |
self.md_flag = md_flag | |
if self.md_flag: | |
+ if use_tpu: | |
+ raise NotImplementedError( | |
+ """ | |
+ XXX: | |
+ md trick produces mixed shape inputs => recompilations | |
+ on tpu. Document the tradeoff after experimenting. | |
+ """ | |
+ ) | |
self.md_threshold = md_threshold | |
+ if self.use_tpu: | |
+ self._init_tpu() | |
# create operators | |
if ndevices <= 1: | |
self.emb_l = self.create_emb(m_spa, ln_emb) | |
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot) | |
self.top_l = self.create_mlp(ln_top, sigmoid_top) | |
+ def _init_tpu(self): | |
+ # no need to be ordinal aware in full data parallel mode. | |
+ if self.ndevices > 1: | |
+ import torch_xla.core.xla_model as xm | |
+ self._ordinal = xm.get_ordinal() | |
+ self._local_ordinal = xm.get_local_ordinal() | |
+ self._all_reduce = xm.all_reduce | |
+ self._all_gather = xm.all_gather | |
+ | |
+ def _filter_params(self, f): | |
+ for name, p in self.named_parameters(): | |
+ if f(name): | |
+ yield p | |
+ | |
+ def mlp_parameters(self): | |
+ return self._filter_params(lambda name: not name.startswith('emb_l')) | |
+ | |
+ def emb_parameters(self): | |
+ return self._filter_params(lambda name: name.startswith('emb_l')) | |
+ | |
def apply_mlp(self, x, layers): | |
# approach 1: use ModuleList | |
# for layer in layers: | |
@@ -332,9 +434,17 @@ class DLRM_Net(nn.Module): | |
def forward(self, dense_x, lS_o, lS_i): | |
if self.ndevices <= 1: | |
return self.sequential_forward(dense_x, lS_o, lS_i) | |
+ elif self.use_tpu: | |
+ return self.tpu_parallel_forward(dense_x, lS_o, lS_i) | |
else: | |
return self.parallel_forward(dense_x, lS_o, lS_i) | |
+ def clamp_output(self, p): | |
+ z = p | |
+ if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: | |
+ z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) | |
+ return z | |
+ | |
def sequential_forward(self, dense_x, lS_o, lS_i): | |
# process dense features (using bottom mlp), resulting in a row vector | |
x = self.apply_mlp(dense_x, self.bot_l) | |
@@ -355,13 +465,84 @@ class DLRM_Net(nn.Module): | |
p = self.apply_mlp(z, self.top_l) | |
# clamp output if needed | |
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: | |
- z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold)) | |
- else: | |
- z = p | |
- | |
+ z = self.clamp_output(p) | |
return z | |
+ def _partition_to_device(self, iterable): | |
+ return [ | |
+ obj for k, obj in enumerate(iterable) | |
+ if self._tpu_index_belongs_to_ordinal(k) | |
+ ] | |
+ | |
+ def _gather_other_embeddings(self, ordinal_data): | |
+ x = iter(ordinal_data) | |
+ ordinal_data = torch.stack(ordinal_data) | |
+ full_data = self._all_gather( | |
+ ordinal_data, dim=0, groups=self._xla_replica_groups | |
+ ) | |
+ return full_data | |
+ | |
+ def _narrow(self, local_bsz, tensor, dim=1): | |
+ return torch.narrow( | |
+ tensor, dim, self._xla_replica_index*local_bsz, local_bsz | |
+ ) | |
+ | |
+ def _gather_other_samples(self, array, dim=0): | |
+ out = torch.stack(array) | |
+ # dim+1 because stack introduces a dimension to the left | |
+ out = self._all_gather(out, dim=dim+1, groups=self._xla_replica_groups) | |
+ return out | |
+ | |
+ def tpu_parallel_forward(self, dense_x, lS_o, lS_i): | |
+ batch_size = dense_x.size()[0] | |
+ ndevices = self.ndevices | |
+ assert not batch_size % ndevices, \ | |
+ f"{batch_size} is bsz, {ndevices} devices" | |
+ local_bsz = batch_size // ndevices | |
+ # XXX: no redistribute model if bsz changes for tpus. is this ok? | |
+ # this is bc all weights are updated at all times. | |
+ # I think on gpus, it updates `ndevices` many, which can fluctuate | |
+ # w/ changes in bsz. | |
+ | |
+ dense_x = self._narrow(local_bsz, dense_x, dim=0) | |
+ | |
+ #bottom mlp | |
+ x = self.bot_l(dense_x) | |
+ # embeddings | |
+ lS_o = self._partition_to_device(lS_o) | |
+ lS_i = self._partition_to_device(lS_i) | |
+ ly_local = self.apply_emb(lS_o, lS_i, self.emb_l) | |
+ | |
+ # at this point, each device have the embeddings belonging to itself. | |
+ # we do a gather to acquire all embeddings, i.e. full input. | |
+ # followed by a `narrow`, so rest of the model can run data-parallel. | |
+ # XXX: pods? will gather collect from all devices? option to collect only locally? | |
+ ly = self._gather_other_embeddings(ly_local) | |
+ # _gather introduces a dim, so batch dim is now 1 | |
+ ly = self._narrow(local_bsz, ly, dim=1) | |
+ # now stop gradients from flowing back. | |
+ ly = [_.clone().detach().requires_grad_(True) for _ in ly] | |
+ | |
+ # interactions | |
+ z = self.interact_features(x, ly) | |
+ | |
+ # top mlp | |
+ p = self.top_l(z) | |
+ | |
+ # clamp output if needed | |
+ z = self.clamp_output(p) | |
+ | |
+ return z, ly_local, ly # extra return args needed during bwd. | |
+ | |
+ def tpu_local_backward(self, fullbatch_localembs, localbatch_fullembs): | |
+ localbatch_fullgrads = [_.grad for _ in localbatch_fullembs] | |
+ grad = self._gather_other_samples(localbatch_fullgrads) # inv to narrow | |
+ grad = self._partition_to_device(grad) | |
+ assert len(fullbatch_localembs) == len(grad), \ | |
+ '{} vs {}'.format(len(fullbatch_localembs), len(grad)) | |
+ for e, g in zip(fullbatch_localembs, grad): | |
+ e.backward(g) | |
+ | |
def parallel_forward(self, dense_x, lS_o, lS_i): | |
### prepare model (overwrite) ### | |
# WARNING: # of devices must be >= batch size in parallel_forward call | |
@@ -461,22 +642,14 @@ class DLRM_Net(nn.Module): | |
p0 = gather(p, self.output_d, dim=0) | |
# clamp output if needed | |
- if 0.0 < self.loss_threshold and self.loss_threshold < 1.0: | |
- z0 = torch.clamp( | |
- p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold) | |
- ) | |
- else: | |
- z0 = p0 | |
+ z0 = self.clamp_output(p0) | |
return z0 | |
-if __name__ == "__main__": | |
- ### import packages ### | |
- import sys | |
+def parse_args(): | |
import argparse | |
- ### parse arguments ### | |
parser = argparse.ArgumentParser( | |
description="Train Deep Learning Recommendation Model (DLRM)" | |
) | |
@@ -503,22 +676,24 @@ if __name__ == "__main__": | |
parser.add_argument("--loss-weights", type=str, default="1.0-1.0") # for wbce | |
parser.add_argument("--loss-threshold", type=float, default=0.0) # 1.0e-7 | |
parser.add_argument("--round-targets", type=bool, default=False) | |
+ parser.add_argument("--pred-threshold", type=float, default=0.5) | |
# data | |
parser.add_argument("--data-size", type=int, default=1) | |
parser.add_argument("--num-batches", type=int, default=0) | |
parser.add_argument( | |
"--data-generation", type=str, default="random" | |
) # synthetic or dataset | |
+ parser.add_argument("--drop-last", action="store_true") | |
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log") | |
parser.add_argument("--data-set", type=str, default="kaggle") # or terabyte | |
parser.add_argument("--raw-data-file", type=str, default="") | |
parser.add_argument("--processed-data-file", type=str, default="") | |
parser.add_argument("--data-randomize", type=str, default="total") # or day or none | |
- parser.add_argument("--data-trace-enable-padding", type=bool, default=False) | |
+ parser.add_argument("--data-trace-enable-padding", action='store_true') | |
parser.add_argument("--max-ind-range", type=int, default=-1) | |
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0) # in [0, 1] | |
parser.add_argument("--num-indices-per-lookup", type=int, default=10) | |
- parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False) | |
+ parser.add_argument("--num-indices-per-lookup-fixed", action='store_true') | |
parser.add_argument("--num-workers", type=int, default=0) | |
parser.add_argument("--memory-map", action="store_true", default=False) | |
# training | |
@@ -532,8 +707,11 @@ if __name__ == "__main__": | |
parser.add_argument("--inference-only", action="store_true", default=False) | |
# onnx | |
parser.add_argument("--save-onnx", action="store_true", default=False) | |
- # gpu | |
+ # accelerators | |
parser.add_argument("--use-gpu", action="store_true", default=False) | |
+ parser.add_argument("--use-tpu", action="store_true", default=False) | |
+ parser.add_argument("--tpu-model-parallel-group-len", type=int, default=1) | |
+ parser.add_argument("--tpu-data-parallel", action="store_true") | |
# debugging and profiling | |
parser.add_argument("--print-freq", type=int, default=1) | |
parser.add_argument("--test-freq", type=int, default=-1) | |
@@ -558,16 +736,7 @@ if __name__ == "__main__": | |
parser.add_argument("--lr-num-warmup-steps", type=int, default=0) | |
parser.add_argument("--lr-decay-start-step", type=int, default=0) | |
parser.add_argument("--lr-num-decay-steps", type=int, default=0) | |
- args = parser.parse_args() | |
- | |
- if args.mlperf_logging: | |
- print('command line args: ', json.dumps(vars(args))) | |
- | |
- ### some basic setup ### | |
- np.random.seed(args.numpy_rand_seed) | |
- np.set_printoptions(precision=args.print_precision) | |
- torch.set_printoptions(precision=args.print_precision) | |
- torch.manual_seed(args.numpy_rand_seed) | |
+ args, _ = parser.parse_known_args() | |
if (args.test_mini_batch_size < 0): | |
# if the parameter is not set, use the training batch size | |
@@ -576,8 +745,73 @@ if __name__ == "__main__": | |
# if the parameter is not set, use the same parameter for training | |
args.test_num_workers = args.num_workers | |
+ return args | |
+ | |
+ | |
+def tpu_get_xla_replica_groups(args): | |
+ import torch_xla.core.xla_model as xm | |
+ # XXX: would it make sense to sort by size here -- to evenly dist. tables? | |
+ world_size = xm.xrt_world_size() | |
+ if args.tpu_data_parallel: | |
+ return [[i] for i in range(world_size)], [list(range(world_size))] | |
+ num_tables = args.arch_embedding_size.count("-") + 1 | |
+ len_mp_group = args.tpu_model_parallel_group_len | |
+ assert not num_tables % len_mp_group, \ | |
+ 'Model parallel group size has to divide number of emb tables evenly.' | |
+ assert not world_size % len_mp_group, \ | |
+ 'Length of model parallel groups has to evenly divide `xrt_world_size`' | |
+ len_dp_group = world_size // len_mp_group | |
+ mp_groups = [ | |
+ [len_mp_group*d+m for m in range(len_mp_group)] | |
+ for d in range(len_dp_group) | |
+ ] | |
+ dp_groups = [ | |
+ [len_dp_group*d+m for m in range(len_dp_group)] | |
+ for d in range(len_mp_group) | |
+ ] | |
+ return mp_groups, dp_groups | |
+ | |
+ | |
+def main(*_args): | |
+ ### import packages ### | |
+ import sys | |
+ | |
+ args = parse_args() | |
+ ### some basic setup ### | |
+ np.random.seed(args.numpy_rand_seed) | |
+ np.set_printoptions(precision=args.print_precision) | |
+ torch.set_printoptions(precision=args.print_precision) | |
+ # XXX: does this imply we init the emb tables the same way? | |
+ torch.manual_seed(args.numpy_rand_seed) | |
+ | |
use_gpu = args.use_gpu and torch.cuda.is_available() | |
- if use_gpu: | |
+ use_tpu = args.use_tpu | |
+ print = builtins.print | |
+ if use_tpu: | |
+ use_gpu = False | |
+ import torch_xla.core.xla_model as xm | |
+ import torch_xla.debug.metrics as met | |
+ import torch_xla.distributed.parallel_loader as pl | |
+ print = xm.master_print | |
+ device = xm.xla_device() | |
+ print("Using {} TPU core(s)...".format(xm.xrt_world_size())) | |
+ if args.enable_profiling: | |
+ print("Profiling was enabled. Turning it off for TPUs.") | |
+ args.enable_profiling = False | |
+ if not args.num_indices_per_lookup_fixed: | |
+ # XXX: does this lead to recompilations? | |
+ raise NotImplementedError | |
+ mp_replica_groups, dp_replica_groups = tpu_get_xla_replica_groups(args) | |
+ print('XLA replica groups for Model Parallel:\n\t', mp_replica_groups) | |
+ print('XLA replica groups for Model Parallel:\n\t', dp_replica_groups) | |
+ if len(dp_replica_groups) == 1: | |
+ # i.e. no allgather etc in the emb layer. | |
+ print("TPU data-parallel mode, setting --tpu-data-parallel to True") | |
+ args.tpu_data_parallel = True | |
+ else: | |
+ print("TPU model-parallel mode, setting --drop-last=True") | |
+ args.drop_last = True | |
+ elif use_gpu: | |
torch.cuda.manual_seed_all(args.numpy_rand_seed) | |
torch.backends.cudnn.deterministic = True | |
device = torch.device("cuda", 0) | |
@@ -587,6 +821,9 @@ if __name__ == "__main__": | |
device = torch.device("cpu") | |
print("Using CPU...") | |
+ if args.mlperf_logging: | |
+ print('command line args: ', json.dumps(vars(args))) | |
+ | |
### prepare training data ### | |
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") | |
# input data | |
@@ -610,8 +847,26 @@ if __name__ == "__main__": | |
# input and target at random | |
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-") | |
m_den = ln_bot[0] | |
- train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den) | |
- nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) | |
+ data_args = [args, ln_emb, m_den] | |
+ if use_tpu: | |
+ ordinal = xm.get_ordinal() | |
+ for g in dp_replica_groups: | |
+ if ordinal in g: | |
+ n_replicas, rank = len(g), g.index(ordinal) | |
+ break | |
+ else: | |
+ raise ValueError( | |
+ 'Ordinal {} not in replica groups!'.format(self._ordinal) | |
+ ) | |
+ data_args.extend([n_replicas, rank]) # extend w/ n_replicas and rank | |
+ train_data, train_ld = dp.make_random_data_and_loader(*data_args) | |
+ #nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) | |
+ nbatches = len(train_ld) | |
+ | |
+ if use_tpu: | |
+ # XXX: test_data is unused. | |
+ # Wrap w/ torch_xla's loader | |
+ train_ld = pl.MpDeviceLoader(train_ld, device) | |
### parse command line arguments ### | |
m_spa = args.arch_sparse_feature_size | |
@@ -737,7 +992,26 @@ if __name__ == "__main__": | |
print([S_i.detach().cpu().tolist() for S_i in lS_i]) | |
print(T.detach().cpu().numpy()) | |
- ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1 | |
+ ndevices = -1 | |
+ if use_tpu: | |
+ ndevices = len(dp_replica_groups) | |
+ # XXX: it could work when the following are violated too. | |
+ # TODO: implement that. | |
+ if args.mini_batch_size % ndevices: | |
+ raise NotImplementedError( | |
+ 'bsz is {}, ndevices is {}'.format( | |
+ args.mini_batch_size, ndevices, | |
+ ) | |
+ ) | |
+ if (num_fea - 1) % ndevices: | |
+ raise NotImplementedError( | |
+ 'num embtables is {}, ndevices is {}'.format( | |
+ num_fea-1, ndevices, | |
+ ) | |
+ ) | |
+ | |
+ elif use_gpu: | |
+ ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) | |
### construct the neural network specified above ### | |
# WARNING: to obtain exactly the same initialization for | |
@@ -761,6 +1035,9 @@ if __name__ == "__main__": | |
qr_threshold=args.qr_threshold, | |
md_flag=args.md_flag, | |
md_threshold=args.md_threshold, | |
+ sparse=device.type != 'xla', | |
+ use_tpu=use_tpu, | |
+ offset=args.num_indices_per_lookup, | |
) | |
# test prints | |
if args.debug_mode: | |
@@ -776,6 +1053,13 @@ if __name__ == "__main__": | |
dlrm = dlrm.to(device) # .cuda() | |
if dlrm.ndevices > 1: | |
dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb) | |
+ if use_tpu: | |
+ # XXX: ndevices is redundant, but meh. | |
+ if dlrm.ndevices > 1: | |
+ dlrm.set_xla_replica_groups(mp_replica_groups) | |
+ dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb) | |
+ dlrm.device = device | |
+ dlrm = dlrm.to(device) | |
# specify the loss function | |
if args.loss_function == "mse": | |
@@ -790,11 +1074,34 @@ if __name__ == "__main__": | |
if not args.inference_only: | |
# specify the optimizer algorithm | |
- optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate) | |
- lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
- args.lr_num_decay_steps) | |
+ | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ # tpu's paradigm is different than gpu/cpu. Each process here runs | |
+ # on its own and all reduces need to happen in a particular way. | |
+ # Data parallel part, i.e. the MLP part will be allreduced. | |
+ # EmbeddingBag part will not be allreduced. | |
+ optimizer = torch.optim.SGD( | |
+ dlrm.mlp_parameters(), lr=args.learning_rate | |
+ ) | |
+ emb_local_optimizer = torch.optim.SGD( | |
+ dlrm.emb_parameters(), lr=args.learning_rate | |
+ ) | |
+ lr_scheduler = LRPolicyScheduler( | |
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
+ args.lr_num_decay_steps, use_tpu=True, | |
+ local_optimizer=emb_local_optimizer, | |
+ ) | |
+ else: | |
+ optimizer = torch.optim.SGD( | |
+ dlrm.parameters(), lr=args.learning_rate | |
+ ) | |
+ lr_scheduler = LRPolicyScheduler( | |
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
+ args.lr_num_decay_steps | |
+ ) | |
### main loop ### | |
+ | |
def time_wrap(use_gpu): | |
if use_gpu: | |
torch.cuda.synchronize() | |
@@ -808,15 +1115,16 @@ if __name__ == "__main__": | |
else lS_i.to(device) | |
lS_o = [S_o.to(device) for S_o in lS_o] if isinstance(lS_o, list) \ | |
else lS_o.to(device) | |
- return dlrm( | |
- X.to(device), | |
- lS_o, | |
- lS_i | |
- ) | |
+ return dlrm(X.to(device), lS_o, lS_i) | |
else: | |
return dlrm(X, lS_o, lS_i) | |
def loss_fn_wrap(Z, T, use_gpu, device): | |
+ if T.size(0) > Z.size(0): | |
+ # This happens for tpus. | |
+ # Target tensor is likely for global batch. Narrow to this device. | |
+ T = dlrm._narrow(Z.size(0), T, dim=0) | |
+ | |
if args.loss_function == "mse" or args.loss_function == "bce": | |
if use_gpu: | |
return loss_fn(Z, T.to(device)) | |
@@ -848,9 +1156,12 @@ if __name__ == "__main__": | |
k = 0 | |
# Load model is specified | |
- if not (args.load_model == ""): | |
+ if args.load_model != "": | |
print("Loading saved model {}".format(args.load_model)) | |
- if use_gpu: | |
+ if use_tpu: | |
+ # XXX: add tpu capabilities to load. | |
+ raise NotImplementedError('Add tpu capabilities to load.') | |
+ elif use_gpu: | |
if dlrm.ndevices > 1: | |
# NOTE: when targeting inference on multiple GPUs, | |
# load the model as is on CPU or GPU, with the move | |
@@ -907,6 +1218,7 @@ if __name__ == "__main__": | |
) | |
print("time/loss/accuracy (if enabled):") | |
+ # XXX: what is profiler? turning it off for tpus. | |
with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof: | |
while k < args.nepochs: | |
if k < skip_upto_epoch: | |
@@ -932,6 +1244,7 @@ if __name__ == "__main__": | |
t1 = time_wrap(use_gpu) | |
# early exit if nbatches was set by the user and has been exceeded | |
+ # XXX: what about j being reset every epoch? | |
if nbatches > 0 and j >= nbatches: | |
break | |
''' | |
@@ -944,8 +1257,22 @@ if __name__ == "__main__": | |
print(T.detach().cpu().numpy()) | |
''' | |
+ # XXX: clean | |
+ """ | |
+ print( | |
+ 'SHAPES', | |
+ X.shape, | |
+ [_.shape for _ in lS_o], | |
+ [_.shape for _ in lS_i], | |
+ ) | |
+ """ | |
# forward pass | |
- Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device) | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ # args[1:] below will be used in the custom backward | |
+ Z, fullbatch_localembs, localbatch_fullembs = \ | |
+ dlrm_wrap(X, lS_o, lS_i, use_gpu, device) | |
+ else: | |
+ Z = dlrm_wrap(X, lS_o, lS_i, use_gpu, device) | |
# loss | |
E = loss_fn_wrap(Z, T, use_gpu, device) | |
@@ -955,12 +1282,15 @@ if __name__ == "__main__": | |
print(Z.detach().cpu().numpy()) | |
print(E.detach().cpu().numpy()) | |
''' | |
- # compute loss and accuracy | |
- L = E.detach().cpu().numpy() # numpy array | |
- S = Z.detach().cpu().numpy() # numpy array | |
- T = T.detach().cpu().numpy() # numpy array | |
+ | |
mbs = T.shape[0] # = args.mini_batch_size except maybe for last | |
- A = np.sum((np.round(S, 0) == T).astype(np.uint8)) | |
+ # XXX: conv related: T is a vector of floats here. | |
+ # args.round_targets related | |
+ # FIXME: figure out a way to do this on tpu | |
+ # S = Z.detach().cpu().numpy() # numpy array | |
+ # T = T.detach().cpu().numpy() # numpy array | |
+ #A = np.sum((np.round(S, 0) == T).astype(np.uint8)) | |
+ A = 0 | |
if not args.inference_only: | |
# scaled error gradient propagation | |
@@ -973,17 +1303,35 @@ if __name__ == "__main__": | |
# if hasattr(l, 'weight'): | |
# print(l.weight.grad.norm().item()) | |
- # optimizer | |
- optimizer.step() | |
+ # backward + optimizer | |
+ if use_tpu and not args.tpu_data_parallel: | |
+ # XXX: no clip grad? | |
+ # Full allreduce across all devices for the MLP part. | |
+ # XXX: what about the bottom mlp? # FIXME | |
+ xm.optimizer_step(optimizer, groups=None) | |
+ # bwd pass for the embedding tables. | |
+ dlrm.tpu_local_backward( | |
+ fullbatch_localembs, localbatch_fullembs, | |
+ ) | |
+ if len(dp_replica_groups[0]) > 1: | |
+ xm.optimizer_step( | |
+ emb_local_optimizer, groups=dp_replica_groups, | |
+ ) | |
+ else: | |
+ # no allreduce, just step | |
+ emb_local_optimizer.step() | |
+ else: | |
+ optimizer.step() | |
lr_scheduler.step() | |
+ | |
if args.mlperf_logging: | |
total_time += iteration_time | |
else: | |
t2 = time_wrap(use_gpu) | |
total_time += t2 - t1 | |
total_accu += A | |
- total_loss += L * mbs | |
+ total_loss += E.detach() * mbs | |
total_iter += 1 | |
total_samp += mbs | |
@@ -996,6 +1344,12 @@ if __name__ == "__main__": | |
# print time, loss and accuracy | |
if should_print or should_test: | |
+ # XXX: detach+cpu+numpy seems to avoid the "aten" counter!!!! | |
+ #L = E.detach().cpu().numpy() # numpy array | |
+ #S = Z.detach().cpu().numpy() # numpy array | |
+ #T = T.detach().cpu().numpy() # numpy array | |
+ if use_tpu: | |
+ xm.mark_step() | |
gT = 1000.0 * total_time / total_iter if args.print_time else -1 | |
total_time = 0 | |
@@ -1020,6 +1374,8 @@ if __name__ == "__main__": | |
# testing | |
if should_test and not args.inference_only: | |
+ # XXX: code path not hit currently | |
+ raise NotImplementedError('not hit') | |
# don't measure training iter time in a test iteration | |
if args.mlperf_logging: | |
previous_iteration_time = None | |
@@ -1234,3 +1590,7 @@ if __name__ == "__main__": | |
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx") | |
# check the onnx model | |
onnx.checker.check_model(dlrm_pytorch_onnx) | |
+ | |
+ | |
+if __name__ == "__main__": | |
+ main() | |
diff --git a/dlrm_tpu_runner.py b/dlrm_tpu_runner.py | |
new file mode 100644 | |
index 0000000..d31aa6d | |
--- /dev/null | |
+++ b/dlrm_tpu_runner.py | |
@@ -0,0 +1,15 @@ | |
+import sys | |
+import argparse | |
+ | |
+import torch_xla.distributed.xla_multiprocessing as xmp | |
+ | |
+from dlrm_s_pytorch import main, parse_args | |
+ | |
+ | |
+if __name__ == '__main__': | |
+ pre_spawn_parser = argparse.ArgumentParser() | |
+ pre_spawn_parser.add_argument( | |
+ "--tpu-cores", type=int, default=8, choices=[1, 8] | |
+ ) | |
+ pre_spawn_flags, _ = pre_spawn_parser.parse_known_args() | |
+ xmp.spawn(main, args=(), nprocs=pre_spawn_flags.tpu_cores) | |
diff --git a/tools/xla_embedding_bag.py b/tools/xla_embedding_bag.py | |
new file mode 100644 | |
index 0000000..c474bf1 | |
--- /dev/null | |
+++ b/tools/xla_embedding_bag.py | |
@@ -0,0 +1,38 @@ | |
+# Copyright (c) Facebook, Inc. and its affiliates. | |
+# | |
+# This source code is licensed under the MIT license found in the | |
+# LICENSE file in the root directory of this source tree. | |
+ | |
+ | |
+import torch | |
+import torch.nn as nn | |
+ | |
+ | |
+class XlaEmbeddingBag(nn.Module): | |
+ """ | |
+ nn.EmbeddingBag is not lowered just yet to xla. | |
+ This performs the same functionality, in an xla compatible, sub-optimal way. | |
+ | |
+ Warning!: only works with constant offsets atm. | |
+ """ | |
+ | |
+ def __init__(self, n, m, mode, offset, *args, **kwargs): | |
+ super(XlaEmbeddingBag, self).__init__() | |
+ self.n = n | |
+ self.m = m | |
+ self.mode = mode | |
+ self.offset = offset | |
+ self.embtable = nn.Embedding(n, m, *args, **kwargs) | |
+ | |
+ def forward(self, sparse_index_group_batch, sparse_offset_group_batch): | |
+ emb = self.embtable(sparse_index_group_batch) | |
+ # XXX: only works w/ constant offset atm | |
+ bsz = emb.size(0) // self.offset | |
+ emb = emb.reshape(bsz, self.offset, *emb.size()[1:]) | |
+ reduce_fn = getattr(torch, self.mode) | |
+ return reduce_fn(emb, axis=1) | |
+ #return reduce_fn(self.embtable(_) for _ in inp_list) | |
+ | |
+ @property | |
+ def weight(self): | |
+ return self.embtable.weight | |
diff --git a/tricks/qr_embedding_bag.py b/tricks/qr_embedding_bag.py | |
index 290d795..941ac1b 100644 | |
--- a/tricks/qr_embedding_bag.py | |
+++ b/tricks/qr_embedding_bag.py | |
@@ -112,7 +112,7 @@ class QREmbeddingBag(nn.Module): | |
def __init__(self, num_categories, embedding_dim, num_collisions, | |
operation='mult', max_norm=None, norm_type=2., | |
scale_grad_by_freq=False, mode='mean', sparse=False, | |
- _weight=None): | |
+ _weight=None, xla=False): | |
super(QREmbeddingBag, self).__init__() | |
assert operation in ['concat', 'mult', 'add'], 'Not valid operation!' | |
@@ -127,6 +127,7 @@ class QREmbeddingBag(nn.Module): | |
self.max_norm = max_norm | |
self.norm_type = norm_type | |
self.scale_grad_by_freq = scale_grad_by_freq | |
+ self.xla = xla | |
if self.operation == 'add' or self.operation == 'mult': | |
assert self.embedding_dim[0] == self.embedding_dim[1], \ | |
@@ -153,17 +154,32 @@ class QREmbeddingBag(nn.Module): | |
nn.init.uniform_(self.weight_q, np.sqrt(1 / self.num_categories)) | |
nn.init.uniform_(self.weight_r, np.sqrt(1 / self.num_categories)) | |
+ def _embed_input(self, input_q, input_r, offsets, per_sample_weights): | |
+ # XXX | |
+ if self.xla: | |
+ raise NotImplementedError('implement tricks later') | |
+ | |
+ else: | |
+ embed_q = F.embedding_bag( | |
+ input_q, self.weight_q, offsets, self.max_norm, self.norm_type, | |
+ self.scale_grad_by_freq, self.mode, self.sparse, | |
+ per_sample_weights | |
+ ) | |
+ embed_r = F.embedding_bag( | |
+ input_r, self.weight_r, offsets, self.max_norm, self.norm_type, | |
+ self.scale_grad_by_freq, self.mode, self.sparse, | |
+ per_sample_weights | |
+ ) | |
+ return embed_q, embed_r | |
+ | |
+ | |
def forward(self, input, offsets=None, per_sample_weights=None): | |
input_q = (input / self.num_collisions).long() | |
input_r = torch.remainder(input, self.num_collisions).long() | |
- | |
- embed_q = F.embedding_bag(input_q, self.weight_q, offsets, self.max_norm, | |
- self.norm_type, self.scale_grad_by_freq, self.mode, | |
- self.sparse, per_sample_weights) | |
- embed_r = F.embedding_bag(input_r, self.weight_r, offsets, self.max_norm, | |
- self.norm_type, self.scale_grad_by_freq, self.mode, | |
- self.sparse, per_sample_weights) | |
- | |
+ embed_q, embed_r = self._embed_input( | |
+ input_q, input_r, offsets=offsets, | |
+ per_sample_weights=per_sample_weights, | |
+ ) | |
if self.operation == 'concat': | |
embed = torch.cat((embed_q, embed_r), dim=1) | |
elif self.operation == 'add': |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment