Skip to content

Instantly share code, notes, and snippets.

@Steboss
Created June 5, 2025 09:29
Show Gist options
  • Save Steboss/1dfb50e2e912c3a732f7d0816bc6ff6b to your computer and use it in GitHub Desktop.
Save Steboss/1dfb50e2e912c3a732f7d0816bc6ff6b to your computer and use it in GitHub Desktop.
detailed version of submit file
#!/bin/bash
#SBATCH -A your_account
#SBATCH -p your_partition
#SBATCH -N 4 # example with 4 nodes
#SBATCH -t 04:00:00 # max run time
#SBATCH -J "something_in_line_with_your_system"
export CONFIG="fuji-70B-v2-flash" # here you can insert, for example, fuji-7B-v2-flash
export CONTAINER="ghcr.io/nvidia/jax:axlearn" # this is our public jax-axlearn container
export MOUNTS="--container-mounts=/home/workspace/:/opt/host" # here you can mount your local folder to the container
export EXPORTS="--export=ALL" # it's usually ok to export all the envs
# Number of GPUs per node
export GPUS_PER_NODE=8 # we want one process per GPU so ntasks = GPUS_PER_NODE
export WORLD_SIZE=$((GPUS_PER_NODE * SLURM_JOB_NUM_NODES)) # gpus per node * num_nodes, this is for output log only
export BASE_DIR="/opt/host/YOUR_FOLDER" # this is the container directory. YOUR_FOLDER is your local folder where the scripts are
export BASE_SCRIPT="fuji-train-perf.py" # this is the script we are going to use for running fuji
export LOCAL_BASE_DIR="/home/workspace/YOUR_FOLDER" # this is your local folder, the one that we're mounting in the container
# NB: in LOCAL_BASE_DIR we have fuji-train-perf.py and the slurm submission job submit.sh
export GBS=${WORLD_SIZE} # this is the global batch size, the minimum is the total number of GPUs, but it can be more
# This is very important, as it wraps the command to run, so that SLURM vars are inserted AFTER the job gets submitted
read -r -d '' cmd <<'EOF'
# here you can add all the other XLA_FLAGS
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
cd ${BASE_DIR}
python3 $BASE_SCRIPT --output_log_file=/opt/host/output.log --module=text.gpt.c4_trainer --config=${CONFIG} --jax_backend=gpu --trainer_dir=/opt/host/axlearn-checkpoints --data_dir=gs://axlearn-public/tensorflow_datasets --ici_fsdp=8 --dcn_dp=4 --gbs=${GBS} --ga=1 --seq_len=4096 --max_step=301 --write_summary_steps=300 --num_processes=${SLURM_NTASKS} --distributed_coordinator=${SLURM_LAUNCH_NODE_IPADDR}:12345 --process_id=${SLURM_PROCID} --world_size=${WORLD_SIZE}
EOF
# --output_log_file=this is where we're writing the final metrics
# --trainer_dir is specifically from AXLearn
# --ici_fsdp=8 --dcn_dp=4 here I am running on 4 nodes with 8 GPUs each
# alternatively you could give --ici_fsdp= NUMBER_OF_GPUS_PER_NODE and --dcn_fsdp= NUMBER_OF_NODES
# here we need to discuss how AXLearn is running these tests and what's the parallelism for comparison
# of course --ici_fsdp + dcn_fsdp will be better than ici_fsdp + dcn_dp
# --num_processes=${SLURM_NTASKS} --distributed_coordinator=${SLURM_LAUNCH_NODE_IPADDR}:12345 --process_id=${SLURM_PROCID}
# All these variables must be retrieved AFTER the job gets submitted, that's why we're wrapping this between EOF
# Create a run-specific output directory for ease of analysis
FOLDER="${LOCAL_BASE_DIR}/outputs/${CONFIG}-N${SLURM_JOB_NUM_NODES}-n${WORLD_SIZE}"
mkdir -p "${FOLDER}"
# Redirect both stdout and stderr to the same file
OUTFILE="${FOLDER}/output-%j.txt"
echo "Running command: $cmd"
# Use srun to launch the container and run the command
srun \
-o "${OUTFILE}" \
-e "${OUTFILE}" \
--container-image="${CONTAINER}" \
${MOUNTS} \
${EXPORTS} \
--container-remap-root \
bash -c "${cmd}"
set +x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment