Skip to content

Instantly share code, notes, and snippets.

@samos123
Created September 26, 2024 18:12
Show Gist options
  • Save samos123/62771cdcd5d8064d6790f91450ad1552 to your computer and use it in GitHub Desktop.
Save samos123/62771cdcd5d8064d6790f91450ad1552 to your computer and use it in GitHub Desktop.
jax 0.4.30
I0926 18:08:39.951711 136615241593984 trainer.py:318] gpt_trainer process 0 step 0] Training state size: 621.76 GiB
Training state size (partitioned): 12.68 GiB
Max training state size (partitioned): 12.68 GiB
I0926 18:08:40.992022 136615241593984 trainer.py:465] Starting loop...
2024-09-26 18:08:41.106644: E tensorflow/core/util/util.cc:131] oneDNN supports DT_INT64 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
Exception in thread gpt_trainer.checkpointer.gc_loop:
Traceback (most recent call last):
File "/root/axlearn/common/file_system.py", line 46, in _wrap_exception
yield
File "/root/axlearn/common/file_system.py", line 71, in wrapped
return fn(*args, **kwargs)
File "/root/axlearn/common/file_system.py", line 98, in listdir
return tf.io.gfile.listdir(path)
File "/opt/venv/lib/python3.10/site-packages/tensorflow/python/lib/io/file_io.py", line 768, in list_directory_v2
coll_type: COLL_ALL_REDUCE
msg_size_tuning_rules {
per_rank_message_size {
min: 0
}
coll_tuning_spec {
num_channel: 2
protocol: PROTO_SIMPLE
algorithm: ALGO_TREE
}
}
}
coll_configs {
coll_type: COLL_DEFAULT
msg_size_tuning_rules {
per_rank_message_size {
min: 0
max: 65536
}
coll_tuning_spec {
num_channel: 2
protocol: PROTO_SIMPLE
algorithm: ALGO_RING
}
}
msg_size_tuning_rules {
per_rank_message_size {
min: 65536
}
coll_tuning_spec {
num_channel: 4
protocol: PROTO_SIMPLE
gke-a3plus-benchmark-a3plus-benchmark-caf334a5-6fvl:7:2175 [7] /nccl-tuner-config-based/src/config_based_tuner.cc:271 NCCL WARN No communicator config selected from config:communicator_configs {
node_range {
min: 2
max: 3
}
rank_per_node_range {
raise errors.NotFoundError(
tensorflow.python.framework.errors_impl.NotFoundError: Could not find directory gs://supercomputer-testing-axlearn/70b-8n-norematoffload-1/checkpoints
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
min: 1
max: 2
}
coll_configs {
coll_type: COLL_ALL_REDUCE
msg_size_tuning_rules {
per_rank_message_size {
min: 0
}
coll_tuning_spec {
num_channel: 2
protocol: PROTO_SIMPLE
algorithm: ALGO_TREE
}
}
}
coll_configs {
coll_type: COLL_DEFAULT
msg_size_tuning_rules {
per_rank_message_size {
min: 0
max: 65536
}
coll_tuning_spec {
num_channel: 2
protocol: PROTO_SIMPLE
algorithm: ALGO_RING
}
}
msg_size_tuning_rules {
per_rank_message_size {
min: 65536
}
coll_tuning_spec {
num_channel: 4
protocol: PROTO_SIMPLE
algorithm: ALGO_RING
}
}
}
}
c
gke-a3plus-benchmark-a3plus-benchmark-caf334a5-6fvl:7:2175 [7] /nccl-tuner-config-based/src/tuner_tcpx.cc:70 NCCL WARN No communicator found for nRanks:8, nNodes:8 from config_path:/usr/local/nvidia/lib64/a3plus_tuner_config.textproto
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/root/axlearn/common/checkpointer.py", line 863, in _gc_loop
self._run_garbage_collection()
File "/root/axlearn/common/checkpointer.py", line 912, in _run_garbage_collection
for step in fs.listdir(cfg.dir)
File "/root/axlearn/common/file_system.py", line 67, in wrapped
with (
File "/usr/lib/python3.10/contextlib.py", line 153, in __exit__
self.gen.throw(typ, value, traceback)
File "/root/axlearn/common/file_system.py", line 48, in _wrap_exception
raise target_exc(str(e)) from e
axlearn.common.file_system.NotFoundError: Could not find directory gs://supercomputer-testing-axlearn/70b-8n-norematoffload-1/checkpoints
I0926 18:10:38.143126 136615241593984 trainer.py:473] input_batch={'input_ids': (16, 4096), 'target_labels': (16, 4096), 'target_num_bytes': (16,)}
I0926 18:10:38.407656 136615241593984 base_layer.py:337] Applying remat on gpt_trainer.model.decoder.transformer.repeat.layer.<function TransformerLayer.forward at 0x7c3a844039a0>: RematSpec(prevent_cse=False, policy=config_for_function(jax._src.ad_checkpoint.save_only_these_names)(fn=<function save_only_these_names at 0x7c3f87590820>, names_which_can_be_saved=['FlashAttention.q_proj', 'FlashAttention.k_proj', 'FlashAttention.v_proj', 'FlashAttention.context', 'FlashAttention.o_proj', 'TransformerFeedForwardLayer.activation', 'TransformerFeedForwardLayer.linear2']))
I0926 18:10:38.483145 136615241593984 serialization.py:554] Error check finished successfully
I0926 18:10:38.483265 136615241593984 checkpointer.py:849] Waiting for gc_thread to finish
I0926 18:10:38.483329 136615241593984 checkpointer.py:854] gc_thread finished
I0926 18:10:38.483386 136615241593984 trainer.py:351] Waiting for watchdog_thread to finish
I0926 18:10:38.483608 136493205018176 trainer.py:377] Watchdog loop done
I0926 18:10:38.483988 136615241593984 trainer.py:356] watchdog_thread finished
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
app.run(main)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
launch_trainer.run_trainer(trainer_config)
File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer
output = trainer.run(prng_key)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 486, in _call_method_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/trainer.py", line 482, in run
output = self._run_step(
File "/root/axlearn/common/trainer.py", line 867, in _run_step
self._trainer_state, outputs = self._jit_train_step(self._trainer_state, input_batch)
File "/root/axlearn/common/trainer.py", line 993, in _train_step
fwd_bwd_outputs, learner_output_collection = F(
File "/root/axlearn/common/module.py", line 986, in functional
method_outputs, output_collection = fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 927, in __call__
method_outputs = self.method_fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/learner.py", line 343, in forward_and_backward
updates = _value_and_grad(
File "/root/axlearn/common/learner.py", line 717, in _value_and_grad
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
File "/root/axlearn/common/learner.py", line 674, in forward
outputs = fun(model_params=model_params, inputs=inputs) # type: ignore
File "/root/axlearn/common/learner.py", line 641, in filtered_forward
return fun(model_params=model_params, inputs=inputs)
File "/root/axlearn/common/trainer.py", line 988, in _forward
loss, aux = self.model(input_batch=train_cast(inputs["input_batch"]))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/causal_lm.py", line 120, in forward
predictions = self.predict(input_batch)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/causal_lm.py", line 282, in predict
decoder_output = self.decoder(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context
return call_thunk_in_context(reversed_path)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/decoder.py", line 561, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/decoder.py", line 479, in _forward_for_mode
transformer_state, x = None, self.transformer(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context
return call_thunk_in_context(reversed_path)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/attention.py", line 3771, in forward
return self.repeat(data, **layer_kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context
return call_thunk_in_context(reversed_path)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/attention.py", line 3687, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 3664, in _forward_for_mode
repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states)
File "/root/axlearn/common/repeat.py", line 199, in _run
carry, ys = scan_in_context(
File "/root/axlearn/common/module.py", line 1059, in scan_in_context
carry, scan_ys = jax.lax.scan(scan_fn, init=carry, xs=xs)
File "/root/axlearn/common/module.py", line 1044, in scan_fn
carry_i, y_i = fn(carry_i, x_i)
File "/root/axlearn/common/attention.py", line 3636, in layer_fn
layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/base_layer.py", line 340, in nullary_with_remat
outputs, output_collection = jax.ad_checkpoint.remat(
File "/root/axlearn/common/base_layer.py", line 334, in fn
outputs = method_fn(self, *args, **kwargs, **static_kwargs)
File "/root/axlearn/common/attention.py", line 3066, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 3004, in _forward_for_mode
self.self_attention(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context
return call_thunk_in_context(reversed_path)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/attention.py", line 2552, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 2506, in _forward_for_mode
atten_state, atten_output = attention_thunk(norm_target)
File "/root/axlearn/common/attention.py", line 2476, in attention_thunk
self.attention(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 865, in __call__
return self.forward(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 522, in wrap_method_fn
return _call_method_in_context(
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 146, in in_context_exception_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 495, in _call_method_in_context
return call_thunk_in_context(list(reversed(context.module.path_to_descendant_module(module))))
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 493, in call_thunk_in_context
return call_thunk_in_context(reversed_path)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 491, in call_thunk_in_context
return thunk()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 473, in thunk
return module._call_thunk(*args, method_fn=method_fn, **kwargs)()
File "/root/axlearn/common/traceback_util.py", line 268, in stack_annotation_wrapper
return fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 872, in nullary
return method_fn(self, *args, **kwargs)
File "/root/axlearn/common/attention.py", line 1871, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 1760, in _forward_for_mode
context, probs = self._compute_attention(
File "/root/axlearn/common/flash_attention/layer.py", line 189, in _compute_attention
cfg.mha_dim_to_partition_spec["bt"],
KeyError: 'bt'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/axlearn/common/traceback_util.py", line 163, in in_context_exception_wrapper
raise in_context_exception from e
axlearn.common.traceback_util._InContextException:
An error was encountered in a wrapped Module method.
Below is an AXLearn stack summary, which may be easier to read.
Immediately above is the stack frame in which the stack summary was initialized.
You can probably ignore that frame.
Further above that is the original Python Exception with the raw stack trace, which caused this Exception.
Stack Summary (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
app.run(main)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 330, in run
raise
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
launch_trainer.run_trainer(trainer_config)
File "/root/axlearn/common/launch_trainer.py", line 131, in run_trainer
output = trainer.run(prng_key)
Wrapped call axlearn.common.trainer.SpmdTrainer.run(jaxlib.xla_extension.ArrayImpl)
File "/root/axlearn/common/trainer.py", line 482, in run
output = self._run_step(
File "/root/axlearn/common/trainer.py", line 867, in _run_step
self._trainer_state, outputs = self._jit_train_step(self._trainer_state, input_batch)
File "/root/axlearn/common/trainer.py", line 993, in _train_step
fwd_bwd_outputs, learner_output_collection = F(
File "/root/axlearn/common/module.py", line 986, in functional
method_outputs, output_collection = fn(*args, **kwargs)
File "/root/axlearn/common/module.py", line 927, in __call__
method_outputs = self.method_fn(*args, **kwargs)
Wrapped call axlearn.common.learner.Learner.forward_and_backward(fn: function, inputs: dict, opt_params: dict)
File "/root/axlearn/common/learner.py", line 343, in forward_and_backward
updates = _value_and_grad(
File "/root/axlearn/common/learner.py", line 717, in _value_and_grad
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
File "/root/axlearn/common/learner.py", line 674, in forward
outputs = fun(model_params=model_params, inputs=inputs) # type: ignore
File "/root/axlearn/common/learner.py", line 641, in filtered_forward
return fun(model_params=model_params, inputs=inputs)
File "/root/axlearn/common/trainer.py", line 988, in _forward
loss, aux = self.model(input_batch=train_cast(inputs["input_batch"]))
Wrapped call axlearn.common.causal_lm.Model.forward(input_batch: dict)
File "/root/axlearn/common/causal_lm.py", line 120, in forward
predictions = self.predict(input_batch)
Wrapped call axlearn.common.causal_lm.Model.predict(dict)
File "/root/axlearn/common/causal_lm.py", line 282, in predict
decoder_output = self.decoder(
Wrapped call axlearn.common.decoder.Decoder.forward(input_ids: jax._src.interpreters.partial_eval.DynamicJaxprTracer, token_type_ids: NoneType, input_segment_ids: NoneType, positions: NoneType)
File "/root/axlearn/common/decoder.py", line 561, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/decoder.py", line 479, in _forward_for_mode
transformer_state, x = None, self.transformer(
Wrapped call axlearn.common.attention.RepeatedTransformerLayer.forward(jax._src.interpreters.ad.JVPTracer, self_attention_logit_biases: NoneType, segment_ids: NoneType, cross_attention_data: NoneType, cross_attention_logit_biases: NoneType)
File "/root/axlearn/common/attention.py", line 3771, in forward
return self.repeat(data, **layer_kwargs)
Wrapped call axlearn.common.attention._TransformerRepeat.forward(jax._src.interpreters.ad.JVPTracer, self_attention_logit_biases: NoneType, segment_ids: NoneType, cross_attention_data: NoneType, cross_attention_logit_biases: NoneType)
File "/root/axlearn/common/attention.py", line 3687, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 3664, in _forward_for_mode
repeat_outputs: Repeat.Output = self._run(layer_fn, carry=carry, xs=cached_states)
File "/root/axlearn/common/repeat.py", line 199, in _run
carry, ys = scan_in_context(
File "/root/axlearn/common/module.py", line 1059, in scan_in_context
carry, scan_ys = jax.lax.scan(scan_fn, init=carry, xs=xs)
File "/root/axlearn/common/module.py", line 1044, in scan_fn
carry_i, y_i = fn(carry_i, x_i)
File "/root/axlearn/common/attention.py", line 3636, in layer_fn
layer_states, layer_outputs = None, self.layer(**carry, **layer_kwargs)
Wrapped call axlearn.common.attention.TransformerLayer.forward(data: jax._src.interpreters.partial_eval.DynamicJaxprTracer, self_attention_logit_biases: NoneType, segment_ids: NoneType, cross_attention_data: NoneType, cross_attention_logit_biases: NoneType)
File "/root/axlearn/common/base_layer.py", line 340, in nullary_with_remat
outputs, output_collection = jax.ad_checkpoint.remat(
File "/root/axlearn/common/base_layer.py", line 334, in fn
outputs = method_fn(self, *args, **kwargs, **static_kwargs)
File "/root/axlearn/common/attention.py", line 3066, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 3004, in _forward_for_mode
self.self_attention(
Wrapped call axlearn.common.attention.TransformerAttentionLayer.forward(target: jax._src.interpreters.partial_eval.DynamicJaxprTracer, segment_ids: NoneType, source: NoneType, attention_logit_biases: NoneType, return_aux: set)
File "/root/axlearn/common/attention.py", line 2552, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 2506, in _forward_for_mode
atten_state, atten_output = attention_thunk(norm_target)
File "/root/axlearn/common/attention.py", line 2476, in attention_thunk
self.attention(
Wrapped call axlearn.common.flash_attention.layer.FlashAttention.forward(query: jax._src.interpreters.partial_eval.DynamicJaxprTracer, return_aux: set, attention_logit_biases: NoneType, segment_ids: NoneType)
File "/root/axlearn/common/attention.py", line 1871, in forward
_, output = self._forward_for_mode(
File "/root/axlearn/common/attention.py", line 1760, in _forward_for_mode
context, probs = self._compute_attention(
File "/root/axlearn/common/flash_attention/layer.py", line 189, in _compute_attention
cfg.mha_dim_to_partition_spec["bt"],
KeyError: 'bt'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment