Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Created January 12, 2023 14:54
Show Gist options
  • Save OhadRubin/e5936ab7c604b94b42f052596b5d7238 to your computer and use it in GitHub Desktop.
Save OhadRubin/e5936ab7c604b94b42f052596b5d7238 to your computer and use it in GitHub Desktop.
KL_controller

This is a KL controller: """ class AdaptiveKLController: """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences" Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py """

def __init__(self, init_kl_coef: float, target: float, horizon: int):
    self.value = init_kl_coef
    self.target = target
    self.horizon = horizon

def update(self, current: float, n_steps: int):
    """Returns adaptively updated KL coefficient, βₜ₊₁.
    Arguments:
        current: The current KL value between the newest policy and the initial policy.
    """
    proportional_error = np.clip(current / self.target - 1, -0.2, 0.2)  # ϵₜ
    mult = 1 + proportional_error * n_steps / self.horizon
    self.value *= mult  # βₜ₊₁

"""

The KL controller is a class that is used to adaptively update the KL divergence coefficient, which is a measure of the difference between two probability distributions. The KL divergence is often used in reinforcement learning to measure the difference between the current policy and the initial policy. The controller takes in three parameters: an initial KL coefficient, a target value, and a horizon (number of steps). The update method takes in the current KL value and the number of steps taken, and returns an updated KL coefficient based on a proportionate error calculation. The updated coefficient is then used to adjust the difference between the current policy and the initial policy in order to try to reach the target value within the given horizon.

Section 2.2 further explains: """ Models trained with different seeds and the same KL penalty β sometimes end up with quite different values of KL(π, ρ), making them hard to compare. To fix this, for some experiments we dynamically vary β to target a particular value of KL(π, ρ) using the log-space proportional controller. """ Can you try to explain to me how does this loss function keeps the KL at a specific value? The KL controller class is used to dynamically adjust the KL divergence coefficient, β, in order to target a specific value of KL divergence (KL(π, ρ)) between the current policy π and the initial policy ρ. The controller uses a proportional error calculation (proportional_error = np.clip(current / self.target - 1, -0.2, 0.2)) that compares the current KL divergence value to the target value, and uses this error value to adjust the coefficient by a multiplicative factor (mult = 1 + proportional_error * n_steps / self.horizon). The updated coefficient (self.value *= mult) is then used in the KL divergence calculation to adjust the difference between the current policy and the initial policy.

By doing this way, the controller aims to keep the KL divergence value close to the target value by adjusting the coefficient as the training progresses. The controller will adjust the coefficient in a way that when the KL divergence is higher than the target value, it will decrease the coefficient, and when the KL divergence is lower than the target value, it will increase the coefficient. This way, the controller tries to make the KL divergence to be as close as possible to the target value.

By doing this way, the algorithm will be training models with similar KL divergence values, which makes it easier to compare the results of models trained with different seeds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment