Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dongzhuoyao/82197fb190f13a1fcb762da97bf00238 to your computer and use it in GitHub Desktop.
Save dongzhuoyao/82197fb190f13a1fcb762da97bf00238 to your computer and use it in GitHub Desktop.
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=1
#SBATCH --cpus-per-task=16
#SBATCH --gres=gpu:1
#SBATCH --time=00:20:00
#SBATCH --output=logs/slurm-%j.out
#SBATCH --error=logs/slurm-%j.err
#SBATCH --mail-type=END,FAIL,BEGIN
#SBATCH [email protected]
#SBATCH -p lrz-hgx-h100-94x4
# 必要环境变量
export VLLM_ATTENTION_BACKEND=XFORMERS
export WANDB_API_KEY=xxxxxxxxxx
# 获取所有节点名称
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
# 获取 Head 节点 IP
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
# if we detect a space character in the head node IP, we'll
# convert it to an ipv4 address. This step is optional.
if [[ "$head_node_ip" == *" "* ]]; then
IFS=' ' read -ra ADDR <<<"$head_node_ip"
if [[ ${#ADDR[0]} -gt 16 ]]; then
head_node_ip=${ADDR[1]}
else
head_node_ip=${ADDR[0]}
fi
echo "IPV6 address detected. We split the IPV4 address as $head_node_ip"
fi
port=6379
ip_head=$head_node_ip:$port
export ip_head
source ~/.bashrc
conda activate rllm
echo "Head Node IP: $ip_head"
# 启动 Head 节点上的 Ray
echo "Starting HEAD node: $head_node"
srun --nodes=1 --ntasks=1 -w "$head_node" \
ray start --head --node-ip-address="$head_node_ip" --port=$port \
--num-cpus=$SLURM_CPUS_PER_TASK --num-gpus=$SLURM_GPUS_PER_NODE --block &
sleep 10 # 等待 Head 启动
# number of nodes other than the head node
worker_num=$((SLURM_JOB_NUM_NODES - 1))
for ((i = 1; i <= worker_num; i++)); do
node_i=${nodes_array[$i]}
echo "Starting WORKER $i at $node_i"
srun --nodes=1 --ntasks=1 -w "$node_i" \
ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block &
sleep 5
done
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_PATH="$2"
shift 2
;;
*)
break
;;
esac
done
# Set default model path if not provided
if [ -z "${MODEL_PATH:-}" ]; then
MODEL_PATH="hfmodels/DeepSeek-R1-Distill-Qwen-1.5B"
fi
MODEL_PATH=$(readlink -f "$MODEL_PATH")
# 启动训练(仍然在 Head 节点)
srun --nodes=2 --ntasks-per-node=1 -w "$head_node" \
python -m verl.trainer.main_ppo_remote \
algorithm.adv_estimator=grpo \
"data.train_files=${HOME}/deepscaler/data/train.parquet" \
"data.val_files=${HOME}/deepscaler/data/aime.parquet" \
data.train_batch_size=64 \
data.val_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=8192 \
"actor_rollout_ref.model.path=${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.temperature=0.6 \
actor_rollout_ref.rollout.val_temperature=0.6 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.n_val=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='deepscaler' \
trainer.experiment_name='deepscaler-1.5b-h100-8k' \
+trainer.val_before_train=False \
trainer.n_gpus_per_node=1 \
trainer.nnodes=2 \
trainer.save_freq=5 \
trainer.test_freq=5 \
trainer.default_hdfs_dir=null \
trainer.total_epochs=30 "$@"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment