After installing jax, run with:
git clone https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e
cd 99e012090a56637b8dd8bb037374900e
python dirty_test.py| using JuMP, HiGHS, Graphs, LinearAlgebra, Test | |
| # Find one valid assignment with integer programming | |
| function lattice_sites(num_sites::Vector{Int}, multiplicity::Vector{Int}, num_atoms::Vector{Int}) | |
| model = Model(HiGHS.Optimizer) | |
| JuMP.set_silent(model) | |
| @variable(model, x[1:length(num_sites), 1:length(num_atoms)] >= 0, Int) | |
| for (ia, na) in enumerate(num_atoms) | |
| @constraint(model, sum(x[i, ia] * num_sites[i] for i in 1:length(num_sites)) == na) | |
| end |
| #!/usr/bin/env python | |
| import math | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| from sklearn.datasets import make_moons | |
| from torch import Tensor | |
| from tqdm import tqdm |
| # requires OMEinsum version >= 0.7 | |
| using Distributed | |
| using OMEinsum, CUDA | |
| println("find $(length(devices())) GPU devices") | |
| const procs = addprocs(length(devices())-nprocs()+1) | |
| const gpus = collect(devices()) | |
| const process_device_map = Dict(zip(procs, gpus)) | |
| @info process_device_map |
| # [email protected] | |
| # `jax.distributed.initialize` is available in jax-0.2.25. | |
| # $ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux. | |
| # Run this script on 2 GPU nodes, assuming 10.128.0.6 is the master node | |
| # python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=0 | |
| # python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=1 | |
| from absl import app | |
| from absl import flags |
After installing jax, run with:
git clone https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e
cd 99e012090a56637b8dd8bb037374900e
python dirty_test.py| macro _threads(ex) | |
| return quote | |
| if (Threads.nthreads() > 1) && (length(st) > 4096) | |
| $(Expr(:macrocall, Expr(:(.), :Threads, QuoteNode(Symbol("@threads"))), __source__, ex)) | |
| else | |
| $ex | |
| end | |
| end |> esc | |
| end |
| using TupleTools | |
| using Base.Cartesian | |
| using CuArrays, CUDAnative | |
| """ | |
| A naive implementation of `einsum!` | |
| * `ixs`: input tensor indices, | |
| * `xs`: input tensors, | |
| * `iy`: output tensor indices, | |
| * `y`: accumulated tensor, notice it is initialized to 0 as output! |