This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Lookup Free Quantization (LFQ) for PyTorch.""" | |
from dataclasses import dataclass | |
from itertools import product | |
import math | |
from typing import Optional | |
import torch | |
from torch import distributed as dist, nn | |
from torch.distributed import nn as dnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.""" | |
from itertools import chain | |
from typing import List, Optional, Tuple | |
import torch | |
def plackett_luce_loss( | |
scores: torch.Tensor, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Generates a smoothly varying standard normal time series.""" | |
import numpy as np | |
import scipy.linalg | |
import torch | |
class KLDNoiseGenerator(torch.nn.Module): | |
"""Generates a smoothly varying standard normal time series. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Ring attention for PyTorch. | |
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py. | |
""" | |
import flash_attn.flash_attn_interface as fai | |
import torch | |
import torch.distributed as dist | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Mixture of Softmaxes""" | |
import torch | |
from torch.nn import functional as F | |
class MixtureOfSoftmaxes(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x, p): | |
with torch.cuda.amp.autocast(enabled=False): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm.""" | |
from dataclasses import dataclass | |
import warnings | |
import torch | |
from torch import nn | |
try: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Scalar Preference Optimization.""" | |
import torch | |
from torch.nn import functional as F | |
def logp_completion(logits, tokens, mask): | |
"""Compute the log probabilities of completions given their prompts. | |
Args: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely | |
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098).""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Optional, Union | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import random | |
class WeightedSampler: | |
"""Samples k elements from a stream of weighted items without replacement. | |
See Weighted Random Sampling (Efraimidis, Spirakis 2005). | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Stochastic beam search. | |
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for | |
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)""" | |
import math | |
import torch | |
NewerOlder