Created
September 26, 2024 18:12
-
-
Save samos123/62771cdcd5d8064d6790f91450ad1552 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jax 0.4.30 | |
I0926 18:08:39.951711 136615241593984 trainer.py:318] gpt_trainer process 0 step 0] Training state size: 621.76 GiB | |
Training state size (partitioned): 12.68 GiB | |
Max training state size (partitioned): 12.68 GiB | |
I0926 18:08:40.992022 136615241593984 trainer.py:465] Starting loop... | |
2024-09-26 18:08:41.106644: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present. | |
Exception in thread gpt_trainer.checkpointer.gc_loop: | |
Traceback (most recent call last): | |
File "/root/axlearn/common/file_system.py", line 46, in _wrap_exception | |
yield | |
File "/root/axlearn/common/file_system.py", line 71, in wrapped | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/file_system.py", line 98, in listdir | |
return tf.io.gfile.listdir(path) | |
File "/opt/venv/lib/python3.10/site-packages/tensorflow/python/lib/io/file_io.py", line 768, in list_directory_v2 | |
coll_type: COLL_ALL_REDUCE | |
msg_size_tuning_rules { | |
per_rank_message_size { | |
min: 0 | |
} | |
coll_tuning_spec { | |
num_channel: 2 | |
protocol: PROTO_SIMPLE | |
algorithm: ALGO_TREE | |
} | |
} | |
} | |
coll_configs { | |
coll_type: COLL_DEFAULT | |
msg_size_tuning_rules { | |
per_rank_message_size { | |
min: 0 | |
max: 65536 | |
} | |
coll_tuning_spec { | |
num_channel: 2 | |
protocol: PROTO_SIMPLE | |
algorithm: ALGO_RING | |
} | |
} | |
msg_size_tuning_rules { | |
per_rank_message_size { | |
min: 65536 | |
} | |
coll_tuning_spec { | |
num_channel: 4 | |
protocol: PROTO_SIMPLE | |
gke-a3plus-benchmark-a3plus-benchmark-caf334a5-6fvl:7:2175 [7] /nccl-tuner-config-based/src/config_based_tuner.cc:271 NCCL WARN No communicator config selected from config:communicator_configs { | |
node_range { | |
min: 2 | |
max: 3 | |
} | |
rank_per_node_range { | |
raise errors.NotFoundError( | |
tensorflow.python.framework.errors_impl.NotFoundError: Could not find directory gs://supercomputer-testing-axlearn/70b-8n-norematoffload-1/checkpoints | |
The above exception was the direct cause of the following exception: | |
Traceback (most recent call last): | |
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner | |
min: 1 | |
max: 2 | |
} | |
coll_configs { | |
coll_type: COLL_ALL_REDUCE | |
msg_size_tuning_rules { | |
per_rank_message_size { | |
min: 0 | |
} | |
coll_tuning_spec { | |
num_channel: 2 | |
protocol: PROTO_SIMPLE | |
algorithm: ALGO_TREE | |
} | |
} | |
} | |
coll_configs { | |
coll_type: COLL_DEFAULT | |
msg_size_tuning_rules { | |
per_rank_message_size { | |
min: 0 | |
max: 65536 | |
} | |
coll_tuning_spec { | |
num_channel: 2 | |
protocol: PROTO_SIMPLE | |
algorithm: ALGO_RING | |
} | |
} | |
msg_size_tuning_rules { | |
per_rank_message_size { | |
min: 65536 | |
} | |
coll_tuning_spec { | |
num_channel: 4 | |
protocol: PROTO_SIMPLE | |
algorithm: ALGO_RING | |
} | |
} | |
} | |
} | |
c | |
gke-a3plus-benchmark-a3plus-benchmark-caf334a5-6fvl:7:2175 [7] /nccl-tuner-config-based/src/tuner_tcpx.cc:70 NCCL WARN No communicator found for nRanks:8, nNodes:8 from config_path:/usr/local/nvidia/lib64/a3plus_tuner_config.textproto | |
self.run() | |
File "/usr/lib/python3.10/threading.py", line 953, in run | |
self._target(*self._args, **self._kwargs) | |
File "/root/axlearn/common/checkpointer.py", line 863, in _gc_loop | |
self._run_garbage_collection() | |
File "/root/axlearn/common/checkpointer.py", line 912, in _run_garbage_collection | |
for step in fs.listdir(cfg.dir) | |
File "/root/axlearn/common/file_system.py", line 67, in wrapped | |
with ( | |
File "/usr/lib/python3.10/contextlib.py", line 153, in __exit__ | |
self.gen.throw(typ, value, traceback) | |
File "/root/axlearn/common/file_system.py", line 48, in _wrap_exception | |
raise target_exc(str(e)) from e | |
axlearn.common.file_system.NotFoundError: Could not find directory gs://supercomputer-testing-axlearn/70b-8n-norematoffload-1/checkpoints | |
I0926 18:10:38.143126 136615241593984 trainer.py:473] input_batch={'input_ids': (16, 4096), 'target_labels': (16, 4096), 'target_num_bytes': (16,)} | |
I0926 18:10:38.407656 136615241593984 base_layer.py:337] Applying remat on gpt_trainer.model.decoder.transformer.repeat.layer.<function TransformerLayer.forward at 0x7c3a844039a0>: RematSpec(prevent_cse=False, policy=config_for_function(jax._src.ad_checkpoint.save_only_these_names)(fn=<function save_only_these_names at 0x7c3f87590820>, names_which_can_be_saved=['FlashAttention.q_proj', 'FlashAttention.k_proj', 'FlashAttention.v_proj', 'FlashAttention.context', 'FlashAttention.o_proj', 'TransformerFeedForwardLayer.activation', 'TransformerFeedForwardLayer.linear2'])) | |
I0926 18:10:38.483145 136615241593984 serialization.py:554] Error check finished successfully | |
I0926 18:10:38.483265 136615241593984 checkpointer.py:849] Waiting for gc_thread to finish | |
I0926 18:10:38.483329 136615241593984 checkpointer.py:854] gc_thread finished | |
I0926 18:10:38.483386 136615241593984 trainer.py:351] Waiting for watchdog_thread to finish | |
I0926 18:10:38.483608 136493205018176 trainer.py:377] Watchdog loop done | |
I0926 18:10:38.483988 136615241593984 trainer.py:356] watchdog_thread finished | |
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. | |
The above exception was the direct cause of the following exception: | |
Traceback (most recent call last): | |
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main | |
return _run_code(code, main_globals, None, | |
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code | |
exec(code, run_globals) | |
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module> | |
app.run(main) | |
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run | |
_run_main(main, args) | |
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main | |
sys.exit(main(argv)) | |
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main | |
launch_trainer.run_trainer(trainer_config) | |
File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer | |
output = trainer.run(prng_key) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 486, in _call_method_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/trainer.py", line 482, in run | |
output = self._run_step( | |
File "/root/axlearn/common/trainer.py", line 867, in _run_step | |
self._trainer_state, outputs = self._jit_train_step(self._trainer_state, input_batch) | |
File "/root/axlearn/common/trainer.py", line 993, in _train_step | |
fwd_bwd_outputs, learner_output_collection = F( | |
File "/root/axlearn/common/module.py", line 986, in functional | |
method_outputs, output_collection = fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 927, in __call__ | |
method_outputs = self.method_fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/learner.py", line 343, in forward_and_backward | |
updates = _value_and_grad( | |
File "/root/axlearn/common/learner.py", line 717, in _value_and_grad | |
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)( | |
File "/root/axlearn/common/learner.py", line 674, in forward | |
outputs = fun(model_params=model_params, inputs=inputs) # type: ignore | |
File "/root/axlearn/common/learner.py", line 641, in filtered_forward | |
return fun(model_params=model_params, inputs=inputs) | |
File "/root/axlearn/common/trainer.py", line 988, in _forward | |
loss, aux = self.model(input_batch=train_cast(inputs["input_batch"])) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/causal_lm.py", line 120, in forward | |
predictions = self.predict(input_batch) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/causal_lm.py", line 282, in predict | |
decoder_output = self.decoder( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context | |
return call_thunk_in_context(reversed_path) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/decoder.py", line 561, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/decoder.py", line 479, in _forward_for_mode | |
transformer_state, x = None, self.transformer( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context | |
return call_thunk_in_context(reversed_path) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/attention.py", line 3771, in forward | |
return self.repeat(data, **layer_kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context | |
return call_thunk_in_context(reversed_path) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/attention.py", line 3687, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 3664, in _forward_for_mode | |
repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states) | |
File "/root/axlearn/common/repeat.py", line 199, in _run | |
carry, ys = scan_in_context( | |
File "/root/axlearn/common/module.py", line 1059, in scan_in_context | |
carry, scan_ys = jax.lax.scan(scan_fn, init=carry, xs=xs) | |
File "/root/axlearn/common/module.py", line 1044, in scan_fn | |
carry_i, y_i = fn(carry_i, x_i) | |
File "/root/axlearn/common/attention.py", line 3636, in layer_fn | |
layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/base_layer.py", line 340, in nullary_with_remat | |
outputs, output_collection = jax.ad_checkpoint.remat( | |
File "/root/axlearn/common/base_layer.py", line 334, in fn | |
outputs = method_fn(self, *args, **kwargs, **static_kwargs) | |
File "/root/axlearn/common/attention.py", line 3066, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 3004, in _forward_for_mode | |
self.self_attention( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context | |
return call_thunk_in_context(reversed_path) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/attention.py", line 2552, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 2506, in _forward_for_mode | |
atten_state, atten_output = attention_thunk(norm_target) | |
File "/root/axlearn/common/attention.py", line 2476, in attention_thunk | |
self.attention( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 865, in __call__ | |
return self.forward(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn | |
return _call_method_in_context( | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context | |
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module)))) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context | |
return call_thunk_in_context(reversed_path) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context | |
return thunk() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 473, in thunk | |
return module._call_thunk(*args, method_fn=method_fn, **kwargs)() | |
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper | |
return fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 872, in nullary | |
return method_fn(self, *args, **kwargs) | |
File "/root/axlearn/common/attention.py", line 1871, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 1760, in _forward_for_mode | |
context, probs = self._compute_attention( | |
File "/root/axlearn/common/flash_attention/layer.py", line 189, in _compute_attention | |
cfg.mha_dim_to_partition_spec["bt"], | |
KeyError: 'bt' | |
The above exception was the direct cause of the following exception: | |
Traceback (most recent call last): | |
File "/root/axlearn/common/traceback_util.py", line 163, in in_context_exception_wrapper | |
raise in_context_exception from e | |
axlearn.common.traceback_util._InContextException: | |
An error was encountered in a wrapped Module method. | |
Below is an AXLearn stack summary, which may be easier to read. | |
Immediately above is the stack frame in which the stack summary was initialized. | |
You can probably ignore that frame. | |
Further above that is the original Python Exception with the raw stack trace, which caused this Exception. | |
Stack Summary (most recent call last): | |
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code | |
exec(code, run_globals) | |
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module> | |
app.run(main) | |
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 330, in run | |
raise | |
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main | |
sys.exit(main(argv)) | |
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main | |
launch_trainer.run_trainer(trainer_config) | |
File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer | |
output = trainer.run(prng_key) | |
Wrapped call axlearn.common.trainer.SpmdTrainer.run(jaxlib.xla_extension.ArrayImpl) | |
File "/root/axlearn/common/trainer.py", line 482, in run | |
output = self._run_step( | |
File "/root/axlearn/common/trainer.py", line 867, in _run_step | |
self._trainer_state, outputs = self._jit_train_step(self._trainer_state, input_batch) | |
File "/root/axlearn/common/trainer.py", line 993, in _train_step | |
fwd_bwd_outputs, learner_output_collection = F( | |
File "/root/axlearn/common/module.py", line 986, in functional | |
method_outputs, output_collection = fn(*args, **kwargs) | |
File "/root/axlearn/common/module.py", line 927, in __call__ | |
method_outputs = self.method_fn(*args, **kwargs) | |
Wrapped call axlearn.common.learner.Learner.forward_and_backward(fn: function, inputs: dict, opt_params: dict) | |
File "/root/axlearn/common/learner.py", line 343, in forward_and_backward | |
updates = _value_and_grad( | |
File "/root/axlearn/common/learner.py", line 717, in _value_and_grad | |
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)( | |
File "/root/axlearn/common/learner.py", line 674, in forward | |
outputs = fun(model_params=model_params, inputs=inputs) # type: ignore | |
File "/root/axlearn/common/learner.py", line 641, in filtered_forward | |
return fun(model_params=model_params, inputs=inputs) | |
File "/root/axlearn/common/trainer.py", line 988, in _forward | |
loss, aux = self.model(input_batch=train_cast(inputs["input_batch"])) | |
Wrapped call axlearn.common.causal_lm.Model.forward(input_batch: dict) | |
File "/root/axlearn/common/causal_lm.py", line 120, in forward | |
predictions = self.predict(input_batch) | |
Wrapped call axlearn.common.causal_lm.Model.predict(dict) | |
File "/root/axlearn/common/causal_lm.py", line 282, in predict | |
decoder_output = self.decoder( | |
Wrapped call axlearn.common.decoder.Decoder.forward(input_ids: jax._src.interpreters.partial_eval.DynamicJaxprTracer, token_type_ids: NoneType, input_segment_ids: NoneType, positions: NoneType) | |
File "/root/axlearn/common/decoder.py", line 561, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/decoder.py", line 479, in _forward_for_mode | |
transformer_state, x = None, self.transformer( | |
Wrapped call axlearn.common.attention.RepeatedTransformerLayer.forward(jax._src.interpreters.ad.JVPTracer, self_attention_logit_biases: NoneType, segment_ids: NoneType, cross_attention_data: NoneType, cross_attention_logit_biases: NoneType) | |
File "/root/axlearn/common/attention.py", line 3771, in forward | |
return self.repeat(data, **layer_kwargs) | |
Wrapped call axlearn.common.attention._TransformerRepeat.forward(jax._src.interpreters.ad.JVPTracer, self_attention_logit_biases: NoneType, segment_ids: NoneType, cross_attention_data: NoneType, cross_attention_logit_biases: NoneType) | |
File "/root/axlearn/common/attention.py", line 3687, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 3664, in _forward_for_mode | |
repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states) | |
File "/root/axlearn/common/repeat.py", line 199, in _run | |
carry, ys = scan_in_context( | |
File "/root/axlearn/common/module.py", line 1059, in scan_in_context | |
carry, scan_ys = jax.lax.scan(scan_fn, init=carry, xs=xs) | |
File "/root/axlearn/common/module.py", line 1044, in scan_fn | |
carry_i, y_i = fn(carry_i, x_i) | |
File "/root/axlearn/common/attention.py", line 3636, in layer_fn | |
layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs) | |
Wrapped call axlearn.common.attention.TransformerLayer.forward(data: jax._src.interpreters.partial_eval.DynamicJaxprTracer, self_attention_logit_biases: NoneType, segment_ids: NoneType, cross_attention_data: NoneType, cross_attention_logit_biases: NoneType) | |
File "/root/axlearn/common/base_layer.py", line 340, in nullary_with_remat | |
outputs, output_collection = jax.ad_checkpoint.remat( | |
File "/root/axlearn/common/base_layer.py", line 334, in fn | |
outputs = method_fn(self, *args, **kwargs, **static_kwargs) | |
File "/root/axlearn/common/attention.py", line 3066, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 3004, in _forward_for_mode | |
self.self_attention( | |
Wrapped call axlearn.common.attention.TransformerAttentionLayer.forward(target: jax._src.interpreters.partial_eval.DynamicJaxprTracer, segment_ids: NoneType, source: NoneType, attention_logit_biases: NoneType, return_aux: set) | |
File "/root/axlearn/common/attention.py", line 2552, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 2506, in _forward_for_mode | |
atten_state, atten_output = attention_thunk(norm_target) | |
File "/root/axlearn/common/attention.py", line 2476, in attention_thunk | |
self.attention( | |
Wrapped call axlearn.common.flash_attention.layer.FlashAttention.forward(query: jax._src.interpreters.partial_eval.DynamicJaxprTracer, return_aux: set, attention_logit_biases: NoneType, segment_ids: NoneType) | |
File "/root/axlearn/common/attention.py", line 1871, in forward | |
_, output = self._forward_for_mode( | |
File "/root/axlearn/common/attention.py", line 1760, in _forward_for_mode | |
context, probs = self._compute_attention( | |
File "/root/axlearn/common/flash_attention/layer.py", line 189, in _compute_attention | |
cfg.mha_dim_to_partition_spec["bt"], | |
KeyError: 'bt' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment