Last active
February 22, 2024 10:11
-
-
Save Findus23/eb5ecb9f65ccf13152cda7c7e521cbdd to your computer and use it in GitHub Desktop.
distributed 3D-rfft using JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
#SBATCH --time=01:00:00 | |
#SBATCH --nodes=2 #equal to -N 1 | |
#SBATCH --tasks-per-node=2 | |
#SBATCH --exclusive | |
#SBATCH --job-name=jax-fft-test | |
#SBATCH --gpus=4 | |
#SBATCH --output output/slurm-%j.out | |
nvidia-smi | |
source $DATA/venv-jax/bin/activate | |
cd ~/jax-testing/ | |
#export XLA_PYTHON_CLIENT_PREALLOCATE=false | |
#export XLA_PYTHON_CLIENT_ALLOCATOR=platform | |
srun --output "output/slurm-%2j-%2t.out" python -u main.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from pathlib import Path | |
import jax | |
import numpy as np | |
import scipy | |
from jax import jit | |
from jax.experimental import mesh_utils | |
from jax.experimental.multihost_utils import sync_global_devices | |
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding | |
import sharded_rfft_general | |
from utils import Timing, plot_graph | |
print("jax version", jax.__version__) | |
num_gpus = int(os.environ.get("SLURM_GPUS")) | |
# jax.config.update("jax_enable_x64", True) | |
def host_subset(array: jax.Array | np.ndarray, size: int): | |
host_id = jax.process_index() | |
start = host_id * size // num_gpus | |
end = (host_id + 1) * size // num_gpus | |
return array[:, start:end] | |
def print_subset(x): | |
print(x[0, :4, :4]) | |
def compare(a, b): | |
is_equal = np.allclose(a, b, rtol=1.e-2, atol=1.e-4) | |
print(is_equal) | |
diff = a - np.asarray(b) | |
max_value = np.max(np.real(out_ref_subs)) | |
max_diff = np.max(np.abs(diff)) | |
print("max_value", max_value) | |
print("max_diff", max_diff) | |
print("max_diff / max_value", max_diff / max_value) | |
print("distributed initialize") | |
jax.distributed.initialize() | |
timing = Timing(print) | |
print("CUDA_VISIBLE_DEVICES", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
print("devices:", jax.device_count(), jax.devices()) | |
print("local_devices:", jax.local_device_count(), jax.local_devices()) | |
print("process_index", jax.process_index()) | |
print("total number of GPUs:", num_gpus) | |
timing.log("random ICs") | |
size = 512 | |
rng = np.random.default_rng(12345) | |
x_np_full = rng.random((size, size, size), dtype=np.float32) | |
x_np = host_subset(x_np_full, size) | |
print("x_np shape", x_np.shape) | |
global_shape = (size, size, size) | |
timing.log("generated") | |
print(x_np.nbytes / 1024 / 1024 / 1024, "GB") | |
print(x_np.shape, x_np.dtype) | |
devices = mesh_utils.create_device_mesh((num_gpus,)) | |
mesh = Mesh(devices, axis_names=('gpus',)) | |
timing.log("start") | |
with mesh: | |
x_single = jax.device_put(x_np) | |
xshard = jax.make_array_from_single_device_arrays( | |
global_shape, | |
NamedSharding(mesh, P(None, "gpus")), | |
[x_single]) | |
rfftn_jit = jit( | |
sharded_rfft_general.rfftn, | |
donate_argnums=0, # doesn't help | |
in_shardings=(NamedSharding(mesh, P(None, "gpus"))), | |
out_shardings=(NamedSharding(mesh, P(None, "gpus"))) | |
) | |
irfftn_jit = jit( | |
sharded_rfft_general.irfftn, | |
donate_argnums=0, | |
in_shardings=(NamedSharding(mesh, P(None, "gpus"))), | |
out_shardings=(NamedSharding(mesh, P(None, "gpus"))) | |
) | |
if jax.process_index() == 0: | |
with jax.spmd_mode('allow_all'): | |
a = Path("compiled.txt") | |
a.write_text(rfftn_jit.lower(xshard).compile().as_text()) | |
z = jax.xla_computation(rfftn_jit)(xshard) | |
plot_graph(z) | |
sync_global_devices("wait for compiler output") | |
with jax.spmd_mode('allow_all'): | |
timing.log("warmup") | |
rfftn_jit(xshard).block_until_ready() | |
timing.log("calculating") | |
out_jit: jax.Array = rfftn_jit(xshard).block_until_ready() | |
print(out_jit.nbytes / 1024 / 1024 / 1024, "GB") | |
print(out_jit.shape, out_jit.dtype) | |
timing.log("inverse calculating") | |
out_inverse: jax.Array = irfftn_jit(out_jit).block_until_ready() | |
timing.log("collecting") | |
sync_global_devices("loop") | |
local_out_subset = out_jit.addressable_data(0) | |
local_inverse_subset = out_inverse.addressable_data(0) | |
print(local_out_subset.shape) | |
print_subset(local_out_subset) | |
# print("JAX output without JIT:") | |
# print_subset(out) | |
# print("JAX output with JIT:") | |
# # print_subset(out_jit) | |
# print("out_jit.shape1", out_jit.shape) | |
# print(out_jit.dtype) | |
timing.log("done") | |
out_ref = scipy.fft.rfftn(x_np_full, workers=128) | |
timing.log("ref done") | |
print("out_ref", out_ref.shape) | |
out_ref_subs = host_subset(out_ref, size) | |
print("out_ref_subs", out_ref_subs.shape) | |
print("JAX output with JIT:") | |
print_subset(local_out_subset) | |
print("Reference output:") | |
print_subset(out_ref_subs) | |
print("ref") | |
compare(out_ref_subs, local_out_subset) | |
print("inverse") | |
compare(x_np, local_inverse_subset) | |
print_subset(x_np) | |
print_subset(local_inverse_subset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable | |
import jax | |
from jax.experimental.custom_partitioning import custom_partitioning | |
from jax.sharding import PartitionSpec as P, NamedSharding | |
def fft_partitioner(fft_func: Callable[[jax.Array], jax.Array], partition_spec: P): | |
@custom_partitioning | |
def func(x): | |
return fft_func(x) | |
def supported_sharding(sharding, shape): | |
return NamedSharding(sharding.mesh, partition_spec) | |
def partition(arg_shapes, arg_shardings, result_shape, result_sharding): | |
return fft_func, supported_sharding(arg_shardings[0], arg_shapes[0]), ( | |
supported_sharding(arg_shardings[0], arg_shapes[0]),) | |
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape): | |
return supported_sharding(arg_shardings[0], arg_shapes[0]) | |
func.def_partition( | |
infer_sharding_from_operands=infer_sharding_from_operands, | |
partition=partition | |
) | |
return func | |
def _fft_XY(x): | |
return jax.numpy.fft.fftn(x, axes=[0, 1]) | |
def _fft_Z(x): | |
return jax.numpy.fft.rfft(x, axis=2) | |
def _ifft_XY(x): | |
return jax.numpy.fft.ifftn(x, axes=[0, 1]) | |
def _ifft_Z(x): | |
return jax.numpy.fft.irfft(x, axis=2) | |
fft_XY = fft_partitioner(_fft_XY, P(None, None, "gpus")) | |
fft_Z = fft_partitioner(_fft_Z, P(None, "gpus")) | |
ifft_XY = fft_partitioner(_ifft_XY, P(None, None, "gpus")) | |
ifft_Z = fft_partitioner(_ifft_Z, P(None, "gpus")) | |
def rfftn(x): | |
x = fft_Z(x) | |
x = fft_XY(x) | |
return x | |
def irfftn(x): | |
x = ifft_XY(x) | |
x = ifft_Z(x) | |
return x | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
import time | |
def plot_graph(z): | |
with open("t.dot", "w") as f: | |
f.write(z.as_hlo_dot_graph()) | |
with open("t.png", "wb") as f: | |
subprocess.run(["dot", "t.dot", "-Tpng"], stdout=f) | |
class Timing: | |
def __init__(self,print_func): | |
self.start = time.perf_counter() | |
self.last = self.start | |
self.print_func=print_func | |
def log(self, message: str) -> None: | |
now = time.perf_counter() | |
delta = now - self.start | |
self.print_func(f"{delta:.4f} / {now - self.last:.4f}: {message}") | |
self.last = now |
Hi,
I should have documented this, but I used the latest version back then, so probably 0.4.19.
I have gotten this error also a few times in the past and don't fully understand it. I can't test this code right now, but will report back if it still works like this in the latest version.
Hi, thanks for the reply!
I have resolved the problem. The bug is triggered because the custom_partitioning
signature has been changed. So the arg_shapes
variable in partition
would be filled by mesh
object, which is obviously incorrect.
Oh, I thought I updated this snippet after that change, but apparently I didn't.
I will do that in the next week.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, @Findus23 , Thank you very much for the great work!
Could you please tell me the jax version you used to run this code?
When using the latest version of jax (0.4.24) to run this code, I got a runtime error as follows and had no clue how to resolve it.
custom_partitioner: TypeError: 'Mesh' object is not subscriptable
Thanks in advance!