Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created January 15, 2025 21:24
Show Gist options
  • Save Birch-san/cc7a16b4e073abe3d6c40cb1f7773adb to your computer and use it in GitHub Desktop.
Save Birch-san/cc7a16b4e073abe3d6c40cb1f7773adb to your computer and use it in GitHub Desktop.
Draw from a uniform distribution, stratified
from typing import NamedTuple, Sequence, Optional
import torch
from torch import FloatTensor, LongTensor
class DevicePlacement(NamedTuple):
global_rank: int
world_size: int
class GradAcc(NamedTuple):
acc_step: int
acc_steps: int
class Stratification(NamedTuple):
group: int
groups: int
def get_stratification(
placement = DevicePlacement(0, 1),
grad_acc = GradAcc(0, 1),
) -> Stratification:
global_rank, world_size = placement
rank = global_rank
world_size = world_size
acc_step, acc_steps = grad_acc
group = rank * acc_steps + acc_step
groups = world_size * acc_steps
return Stratification(group, groups)
def stratified_uniform(
shape: Sequence[int],
stratum = Stratification(0, 1),
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device | str | int] = None,
generator: torch.Generator = None,
) -> FloatTensor:
group, groups = stratum
"""Draws stratified samples from a uniform distribution."""
if groups <= 0:
raise ValueError(f"groups must be positive, got {groups}")
if group < 0 or group >= groups:
raise ValueError(f"group must be in [0, {groups})")
n = shape[-1] * groups
offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
u = torch.rand(shape, dtype=dtype, device=device, generator=generator)
return (offsets + u) / n
def scale_shift_positive(x: FloatTensor, low: int, high: int) -> LongTensor:
"""
adapts torch.rand(shape)
into torch.randint(low, high, shape)
the relevance of "positive" is that we can skip the .floor() if we know
that the input contains positive values only (as torch.rand() output does)
"""
# return (x * (high - low) + low).long()
return torch.add(low, x, alpha=high - low).long()
stratified_uniform((8,), generator=torch.Generator().manual_seed(42))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment