DeepSpeed🔗
DeepSpeed is an optimization library for PyTorch aimed at reducing memory use and improving parallelism. In particular, it enables strategies such as CPU- or hard drive-based offloading of large models and improves the efficiency of multi-gpu and multi-node training, using its so-called Zero Redundancy Optimizer (ZeRO). For this reason, it is particularly well suited to training large language models, which are often too large to fit on a single node.
Quick guide🔗
- For single-node jobs you can launch DeepSpeed-accelerated training with
torchrun ... --deepspeed /path/to/ds_config.json
or withdeepspeed ... --deepspeed /path/to/ds_config.json
. The deepspeed launcher uses torchrun, but can take some additional arguments, like--num_gpus
. - For multi-node jobs, it is easiest to use
torchrun
with a deepspeed config. - A variety of settings can be adjusted in the
ds_config
json file, including the ZeRO optimizer's offloading parameters, as well as modifying the optimizater and scheduling parameters.
To run DeepSpeed, we will use an apptainer image based on a PyTorch image:
bootstrap: localimage
from: /apps/containers/PyTorch/PyTorch-2.3.0-NGC-24.04.sif
%post
apt update
apt install -y libaio-dev
pip install deepspeed==0.14.4 datasets==2.20.0 evaluate==0.4.2 \
accelerate==0.33.0 sacrebleu==2.4.2 sentencepiece==0.2.0
git clone https://github.com/huggingface/transformers -b "v4.43.3" -- "/opt/transformers"
pip install "/opt/transformers"
Single-node example🔗
For this example we will use a configuration based on examples provided by HuggingFace
.
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
In this case, we simply distribute the model across the GPU:s available for each node, and we let the optimizer determine a lot of the optimization hyperparameters. We can run it using the deepspeed
command, which is a wrapper for the PyTorch launcher:
#!/bin/bash
#SBATCH --partition=alvis
#SBATCH --account=YOUR_ACCOUNT
#SBATCH --time=0-08:00:00
#SBATCH --job-name=test_single_node
#SBATCH --gpus-per-node=A40:4
#SBATCH --output=%x-%j.out
APPT_CMD="apptainer exec --nv /path/to/deepspeed/image.sif"
DSCONFIG=/path/to/deepspeed/config.json
MODEL=bigscience/T0_3B
EXEC=/opt/transformers/examples/pytorch/question-answering/run_seq2seq_qa.py
ARGS="--model_name_or_path ${MODEL} --dataset_name squad --do_train --do_eval \
--per_device_train_batch_size 1 --per_device_eval_batch_size 1 --max_train_samples 1000 \
--learning_rate 3e-5 --num_train_epochs 1 --max_seq_length 384 --doc_stride 128 \
--max_eval_samples 1000 --output_dir /path/to/output \
--overwrite_output_dir --deepspeed ${DSCONFIG}"
$APPT_CMD deepspeed $EXEC $ARGS
In this case, we use the model T0_3B
which would otherwise be too large to fit on a single GPU. Once the job is known to successfully launch, parameters such as batch size, learning rate, number of epochs, and so on, can be adjusted according to user needs.
Multi-node example🔗
For the multi-node example, we can use the same optimizer configuration, but we need to modify our execution approach. We cannot use the deepspeed
command directly in this case.
We again need a deepspeed
image, and we will run an example from the local transformers
repository on the image.
#!/bin/bash
#SBATCH --partition=alvis
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1
#SBATCH --account=YOUR_ACCOUNT
#SBATCH --time=0-08:00:00
#SBATCH --job-name=test_multinode
#SBATCH --gpus-per-node=A40:4
#SBATCH --output=%x-%j.out
export GPUS_PER_NODE=4
export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))
APPT_CMD="apptainer exec --nv /path/to/deepspeed/image.sif"
DSCONFIG=/path/to/deepspeed/config.json
MODEL=bigscience/T0
EXEC=/opt/transformers/examples/pytorch/question-answering/run_seq2seq_qa.py
ARGS="--model_name_or_path ${MODEL} --dataset_name squad --do_train --do_eval \
--per_device_train_batch_size 1 --per_device_eval_batch_size 1 --max_train_samples 1000 \
--learning_rate 3e-5 --num_train_epochs 1 --max_seq_length 384 --doc_stride 128 \
--max_eval_samples 1000 --output_dir /path/to/output \
--overwrite_output_dir --deepspeed ${DSCONFIG}"
srun bash -c "$APPT_CMD python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \
--rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
$EXEC $ARGS"
Here we use the even larger bigscience/T0
model, which does not even fit across a whole node of A40
GPU:s. We therefore allocate 4 full nodes of A40 GPU:s, for a total of 16 GPU:s. We use the rendevouz
backend for communication between nodes, and some arithmetic to get a semi-random port number based on our job ID, making port collisions unlikely. Running across multiple GPU:s is less efficient, but allows the usage of larger models than would otherwise be possible.
Other options🔗
Deepspeed also supports offloading to regular RAM and to NVMe hard drives (you can use, e.g., a folder in your Mimer allocation for this), but these options come with large amounts of overhead without necessarily enabling the use of much larger models. Make sure you test and benchmark various options to find out what works best for your use-case.
To try these options, you can modify the zero_optimization
section of your configuration as follows:
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
and
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "nvme",
"nvme_path": "/mimer/path/to/you/allocation",
"buffer_size": 6e8,
"buffer_count": 2,
"max_in_cpu": 0
},
respectively.