Skip to content

Instantly share code, notes, and snippets.

@cyb70289
Created June 17, 2024 08:37
Show Gist options
  • Save cyb70289/9397f6b3ebf56fd9dc763d79dbe01f13 to your computer and use it in GitHub Desktop.
Save cyb70289/9397f6b3ebf56fd9dc763d79dbe01f13 to your computer and use it in GitHub Desktop.
tf multi process
##################################################################
# profile.py
##################################################################
import tensorflow as tf
import timeit
import os
n_threads = int(os.getenv('OMP_NUM_THREADS'))
if n_threads < 1 or n_threads > 999:
raise Exception(f'invalid n_threads {n_threads}')
M, N, K = 512, 256, 128
BATCH = 2048 // (64 // n_threads)
print(f'batch = {BATCH}')
tf.config.threading.set_intra_op_parallelism_threads(n_threads)
tf.config.threading.set_inter_op_parallelism_threads(n_threads)
tf.random.set_seed(42)
a = tf.random.uniform((BATCH, M, K), dtype=tf.float32)
b = tf.random.uniform((BATCH, K, N), dtype=tf.float32)
def mm():
return tf.matmul(a, b)
while True:
duration = timeit.timeit(mm, number=n_threads)
print('threads,ops')
print(f'{n_threads},{BATCH*n_threads/duration:.0f}')
##################################################################
# test steps
##################################################################
1x64
----
# one process with 64 threads
$ OMP_NUM_THREADS=64 numactl -m0 -N0 python3 profile.py
batch = 2048
ops = 56429
bw-r = 133M*8*32
bw-w = 166M*8*32
2x32
----
# 2 processes with 32 threads each
$ OMP_NUM_THREADS=32 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=32 numactl -m0 -N0 python3 profile.py
batch = 1024
ops = 30051 * 2
bw-r = 142M*8*32
bw-w = 174M*8*32
4x16
----
# 4 processes with 16 threads each
$ OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py
batch = 512
ops = 15591 * 4
bw-r = 146M*8*32
bw-w = 177M*8*32
8x8
---
# 8 processes with 8 threads each
$ OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py
batch = 256
ops = 8144 * 8
bw-r = 151M*8*32
bw-w = 185M*8*32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment