Created
June 12, 2020 23:13
-
-
Save taylanbil/226e63897665bf504ec74e5dfa2370f8 to your computer and use it in GitHub Desktop.
[wip] dlrm on tpu
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
git diff HEAD~1 . | |
diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py | |
index 1955bb9..e9ff88a 100644 | |
--- a/dlrm_s_pytorch.py | |
+++ b/dlrm_s_pytorch.py | |
@@ -177,9 +177,11 @@ class DLRM_Net(nn.Module): | |
n = ln[i] | |
# construct embedding operator | |
if self.qr_flag and n > self.qr_threshold: | |
+ # XXX: code path not hit with current tpu tests. | |
EE = QREmbeddingBag(n, m, self.qr_collisions, | |
- operation=self.qr_operation, mode="sum", sparse=True) | |
+ operation=self.qr_operation, mode="sum", sparse=self.sparse) | |
elif self.md_flag: | |
+ # XXX: code path not hit with current tpu tests. | |
base = max(m) | |
_m = m[i] if n > self.md_threshold else base | |
EE = PrEmbeddingBag(n, _m, base) | |
@@ -190,7 +192,7 @@ class DLRM_Net(nn.Module): | |
EE.embs.weight.data = torch.tensor(W, requires_grad=True) | |
else: | |
- EE = nn.EmbeddingBag(n, m, mode="sum", sparse=True) | |
+ EE = nn.EmbeddingBag(n, m, mode="sum", sparse=self.sparse) | |
# initialize embeddings | |
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n)) | |
@@ -227,6 +229,7 @@ class DLRM_Net(nn.Module): | |
qr_threshold=200, | |
md_flag=False, | |
md_threshold=200, | |
+ sparse=True, | |
): | |
super(DLRM_Net, self).__init__() | |
@@ -247,6 +250,7 @@ class DLRM_Net(nn.Module): | |
self.arch_interaction_itself = arch_interaction_itself | |
self.sync_dense_params = sync_dense_params | |
self.loss_threshold = loss_threshold | |
+ self.sparse = sparse | |
# create variables for QR embedding if applicable | |
self.qr_flag = qr_flag | |
if self.qr_flag: | |
@@ -333,7 +337,10 @@ class DLRM_Net(nn.Module): | |
if self.ndevices <= 1: | |
return self.sequential_forward(dense_x, lS_o, lS_i) | |
else: | |
+ # XXX: | |
+ raise NotImplementedError('TPU') | |
return self.parallel_forward(dense_x, lS_o, lS_i) | |
+ metsumm() | |
def sequential_forward(self, dense_x, lS_o, lS_i): | |
# process dense features (using bottom mlp), resulting in a row vector | |
@@ -363,6 +370,8 @@ class DLRM_Net(nn.Module): | |
return z | |
def parallel_forward(self, dense_x, lS_o, lS_i): | |
+ # XXX | |
+ raise NotImplementedError('tpu') | |
### prepare model (overwrite) ### | |
# WARNING: # of devices must be >= batch size in parallel_forward call | |
batch_size = dense_x.size()[0] | |
@@ -471,12 +480,9 @@ class DLRM_Net(nn.Module): | |
return z0 | |
-if __name__ == "__main__": | |
- ### import packages ### | |
- import sys | |
+def parse_args(): | |
import argparse | |
- ### parse arguments ### | |
parser = argparse.ArgumentParser( | |
description="Train Deep Learning Recommendation Model (DLRM)" | |
) | |
@@ -532,8 +538,11 @@ if __name__ == "__main__": | |
parser.add_argument("--inference-only", action="store_true", default=False) | |
# onnx | |
parser.add_argument("--save-onnx", action="store_true", default=False) | |
- # gpu | |
+ # accelerators | |
parser.add_argument("--use-gpu", action="store_true", default=False) | |
+ parser.add_argument("--use-tpu", action="store_true", default=False) | |
+ # XXX: clean | |
+ #parser.add_argument("--tpu-cores", type=int, default=8, choices=[1, 8]) | |
# debugging and profiling | |
parser.add_argument("--print-freq", type=int, default=1) | |
parser.add_argument("--test-freq", type=int, default=-1) | |
@@ -558,16 +567,7 @@ if __name__ == "__main__": | |
parser.add_argument("--lr-num-warmup-steps", type=int, default=0) | |
parser.add_argument("--lr-decay-start-step", type=int, default=0) | |
parser.add_argument("--lr-num-decay-steps", type=int, default=0) | |
- args = parser.parse_args() | |
- | |
- if args.mlperf_logging: | |
- print('command line args: ', json.dumps(vars(args))) | |
- | |
- ### some basic setup ### | |
- np.random.seed(args.numpy_rand_seed) | |
- np.set_printoptions(precision=args.print_precision) | |
- torch.set_printoptions(precision=args.print_precision) | |
- torch.manual_seed(args.numpy_rand_seed) | |
+ args, _ = parser.parse_known_args() | |
if (args.test_mini_batch_size < 0): | |
# if the parameter is not set, use the training batch size | |
@@ -576,8 +576,38 @@ if __name__ == "__main__": | |
# if the parameter is not set, use the same parameter for training | |
args.test_num_workers = args.num_workers | |
+ return args | |
+ | |
+ | |
+def main(*_args): | |
+ ### import packages ### | |
+ import sys | |
+ | |
+ args = parse_args() | |
+ ### some basic setup ### | |
+ np.random.seed(args.numpy_rand_seed) | |
+ np.set_printoptions(precision=args.print_precision) | |
+ torch.set_printoptions(precision=args.print_precision) | |
+ # XXX: does this imply we init the emb tables the same way? | |
+ torch.manual_seed(args.numpy_rand_seed) | |
+ | |
use_gpu = args.use_gpu and torch.cuda.is_available() | |
- if use_gpu: | |
+ use_tpu = args.use_tpu | |
+ if use_tpu: | |
+ use_gpu = False | |
+ import sys | |
+ sys.path.insert(0,'/usr/share/torch-xla-nightly/pytorch/xla') | |
+ import torch_xla.core.xla_model as xm | |
+ import torch_xla.debug.metrics as met | |
+ import torch_xla.distributed.parallel_loader as pl | |
+ builtin_print = builtins.print | |
+ print = xm.master_print | |
+ device = xm.xla_device() | |
+ print("Using {} TPU cores...".format(xm.xrt_world_size())) | |
+ if args.enable_profiling: | |
+ print("Profiling was enabled. Turning it off for TPUs.") | |
+ args.enable_profiling = False | |
+ elif use_gpu: | |
torch.cuda.manual_seed_all(args.numpy_rand_seed) | |
torch.backends.cudnn.deterministic = True | |
device = torch.device("cuda", 0) | |
@@ -587,7 +617,12 @@ if __name__ == "__main__": | |
device = torch.device("cpu") | |
print("Using CPU...") | |
+ if 1 or args.mlperf_logging: | |
+ print('command line args: ', json.dumps(vars(args))) | |
+ | |
### prepare training data ### | |
+ # XXX: this doesn't seem to shard the data? check dlrm_data_pytorch.py | |
+ # XXX: are we dropping last? | |
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-") | |
# input data | |
if (args.data_generation == "dataset"): | |
@@ -613,6 +648,11 @@ if __name__ == "__main__": | |
train_data, train_ld = dp.make_random_data_and_loader(args, ln_emb, m_den) | |
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld) | |
+ if use_tpu: | |
+ # XXX: test_data is unused. | |
+ # Wrap w/ torch_xla's loader | |
+ train_ld = pl.MpDeviceLoader(train_ld, device) | |
+ | |
### parse command line arguments ### | |
m_spa = args.arch_sparse_feature_size | |
num_fea = ln_emb.size + 1 # num sparse + num dense features | |
@@ -737,7 +777,14 @@ if __name__ == "__main__": | |
print([S_i.detach().cpu().tolist() for S_i in lS_i]) | |
print(T.detach().cpu().numpy()) | |
- ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1 | |
+ # XXX: why mini batch size, num_fea? | |
+ # it must be bc. it shards along those things. | |
+ if use_tpu: | |
+ ndevices = xm.xrt_world_size() | |
+ elif use_gpu: | |
+ ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) | |
+ else: | |
+ ndevices = -1 | |
### construct the neural network specified above ### | |
# WARNING: to obtain exactly the same initialization for | |
@@ -761,6 +808,7 @@ if __name__ == "__main__": | |
qr_threshold=args.qr_threshold, | |
md_flag=args.md_flag, | |
md_threshold=args.md_threshold, | |
+ sparse=device.type != 'xla' | |
) | |
# test prints | |
if args.debug_mode: | |
@@ -776,6 +824,12 @@ if __name__ == "__main__": | |
dlrm = dlrm.to(device) # .cuda() | |
if dlrm.ndevices > 1: | |
dlrm.emb_l = dlrm.create_emb(m_spa, ln_emb) | |
+ if use_tpu: | |
+ # XXX: what is ndevices exactly? How is it used in parallelization? | |
+ dlrm = dlrm.to(device) # .cuda() | |
+ if dlrm.ndevices > 1: | |
+ raise | |
+ # XXX: | |
# specify the loss function | |
if args.loss_function == "mse": | |
@@ -791,10 +845,13 @@ if __name__ == "__main__": | |
if not args.inference_only: | |
# specify the optimizer algorithm | |
optimizer = torch.optim.SGD(dlrm.parameters(), lr=args.learning_rate) | |
- lr_scheduler = LRPolicyScheduler(optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
- args.lr_num_decay_steps) | |
+ lr_scheduler = LRPolicyScheduler( | |
+ optimizer, args.lr_num_warmup_steps, args.lr_decay_start_step, | |
+ args.lr_num_decay_steps | |
+ ) | |
### main loop ### | |
+ # XXX: what do these functions do? | |
def time_wrap(use_gpu): | |
if use_gpu: | |
torch.cuda.synchronize() | |
@@ -802,21 +859,19 @@ if __name__ == "__main__": | |
def dlrm_wrap(X, lS_o, lS_i, use_gpu, device): | |
if use_gpu: # .cuda() | |
+ raise | |
# lS_i can be either a list of tensors or a stacked tensor. | |
# Handle each case below: | |
lS_i = [S_i.to(device) for S_i in lS_i] if isinstance(lS_i, list) \ | |
else lS_i.to(device) | |
lS_o = [S_o.to(device) for S_o in lS_o] if isinstance(lS_o, list) \ | |
else lS_o.to(device) | |
- return dlrm( | |
- X.to(device), | |
- lS_o, | |
- lS_i | |
- ) | |
+ return dlrm(X.to(device), lS_o, lS_i) | |
else: | |
return dlrm(X, lS_o, lS_i) | |
def loss_fn_wrap(Z, T, use_gpu, device): | |
+ # XXX: why all this .to(device) ? | |
if args.loss_function == "mse" or args.loss_function == "bce": | |
if use_gpu: | |
return loss_fn(Z, T.to(device)) | |
@@ -848,9 +903,12 @@ if __name__ == "__main__": | |
k = 0 | |
# Load model is specified | |
- if not (args.load_model == ""): | |
+ if args.load_model != "": | |
print("Loading saved model {}".format(args.load_model)) | |
- if use_gpu: | |
+ if use_tpu: | |
+ # XXX: add tpu capabilities to load. | |
+ raise NotImplementedError('Add tpu capabilities to load.') | |
+ elif use_gpu: | |
if dlrm.ndevices > 1: | |
# NOTE: when targeting inference on multiple GPUs, | |
# load the model as is on CPU or GPU, with the move | |
@@ -907,6 +965,7 @@ if __name__ == "__main__": | |
) | |
print("time/loss/accuracy (if enabled):") | |
+ # XXX: what is profiler? turning it off for tpus. | |
with torch.autograd.profiler.profile(args.enable_profiling, use_gpu) as prof: | |
while k < args.nepochs: | |
if k < skip_upto_epoch: | |
@@ -932,6 +991,7 @@ if __name__ == "__main__": | |
t1 = time_wrap(use_gpu) | |
# early exit if nbatches was set by the user and has been exceeded | |
+ # XXX: what about j being reset every epoch? | |
if nbatches > 0 and j >= nbatches: | |
break | |
''' | |
@@ -955,6 +1015,7 @@ if __name__ == "__main__": | |
print(Z.detach().cpu().numpy()) | |
print(E.detach().cpu().numpy()) | |
''' | |
+ # FIXME: XXX: do this inside print_freq check | |
# compute loss and accuracy | |
L = E.detach().cpu().numpy() # numpy array | |
S = Z.detach().cpu().numpy() # numpy array | |
@@ -967,6 +1028,7 @@ if __name__ == "__main__": | |
# (where we do not accumulate gradients across mini-batches) | |
optimizer.zero_grad() | |
# backward pass | |
+ # FIXME: errors here | |
E.backward() | |
# debug prints (check gradient norm) | |
# for l in mlp.layers: | |
@@ -1234,3 +1296,18 @@ if __name__ == "__main__": | |
dlrm_pytorch_onnx = onnx.load("dlrm_s_pytorch.onnx") | |
# check the onnx model | |
onnx.checker.check_model(dlrm_pytorch_onnx) | |
+ | |
+ | |
+# XXX: clean | |
+def metsumm(stepno=''): | |
+ #import torch_xla.debug.metrics as met | |
+ x = met.metrics_report().split('\n') | |
+ for i, line in enumerate(x): | |
+ if 'CompileTime' in line or 'aten::' in line: | |
+ key = line.split()[-1] | |
+ value = x[i+1].split()[-1] | |
+ xm.master_print('step {}, key {}, value {}'.format(stepno, key, value)) | |
+ | |
+ | |
+if __name__ == "__main__": | |
+ raise | |
diff --git a/dlrm_tpu_runner.py b/dlrm_tpu_runner.py | |
new file mode 100644 | |
index 0000000..d31aa6d | |
--- /dev/null | |
+++ b/dlrm_tpu_runner.py | |
@@ -0,0 +1,15 @@ | |
+import sys | |
+import argparse | |
+ | |
+import torch_xla.distributed.xla_multiprocessing as xmp | |
+ | |
+from dlrm_s_pytorch import main, parse_args | |
+ | |
+ | |
+if __name__ == '__main__': | |
+ pre_spawn_parser = argparse.ArgumentParser() | |
+ pre_spawn_parser.add_argument( | |
+ "--tpu-cores", type=int, default=8, choices=[1, 8] | |
+ ) | |
+ pre_spawn_flags, _ = pre_spawn_parser.parse_known_args() | |
+ xmp.spawn(main, args=(), nprocs=pre_spawn_flags.tpu_cores) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment