Skip to content

Instantly share code, notes, and snippets.

View MilesCranmer's full-sized avatar

Miles Cranmer MilesCranmer

View GitHub Profile
@MilesCranmer
MilesCranmer / reduce_precision.py
Created July 1, 2022 20:35
Reduce precision of constants in a string
import re
def reduce_precision_of_constants_in_string(s, precision=3):
# Find all constants in the string:
constants = re.findall(r"\b[-+]?\d*\.\d+|\b[-+]?\d+\.?\d*", s)
for c in constants:
reduced_c = "{:.{precision}g}".format(float(c), precision=precision)
s = s.replace(c, reduced_c)
return s
@MilesCranmer
MilesCranmer / unique.c
Created June 26, 2022 01:58
Get unique elements of an array using a lookup table
// Count elements of an array by using a lookup table.
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <time.h>
int main(int argc, char *argv[])
{
// Generate random array of integers, with
// size given by args.
@MilesCranmer
MilesCranmer / python_syntax_for_gin.py
Last active May 7, 2022 01:33
Enable valid Python to be a config.gin file, so code analysis and syntax highlighting works
def preprocess_config(s: str):
"""Remove imports from a string representation of a python file"""
# We assume that imports are not multi-line.
lines = s.splitlines()
out_lines = []
for line in lines:
# Skip lines with import in them:
if 'import' in line:
continue
@MilesCranmer
MilesCranmer / torch_softclip.py
Last active June 13, 2024 08:55
Soft clipping in pytorch
def soft_clip(x, lo, hi, pct=0.1):
range = hi - lo
frac = (x - lo) / range
normalization = F.softplus(torch.ones_like(x))
for _ in ['lo', 'hi']:
frac = torch.where(frac > pct,
frac,
pct * F.softplus(frac / pct) / normalization
def acos2(num, denom, disamb):
cosine = num/denom
return torch.where((cosine > -1) & (cosine < 1.),
torch.acos(cosine) * torch.where(disamb < 0.0, -1, 1),
torch.where(cosine <= -1, np.pi, 0.0)
)
def coord_transform(x):
# Assumes in CoM frame
@MilesCranmer
MilesCranmer / backtrack_nans.md
Created March 3, 2021 19:24
Trick to step backwards if a NaN occurs

Before training:

last = model.state_dict()

Inside training loop, after computing loss:

if torch.isnan(loss).sum().item():
    model.load_state_dict(last)
else:
@MilesCranmer
MilesCranmer / mwe.jl
Last active January 27, 2021 22:44
Implementing SymbolicUtils.jl interface for SymbolicRegression.jl
using SymbolicUtils
mutable struct Node
#Holds operators, variables, constants in a tree
degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
val::Union{Float32, Integer, Nothing} #Either const value, or enumerates variable
constant::Bool #false if variable
op::Integer #enumerates operator (separately for degree=1,2)
l::Union{Node, Nothing}
r::Union{Node, Nothing}
@MilesCranmer
MilesCranmer / einop.py
Created January 5, 2021 23:18
Generic einops operation that performs a repeat, rearrange, or reduce based on indices
# Copy this into your code. Call with, e.g., einop(x, 'i j -> j', reduction='mean')
import functools
import einops as _einops
from einops.parsing import ParsedExpression
@functools.lru_cache(256)
def _match_einop(pattern: str, reduction=None, **axes_lengths: int):
"""Find the corresponding operation matching the pattern"""
left, rght = pattern.split('->')
left = ParsedExpression(left)
@MilesCranmer
MilesCranmer / analytic_approximation.py
Last active October 17, 2020 04:35
Analytic Approximation of log(1+erf(x)) for x in [-10, -5]. Uses PySR: https://pysr.readthedocs.io/, mpmath: http://mpmath.org/
import numpy as np
from mpmath import mp, mpmathify
from pysr import *
#Set precision to 200 decimal places:
mp.dps = 200
x = np.linspace(-10, -5, num=300)
#High precision calculation:
@MilesCranmer
MilesCranmer / matrix_box_notation.tex
Last active February 9, 2021 10:49
Draw boxes for matrices in equation
%Make sure to have \usepackage{tikz}
%https://tex.stackexchange.com/a/45815/140440 - for grid
%https://tex.stackexchange.com/a/381175/140440 - for alignment in equation
% This function draws a matrix.
\newcommand{\mat}[2]{% cols, rows
\vcenter{\hbox{ %Vertical alignment
\begin{tikzpicture}[scale=0.3, align=center]