Skip to content

Instantly share code, notes, and snippets.

View antferdom's full-sized avatar

A.J antferdom

View GitHub Profile
@antferdom
antferdom / flash3.py
Created September 16, 2024 11:33
FlashAttention v3 within torch.compile compatible
from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple
import torch
try:
from flash_attn_interface import flashattn_hopper_cuda as _C_flashattention3
except ImportError:
# We end up here is arch is not 90a
_C_flashattention3 = None
import torch
import torch.nn.functional as F
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / x.abs().max().clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
# see https://github.com/pytorch/torchsnapshot/blob/main/benchmarks/fsdp/main.py
import torch
from torch import distributed as dist, nn
def create_model() -> nn.Module:
# 7.8GB model, 1.9B parameters
model = nn.Transformer(
d_model=864,
num_encoder_layers=1,
num_decoder_layers=20,
import torch
model_size = sum(
p.numel() * p.element_size() for p in model.parameters() if p.requires_grad
)
model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
@antferdom
antferdom / playground.py
Last active September 15, 2023 09:26
CoT_Q_A_RL
from typing import List, Dict, Tuple, Union, Optional
def largest_prime_factor(n: int):
"""Return the largest prime factor of n. Assume n > 1 and is not a prime.
Q: What/How is a prime factor of n?
A: A prime factor of n is a prime number that divides n.
Q: What/how a number is not prime?
A: A number is not prime if it has more than two factors.
Examples:
Q: 15