Created
June 17, 2024 08:37
-
-
Save cyb70289/9397f6b3ebf56fd9dc763d79dbe01f13 to your computer and use it in GitHub Desktop.
tf multi process
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################## | |
# profile.py | |
################################################################## | |
import tensorflow as tf | |
import timeit | |
import os | |
n_threads = int(os.getenv('OMP_NUM_THREADS')) | |
if n_threads < 1 or n_threads > 999: | |
raise Exception(f'invalid n_threads {n_threads}') | |
M, N, K = 512, 256, 128 | |
BATCH = 2048 // (64 // n_threads) | |
print(f'batch = {BATCH}') | |
tf.config.threading.set_intra_op_parallelism_threads(n_threads) | |
tf.config.threading.set_inter_op_parallelism_threads(n_threads) | |
tf.random.set_seed(42) | |
a = tf.random.uniform((BATCH, M, K), dtype=tf.float32) | |
b = tf.random.uniform((BATCH, K, N), dtype=tf.float32) | |
def mm(): | |
return tf.matmul(a, b) | |
while True: | |
duration = timeit.timeit(mm, number=n_threads) | |
print('threads,ops') | |
print(f'{n_threads},{BATCH*n_threads/duration:.0f}') | |
################################################################## | |
# test steps | |
################################################################## | |
1x64 | |
---- | |
# one process with 64 threads | |
$ OMP_NUM_THREADS=64 numactl -m0 -N0 python3 profile.py | |
batch = 2048 | |
ops = 56429 | |
bw-r = 133M*8*32 | |
bw-w = 166M*8*32 | |
2x32 | |
---- | |
# 2 processes with 32 threads each | |
$ OMP_NUM_THREADS=32 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=32 numactl -m0 -N0 python3 profile.py | |
batch = 1024 | |
ops = 30051 * 2 | |
bw-r = 142M*8*32 | |
bw-w = 174M*8*32 | |
4x16 | |
---- | |
# 4 processes with 16 threads each | |
$ OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=16 numactl -m0 -N0 python3 profile.py | |
batch = 512 | |
ops = 15591 * 4 | |
bw-r = 146M*8*32 | |
bw-w = 177M*8*32 | |
8x8 | |
--- | |
# 8 processes with 8 threads each | |
$ OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py & \ | |
OMP_NUM_THREADS=8 numactl -m0 -N0 python3 profile.py | |
batch = 256 | |
ops = 8144 * 8 | |
bw-r = 151M*8*32 | |
bw-w = 185M*8*32 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment