Update: Fixed by pinning Levanter to latest commit of 21e2de22cce28d8f0747d36986c3c217dc6f6ff0 using submodules!
I tried running but got the following mesh error:
Traceback (most recent call last):
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/levanter/infra/ray_tpu.py", line 735, in run_on_pod_ray
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m tpu_results[future_to_index[f]] = TpuSuccess(ray.get(f))
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m return fn(*args, **kwargs)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m return func(*args, **kwargs)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/worker.py", line 2822, in get
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/worker.py", line 930, in get_objects
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m raise value.as_instanceof_cause()
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ray.exceptions.RayTaskError(ValueError): ^[[36mray::inference_worker_task()^[[39m (pid=11862, ip=10.164.0.41)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/rl_job.py", line 186, in inference_worker_task
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m worker = RolloutWorker(config=rollout_worker_config)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/rollout_worker.py", line 175, in __init__
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m self._build_models()
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/rollout_worker.py", line 280, in _build_models
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m initial_model = load_model_from_checkpoint(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/model_utils.py", line 88, in load_model_from_checkpoint
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m model = converter.load_pretrained(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/levanter/compat/hf_checkpoints.py", line 640, in load_pretrained
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/equinox/_jit.py", line 209, in __call__
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m return _call(self, False, args, kwargs)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m File "/home/ray/anaconda3/lib/python3.11/site-packages/equinox/_jit.py", line 259, in _call
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m marker, _, _ = out = jit_wrapper._cached(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ValueError: Mesh context manager should not be used with jit when backend or device is also specified as an argument to jit.
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m --------------------
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m --------------------