Skip to content

Instantly share code, notes, and snippets.

@ryan-williams
Created November 3, 2025 20:46
Show Gist options
  • Save ryan-williams/14c9c9fd26a11967c9e768953ca041c8 to your computer and use it in GitHub Desktop.
Save ryan-williams/14c9c9fd26a11967c9e768953ca041c8 to your computer and use it in GitHub Desktop.

marin-community/marin#1726 Fix KL loss in the RL training.

Description

A few fixes for RL training.

  • We were computing tokens incorrectly by decoding from the token text vs from the logprob tokens. These can diverge when there are special tokens in the output.
  • Our KL loss was calculating the KL divergence, but not actually a penalty - the model was encouraged to diverge from the reference.
  • We were using the "old" mesh syntax in a number of locations.

Also adds a script to load a rollout pickle for debugging.

Update: Fixed by pinning Levanter to latest commit of 21e2de22cce28d8f0747d36986c3c217dc6f6ff0 using submodules!

I tried running but got the following mesh error:

Traceback (most recent call last):
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/levanter/infra/ray_tpu.py", line 735, in run_on_pod_ray
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     tpu_results[future_to_index[f]] = TpuSuccess(ray.get(f))
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m                                                  ^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     return fn(*args, **kwargs)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m            ^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     return func(*args, **kwargs)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m            ^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/worker.py", line 2822, in get
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/worker.py", line 930, in get_objects
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     raise value.as_instanceof_cause()
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ray.exceptions.RayTaskError(ValueError): ^[[36mray::inference_worker_task()^[[39m (pid=11862, ip=10.164.0.41)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/rl_job.py", line 186, in inference_worker_task
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     worker = RolloutWorker(config=rollout_worker_config)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/rollout_worker.py", line 175, in __init__
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     self._build_models()
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/rollout_worker.py", line 280, in _build_models
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     initial_model = load_model_from_checkpoint(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/tmp/ray/session_2025-10-05_20-36-16_257629_732/runtime_resources/working_dir_files/_ray_pkg_3c7e10bbed4a3c45/src/marin/rl/model_utils.py", line 88, in load_model_from_checkpoint
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     model = converter.load_pretrained(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m             ^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/levanter/compat/hf_checkpoints.py", line 640, in load_pretrained
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/equinox/_jit.py", line 209, in __call__
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     return _call(self, False, args, kwargs)
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m   File "/home/ray/anaconda3/lib/python3.11/site-packages/equinox/_jit.py", line 259, in _call
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m     marker, _, _ = out = jit_wrapper._cached(
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m                          ^^^^^^^^^^^^^^^^^^^^
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m ValueError: Mesh context manager should not be used with jit when backend or device is also specified as an argument to jit.
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m --------------------
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
^[[36m(run_on_pod_ray pid=2160, ip=10.164.2.50)^[[0m --------------------

delete the device=cpu_device line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment