Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Created June 12, 2020 23:13
Show Gist options
  • Save taylanbil/226e63897665bf504ec74e5dfa2370f8 to your computer and use it in GitHub Desktop.
Save taylanbil/226e63897665bf504ec74e5dfa2370f8 to your computer and use it in GitHub Desktop.
[wip] dlrm on tpu
git diff HEAD~1 .
diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py
index 1955bb9..e9ff88a 100644
--- a/dlrm_s_pytorch.py
+++ b/dlrm_s_pytorch.py
@@ -177,9 +177,11 @@ class DLRM_Net(nn.Module):
n = ln[i]
# construct embedding operator
if self.qr_flag and n > self.qr_threshold:
+ # XXX: code path not hit with current tpu tests.
EE = QREmbeddingBag(n, m, self.qr_collisions,
- operation=self.qr_operation, mode="sum", sparse=True)
+ operation=self.qr_operation, mode="sum", sparse=self.sparse)
elif self.md_flag:
+ # XXX: code path not hit with current tpu tests.
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
@@ -190,7 +192,7 @@ class DLRM_Net(nn.Module):
EE.embs.weight.data = torch.tensor(W, requires_grad=True)
else:
- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True)
+ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse)
# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
@@ -227,6 +229,7 @@ class DLRM_Net(nn.Module):
qr_threshold=200,
md_flag=False,
md_threshold=200,
+ sparse=True,
):
super(DLRM_Net, self).__init__()
@@ -247,6 +250,7 @@ class DLRM_Net(nn.Module):
self.arch_interaction_itself = arch_interaction_itself
self.sync_dense_params = sync_dense_params
self.loss_threshold = loss_threshold
+ self.sparse = sparse
# create variables for QR embedding if applicable
self.qr_flag = qr_flag
if self.qr_flag:
@@ -333,7 +337,10 @@ class DLRM_Net(nn.Module):
if self.ndevices <= 1:
return self.sequential_forward(dense_x, lS_o, lS_i)
else:
+ # XXX:
+ raise NotImplementedError('TPU')
return self.parallel_forward(dense_x, lS_o, lS_i)
+ metsumm()
def sequential_forward(self, dense_x, lS_o, lS_i):
# process dense features (using bottom mlp), resulting in a row vector
@@ -363,6 +370,8 @@ class DLRM_Net(nn.Module):
return z
def parallel_forward(self, dense_x, lS_o, lS_i):
+ # XXX
+ raise NotImplementedError('tpu')
### prepare model (overwrite) ###
# WARNING: # of devices must be >= batch size in parallel_forward call
batch_size = dense_x.size()[0]
@@ -471,12 +480,9 @@ class DLRM_Net(nn.Module):
return z0
-if __name__ == "__main__":
- ### import packages ###
- import sys
+def parse_args():
import argparse
- ### parse arguments ###
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
@@ -532,8 +538,11 @@ if __name__ == "__main__":
parser.add_argument("--inference-only", action="store_true", default=False)
# onnx
parser.add_argument("--save-onnx", action="store_true", default=False)
- # gpu
+ # accelerators
parser.add_argument("--use-gpu", action="store_true", default=False)
+ parser.add_argument("--use-tpu", action="store_true", default=False)
+ # XXX: clean
+ #parser.add_argument("--tpu-cores", type=int, default=8, choices=[1, 8])
# debugging and profiling
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
@@ -558,16 +567,7 @@ if __name__ == "__main__":
parser.add_argument("--lr-num-warmup-steps", type=int, default=0)
parser.add_argument("--lr-decay-start-step", type=int, default=0)
parser.add_argument("--lr-num-decay-steps", type=int, default=0)
- args = parser.parse_args()
-
- if args.mlperf_logging:
- print('command line args: ', json.dumps(vars(args)))
-
- ### some basic setup ###
- np.random.seed(args.numpy_rand_seed)
- np.set_printoptions(precision=args.print_precision)
- torch.set_printoptions(precision=args.print_precision)
- torch.manual_seed(args.numpy_rand_seed)
+ args, _ = parser.parse_known_args()
if (args.test_mini_batch_size < 0):
# if the parameter is not set, use the training batch size
@@ -576,8 +576,38 @@ if __name__ == "__main__":
# if the parameter is not set, use the same parameter for training
args.test_num_workers = args.num_workers
+ return args
+
+
+def main(*_args):
+ ### import packages ###
+ import sys
+
+ args = parse_args()
+ ### some basic setup ###
+ np.random.seed(args.numpy_rand_seed)
+ np.set_printoptions(precision=args.print_precision)
+ torch.set_printoptions(precision=args.print_precision)
+ # XXX: does this imply we init the emb tables the same way?
+ torch.manual_seed(args.numpy_rand_seed)
+
use_gpu = args.use_gpu and torch.cuda.is_available()
- if use_gpu:
+ use_tpu = args.use_tpu
+ if use_tpu:
+ use_gpu = False
+ import sys
+ sys.path.insert(0,'/usr/share/torch-xla-nightly/pytorch/xla')
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.distributed.parallel_loader as pl
+ builtin_print = builtins.print
+ print = xm.master_print
+ device = xm.xla_device()
+ print("Using {} TPU cores...".format(xm.xrt_world_size()))
+ if args.enable_profiling:
+ print("Profiling was enabled. Turning it off for TPUs.")
+ args.enable_profiling = False
+ elif use_gpu:
torch.cuda.manual_seed_all(args.numpy_rand_seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda", 0)
@@ -587,7 +617,12 @@ if __name__ == "__main__":
device = torch.device("cpu")
print("Using CPU...")
+ if 1 or args.mlperf_logging:
+ print('command line args: ', json.dumps(vars(args)))
+
### prepare training data ###
+ # XXX: this doesn't seem to shard the data? check dlrm_data_pytorch.py
+ # XXX: are we dropping last?
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
# input data
if (args.data_generation == "dataset"):
@@ -613,6 +648,11 @@ if __name__ == "__main__":
train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den)
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
+ if use_tpu:
+ # XXX: test_data is unused.
+ # Wrap w/ torch_xla's loader
+ train_ld = pl.MpDeviceLoader(train_ld, device)
+
### parse command line arguments ###
m_spa = args.arch_sparse_feature_size
num_fea = ln_emb.size + 1 # num sparse + num dense features
@@ -737,7 +777,14 @@ if __name__ == "__main__":
print([S_i.detach().cpu().tolist() for S_i in lS_i])
print(T.detach().cpu().numpy())
- ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1
+ # XXX: why mini batch size, num_fea?
+ # it must be bc. it shards along those things.
+ if use_tpu:
+ ndevices = xm.xrt_world_size()
+ elif use_gpu:
+ ndevices = min(ngpus, args.mini_batch_size, num_fea - 1)
+ else:
+ ndevices = -1
### construct the neural network specified above ###
# WARNING: to obtain exactly the same initialization for
@@ -761,6 +808,7 @@ if __name__ == "__main__":
qr_threshold=args.qr_threshold,
md_flag=args.md_flag,
md_threshold=args.md_threshold,
+ sparse=device.type != 'xla'
)
# test prints
if args.debug_mode:
@@ -776,6 +824,12 @@ if __name__ == "__main__":
dlrm = dlrm.to(device) # .cuda()
if dlrm.ndevices > 1:
dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb)
+ if use_tpu:
+ # XXX: what is ndevices exactly? How is it used in parallelization?
+ dlrm = dlrm.to(device) # .cuda()
+ if dlrm.ndevices > 1:
+ raise
+ # XXX:
# specify the loss function
if args.loss_function == "mse":
@@ -791,10 +845,13 @@ if __name__ == "__main__":
if not args.inference_only:
# specify the optimizer algorithm
optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate)
- lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
- args.lr_num_decay_steps)
+ lr_scheduler = LRPolicyScheduler(
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step,
+ args.lr_num_decay_steps
+ )
### main loop ###
+ # XXX: what do these functions do?
def time_wrap(use_gpu):
if use_gpu:
torch.cuda.synchronize()
@@ -802,21 +859,19 @@ if __name__ == "__main__":
def dlrm_wrap(X, lS_o, lS_i, use_gpu, device):
if use_gpu: # .cuda()
+ raise
# lS_i can be either a list of tensors or a stacked tensor.
# Handle each case below:
lS_i = [S_i.to(device) for S_i in lS_i] if isinstance(lS_i, list) \
else lS_i.to(device)
lS_o = [S_o.to(device) for S_o in lS_o] if isinstance(lS_o, list) \
else lS_o.to(device)
- return dlrm(
- X.to(device),
- lS_o,
- lS_i
- )
+ return dlrm(X.to(device), lS_o, lS_i)
else:
return dlrm(X, lS_o, lS_i)
def loss_fn_wrap(Z, T, use_gpu, device):
+ # XXX: why all this .to(device) ?
if args.loss_function == "mse" or args.loss_function == "bce":
if use_gpu:
return loss_fn(Z, T.to(device))
@@ -848,9 +903,12 @@ if __name__ == "__main__":
k = 0
# Load model is specified
- if not (args.load_model == ""):
+ if args.load_model != "":
print("Loading saved model {}".format(args.load_model))
- if use_gpu:
+ if use_tpu:
+ # XXX: add tpu capabilities to load.
+ raise NotImplementedError('Add tpu capabilities to load.')
+ elif use_gpu:
if dlrm.ndevices > 1:
# NOTE: when targeting inference on multiple GPUs,
# load the model as is on CPU or GPU, with the move
@@ -907,6 +965,7 @@ if __name__ == "__main__":
)
print("time/loss/accuracy (if enabled):")
+ # XXX: what is profiler? turning it off for tpus.
with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof:
while k < args.nepochs:
if k < skip_upto_epoch:
@@ -932,6 +991,7 @@ if __name__ == "__main__":
t1 = time_wrap(use_gpu)
# early exit if nbatches was set by the user and has been exceeded
+ # XXX: what about j being reset every epoch?
if nbatches > 0 and j >= nbatches:
break
'''
@@ -955,6 +1015,7 @@ if __name__ == "__main__":
print(Z.detach().cpu().numpy())
print(E.detach().cpu().numpy())
'''
+ # FIXME: XXX: do this inside print_freq check
# compute loss and accuracy
L = E.detach().cpu().numpy() # numpy array
S = Z.detach().cpu().numpy() # numpy array
@@ -967,6 +1028,7 @@ if __name__ == "__main__":
# (where we do not accumulate gradients across mini-batches)
optimizer.zero_grad()
# backward pass
+ # FIXME: errors here
E.backward()
# debug prints (check gradient norm)
# for l in mlp.layers:
@@ -1234,3 +1296,18 @@ if __name__ == "__main__":
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx")
# check the onnx model
onnx.checker.check_model(dlrm_pytorch_onnx)
+
+
+# XXX: clean
+def metsumm(stepno=''):
+ #import torch_xla.debug.metrics as met
+ x = met.metrics_report().split('\n')
+ for i, line in enumerate(x):
+ if 'CompileTime' in line or 'aten::' in line:
+ key = line.split()[-1]
+ value = x[i+1].split()[-1]
+ xm.master_print('step {}, key {}, value {}'.format(stepno, key, value))
+
+
+if __name__ == "__main__":
+ raise
diff --git a/dlrm_tpu_runner.py b/dlrm_tpu_runner.py
new file mode 100644
index 0000000..d31aa6d
--- /dev/null
+++ b/dlrm_tpu_runner.py
@@ -0,0 +1,15 @@
+import sys
+import argparse
+
+import torch_xla.distributed.xla_multiprocessing as xmp
+
+from dlrm_s_pytorch import main, parse_args
+
+
+if __name__ == '__main__':
+ pre_spawn_parser = argparse.ArgumentParser()
+ pre_spawn_parser.add_argument(
+ "--tpu-cores", type=int, default=8, choices=[1, 8]
+ )
+ pre_spawn_flags, _ = pre_spawn_parser.parse_known_args()
+ xmp.spawn(main, args=(), nprocs=pre_spawn_flags.tpu_cores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment