Created
January 15, 2025 21:24
-
-
Save Birch-san/cc7a16b4e073abe3d6c40cb1f7773adb to your computer and use it in GitHub Desktop.
Draw from a uniform distribution, stratified
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import NamedTuple, Sequence, Optional | |
import torch | |
from torch import FloatTensor, LongTensor | |
class DevicePlacement(NamedTuple): | |
global_rank: int | |
world_size: int | |
class GradAcc(NamedTuple): | |
acc_step: int | |
acc_steps: int | |
class Stratification(NamedTuple): | |
group: int | |
groups: int | |
def get_stratification( | |
placement = DevicePlacement(0, 1), | |
grad_acc = GradAcc(0, 1), | |
) -> Stratification: | |
global_rank, world_size = placement | |
rank = global_rank | |
world_size = world_size | |
acc_step, acc_steps = grad_acc | |
group = rank * acc_steps + acc_step | |
groups = world_size * acc_steps | |
return Stratification(group, groups) | |
def stratified_uniform( | |
shape: Sequence[int], | |
stratum = Stratification(0, 1), | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[torch.device | str | int] = None, | |
generator: torch.Generator = None, | |
) -> FloatTensor: | |
group, groups = stratum | |
"""Draws stratified samples from a uniform distribution.""" | |
if groups <= 0: | |
raise ValueError(f"groups must be positive, got {groups}") | |
if group < 0 or group >= groups: | |
raise ValueError(f"group must be in [0, {groups})") | |
n = shape[-1] * groups | |
offsets = torch.arange(group, n, groups, dtype=dtype, device=device) | |
u = torch.rand(shape, dtype=dtype, device=device, generator=generator) | |
return (offsets + u) / n | |
def scale_shift_positive(x: FloatTensor, low: int, high: int) -> LongTensor: | |
""" | |
adapts torch.rand(shape) | |
into torch.randint(low, high, shape) | |
the relevance of "positive" is that we can skip the .floor() if we know | |
that the input contains positive values only (as torch.rand() output does) | |
""" | |
# return (x * (high - low) + low).long() | |
return torch.add(low, x, alpha=high - low).long() | |
stratified_uniform((8,), generator=torch.Generator().manual_seed(42)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment