Last active
October 29, 2024 21:25
-
-
Save skye/f82ba45d2445bb19d53545538754f9a3 to your computer and use it in GitHub Desktop.
You can use these environment variables to run a Python process on a subset of the TPU cores on a Cloud TPU VM. This allows running multiple TPU processes at the same time, since only one process can access a given TPU chip at a time. Note that on TPU v2 and v3, 1 TPU chip = 2 TpuDevice as reported by `jax.devices()` (8 devices total). On v4, 1 …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ==== Non-communicating processes | |
# 4x 1 chip per process: | |
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1" | |
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1" | |
# Different per process: | |
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3" | |
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476 | |
# 2x 2 chips per process: | |
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,2,1" | |
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1" | |
# Different per process: | |
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3" | |
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,2,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476 | |
# 1x 4 chips for one process per host (default on v2-8, v3-8, v4-8): | |
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "2,2,1" | |
os.environ["TPU_PROCESS_BOUNDS"] = "1,1,1" | |
os.environ["TPU_VISIBLE_DEVICES"] = "0,1,2,3" | |
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=2,2,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1,2,3 | |
# ==== Communicating processes | |
# 4x 1 chip per process: | |
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,1,1" | |
os.environ["TPU_PROCESS_BOUNDS"] = "2,2,1" | |
os.environ["TPU_PROCESS_ADDRESSES"] = "localhost:8476,localhost:8477,localhost:8478,localhost:8479" | |
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3" | |
os.environ["TPU_PROCESS_PORT"] = "8476" # "8477", "8478", "8479" | |
os.environ["CLOUD_TPU_TASK_ID"] = "0" # "1", "2", "3" | |
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=2,2,1 TPU_PROCESS_ADDRESSES=localhost:8476,localhost:8477,localhost:8478,localhost:8479 TPU_VISIBLE_DEVICES=0 TPU_PROCESS_PORT=8476 CLOUD_TPU_TASK_ID=0 | |
# 2x 2 chips per process: | |
os.environ["TPU_CHIPS_PER_PROCESS_BOUNDS"] = "1,2,1" | |
os.environ["TPU_PROCESS_BOUNDS"] = "2,1,1" | |
os.environ["TPU_PROCESS_ADDRESSES"] = "localhost:8476,localhost:8477" | |
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3" | |
os.environ["TPU_PROCESS_PORT"] = "8476" # "8477" | |
os.environ["CLOUD_TPU_TASK_ID"] = "0" # "1" | |
# 1-liner for bash: TPU_CHIPS_PER_PROCESS_BOUNDS=1,2,1 TPU_PROCESS_BOUNDS=2,1,1 TPU_PROCESS_ADDRESSES=localhost:8476,localhost:8477 TPU_VISIBLE_DEVICES=0,1 TPU_PROCESS_PORT=8476 CLOUD_TPU_TASK_ID=0 |
What I meant is that you can run up to 4 single-chip processes on a single 2x2 TPU VM, by having 1 process per chip.
What I meant is that you can run up to 4 single-chip processes on a single 2x2 TPU VM, by having 1 process per chip.
I see. Thanks!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks Skye! I suppose the first line is a typo and it should be
# *1*x 1 chip (2 cores) per process: