Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
crowsonkb / lfq.py
Last active February 20, 2025 16:37
Lookup Free Quantization (LFQ) for PyTorch.
"""Lookup Free Quantization (LFQ) for PyTorch."""
from dataclasses import dataclass
from itertools import product
import math
from typing import Optional
import torch
from torch import distributed as dist, nn
from torch.distributed import nn as dnn
@crowsonkb
crowsonkb / plackett_luce.py
Last active November 5, 2024 21:33
Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties.
"""Plackett-Luce loss function for ranking tasks. The rankings may be partial and include ties."""
from itertools import chain
from typing import List, Optional, Tuple
import torch
def plackett_luce_loss(
scores: torch.Tensor,
@crowsonkb
crowsonkb / kld_noise_generator.py
Created October 25, 2024 17:15
Generates a smoothly varying standard normal time series.
"""Generates a smoothly varying standard normal time series."""
import numpy as np
import scipy.linalg
import torch
class KLDNoiseGenerator(torch.nn.Module):
"""Generates a smoothly varying standard normal time series.
@crowsonkb
crowsonkb / ring_attn.py
Created October 10, 2024 16:19
Ring attention for PyTorch.
"""Ring attention for PyTorch.
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py.
"""
import flash_attn.flash_attn_interface as fai
import torch
import torch.distributed as dist
@crowsonkb
crowsonkb / mos.py
Last active April 11, 2024 21:23
Mixture of Softmaxes
"""Mixture of Softmaxes"""
import torch
from torch.nn import functional as F
class MixtureOfSoftmaxes(torch.autograd.Function):
@staticmethod
def forward(ctx, x, p):
with torch.cuda.amp.autocast(enabled=False):
"""Grouped linear layer using https://github.com/tgale96/grouped_gemm."""
from dataclasses import dataclass
import warnings
import torch
from torch import nn
try:
@crowsonkb
crowsonkb / spo_loss.py
Last active June 10, 2024 15:38
Scalar Preference Optimization
"""Scalar Preference Optimization."""
import torch
from torch.nn import functional as F
def logp_completion(logits, tokens, mask):
"""Compute the log probabilities of completions given their prompts.
Args:
@crowsonkb
crowsonkb / reinforce.py
Last active June 30, 2023 19:12
REINFORCE with exponential moving average baseline
"""REINFORCE (DiCE) with exponential moving average baseline. Implements "DiCE: The Infinitely
Differentiable Monte Carlo Estimator (https://arxiv.org/abs/1802.05098)."""
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, Union
import itertools
import random
class WeightedSampler:
"""Samples k elements from a stream of weighted items without replacement.
See Weighted Random Sampling (Efraimidis, Spirakis 2005).
"""
"""Stochastic beam search.
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)"""
import math
import torch