Skip to content

Instantly share code, notes, and snippets.

View Birch-san's full-sized avatar

Birch-san

View GitHub Profile
@Birch-san
Birch-san / gist:eddad13648725d47c71799c39e8361b2
Created May 29, 2025 13:07
Example API request for generating an image using a stored NAIv4 vibe. Uses vibe files created by https://gist.github.com/Birch-san/5eb62a4a5e4a1c4447a55e3a9faf8988
#!/usr/bin/env bash
set -eo pipefail
# https://stackoverflow.com/a/12194427/5257399
create() { # fd base [qualifier [suffix [max]]]
local fd="$1" base="$2" qualifier="${3-}" suffix="${4-.png}" max="${5-}"
local n=0 file
local - # ash-style local scoping of options in 4.4+
set -o noclobber
REPLY=
@Birch-san
Birch-san / vibev4_encode.sh
Last active May 29, 2025 13:04
Example API request for encoding a NAIv4 vibe. Doesn't include all the metadata (e.g. image thumbnail and which model it was encoded for) that the UI adds
#!/usr/bin/env bash
set -eo pipefail
# https://stackoverflow.com/a/12194427/5257399
create() { # fd base [qualifier [suffix [max]]]
local fd="$1" base="$2" qualifier="${3-}" suffix="${4-.png}" max="${5-}"
local n=0 file
local - # ash-style local scoping of options in 4.4+
set -o noclobber
REPLY=
@Birch-san
Birch-san / attn_jvp_test.py
Created May 27, 2025 01:34
Test stub for comparing jvp of memory-efficient attention against reference implementation
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional
from typing_extensions import override
import torch
from torch import Tensor, no_grad, enable_grad
import torch.autograd.forward_ad as fwAD
from torch.autograd.function import FunctionCtx
from torch.nn import Linear, Module
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
@Birch-san
Birch-san / matmul_via_vmap.py
Last active April 28, 2025 01:22
How to implement mm, bmm and matmul in pytorch via vmap
import torch
from torch import FloatTensor
def mm(a: FloatTensor, b: FloatTensor) -> FloatTensor:
assert a.ndim == 2
assert b.ndim == 2
assert a.size(-1) == b.size(-2)
assert a.size(-2) == b.size(-1)
# batched dot product
def bdp(a_row: FloatTensor, b: FloatTensor) -> FloatTensor:
@Birch-san
Birch-san / vmap_repro.py
Created April 17, 2025 13:00
why can't I invoke vmapped attention with a mask? why doesn't vmap unbind my mask's batch dim?
from typing import Optional
import torch
from torch import FloatTensor, BoolTensor, Tensor, inference_mode
from torch.func import functional_call, stack_module_state
from torch.nn import Module, Linear
from torch.nn.functional import scaled_dot_product_attention
from einops import rearrange
class Attention(Module):
def __init__(
@Birch-san
Birch-san / danbooru-tag-to-prompt-bookmarklet.js
Created March 3, 2025 19:39
Danbooru tag to prompt bookmarklet
javascript: (async function copyTags() {
const replacements = {
v: "peace sign",
"double v": "double peace",
"|_|": "bar eyes",
"\\||/": "opem \\m/",
":|": "neutral face",
";|": "neutral face",
"eyepatch bikini": "square bikini",
"tachi-e": "character image",
@Birch-san
Birch-san / stratified_sampling.py
Created January 15, 2025 21:24
Draw from a uniform distribution, stratified
from typing import NamedTuple, Sequence, Optional
import torch
from torch import FloatTensor, LongTensor
class DevicePlacement(NamedTuple):
global_rank: int
world_size: int
class GradAcc(NamedTuple):
acc_step: int
@Birch-san
Birch-san / bench_repro.py
Created November 17, 2024 18:09
Enabling --count-flops-early (run a model under FlopCounterMode before benchmarkign it) regresses the performance of the compiled model
import argparse
import math
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
from einops import rearrange
from torch import (
BoolTensor,
@Birch-san
Birch-san / t5_enc_attn_bench.py
Last active October 23, 2024 20:05
Benchmark various ways of doing T5 Encoder flex_attention against SDPA
from enum import Enum
from typing import Callable, Optional, Any
from einops import rearrange
from dataclasses import dataclass
import math
import torch
from torch import FloatTensor, LongTensor, IntTensor, BoolTensor, ByteTensor, no_grad, inference_mode
from torch.nn import Embedding, Linear, Module
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, _score_mod_signature, _mask_mod_signature
from torch.nn.functional import scaled_dot_product_attention
@Birch-san
Birch-san / segfault.txt
Created August 21, 2024 22:13
stable-fast torch.jit.trace segfault
Caught signal 11 (Segmentation fault: address not mapped to object at address 0x20)
==== backtrace (tid: 63632) ====
0 0x0000000000042520 __sigaction() ???:0
1 0x0000000006e9fe76 torch::jit::InterpreterStateImpl::callstack() interpreter.cpp:0
2 0x0000000006ea0172 torch::jit::InterpreterStateImpl::handleError() interpreter.cpp:0
3 0x0000000006eac9fb torch::jit::InterpreterStateImpl::runTemplate<false>() interpreter.cpp:0
4 0x0000000006eb0585 torch::jit::InterpreterStateImpl::run() interpreter.cpp:0
5 0x0000000006e897b3 torch::jit::GraphExecutorImplBase::run() graph_executor.cpp:0
6 0x0000000000d3d859 torch::jit::runAndInsertCall() python_custom_class.cpp:0
7 0x0000000000e4208b torch::jit::invokeScriptMethodFromPython() script_init.cpp:0