Skip to content

Instantly share code, notes, and snippets.

View wangleiphy's full-sized avatar

Lei Wang wangleiphy

View GitHub Profile
@GiggleLiu
GiggleLiu / wyckoff_assign.jl
Last active January 26, 2025 07:21
Wyckoff position assignment - Given atoms counting and Wyckoff position multiplicity
using JuMP, HiGHS, Graphs, LinearAlgebra, Test
# Find one valid assignment with integer programming
function lattice_sites(num_sites::Vector{Int}, multiplicity::Vector{Int}, num_atoms::Vector{Int})
model = Model(HiGHS.Optimizer)
JuMP.set_silent(model)
@variable(model, x[1:length(num_sites), 1:length(num_atoms)] >= 0, Int)
for (ia, na) in enumerate(num_atoms)
@constraint(model, sum(x[i, ia] * num_sites[i] for i in 1:length(num_sites)) == na)
end
@francois-rozet
francois-rozet / flow_matching.py
Last active November 6, 2025 20:32
Flow Matching in 100 LOC
#!/usr/bin/env python
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.datasets import make_moons
from torch import Tensor
from tqdm import tqdm
@GiggleLiu
GiggleLiu / tensorcontract_multigpu.jl
Last active April 12, 2024 20:49
Slicing + multi-GPU for contracting OMEinsum tensor contraction
# requires OMEinsum version >= 0.7
using Distributed
using OMEinsum, CUDA
println("find $(length(devices())) GPU devices")
const procs = addprocs(length(devices())-nprocs()+1)
const gpus = collect(devices())
const process_device_map = Dict(zip(procs, gpus))
@info process_device_map
# [email protected]
# `jax.distributed.initialize` is available in jax-0.2.25.
# $ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
# Run this script on 2 GPU nodes, assuming 10.128.0.6 is the master node
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=0
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=1
from absl import app
from absl import flags
@jackd
jackd / README.md
Last active January 24, 2021 14:43
Generalized eigenvalue jvp implementation in jax

After installing jax, run with:

git clone https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e
cd 99e012090a56637b8dd8bb037374900e
python dirty_test.py
@Roger-luo
Roger-luo / tiny_yao.jl
Created December 20, 2020 07:05
Implement your own (full amplitude) top performance quantum circuit emulator in ONE day!
macro _threads(ex)
return quote
if (Threads.nthreads() > 1) && (length(st) > 4096)
$(Expr(:macrocall, Expr(:(.), :Threads, QuoteNode(Symbol("@threads"))), __source__, ex))
else
$ex
end
end |> esc
end
@shoyer
shoyer / jax-harmonic-oscillator-odeint.ipynb
Last active December 11, 2021 16:43
JAX harmonic oscillator odeint.ipynb
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@shoyer
shoyer / simple-jax-gmres.ipynb
Created July 7, 2020 21:11
Simple JAX GMRES
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@caryan
caryan / Quantum Optimal Control with SciML.ipynb
Created June 10, 2020 04:32
Quantum Optimal Control with SciML and Julia
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@GiggleLiu
GiggleLiu / naive_einsum.jl
Last active January 20, 2020 09:16
CUDAnative based einsum! on GPU - the prototype
using TupleTools
using Base.Cartesian
using CuArrays, CUDAnative
"""
A naive implementation of `einsum!`
* `ixs`: input tensor indices,
* `xs`: input tensors,
* `iy`: output tensor indices,
* `y`: accumulated tensor, notice it is initialized to 0 as output!