Skip to content

Commit

Permalink
doc: add PyTorch parallel training content (#3379)
Browse files Browse the repository at this point in the history
Signed-off-by: Lysithea <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
3 people authored Mar 2, 2024
1 parent 7aee42c commit 822be1e
Showing 1 changed file with 102 additions and 7 deletions.
109 changes: 102 additions & 7 deletions doc/train/parallel-training.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Parallel training {{ tensorflow_icon }}
# Parallel training {{ tensorflow_icon }} {{ pytorch_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}
:::

Currently, parallel training is enabled in a synchronized way with help of [Horovod](https://github.com/horovod/horovod).
## TensorFlow Implementation {{ tensorflow_icon }}
Currently, parallel training in tensorflow version is enabled in a synchronized way with help of [Horovod](https://github.com/horovod/horovod).
Depending on the number of training processes (according to MPI context) and the number of GPU cards available, DeePMD-kit will decide whether to launch the training in parallel (distributed) mode or in serial mode. Therefore, no additional options are specified in your JSON/YAML input file.

## Tuning learning rate
### Tuning learning rate

Horovod works in the data-parallel mode, resulting in a larger global batch size. For example, the real batch size is 8 when {ref}`batch_size <training/training_data/batch_size>` is set to 2 in the input file and you launch 4 workers. Thus, {ref}`learning_rate <learning_rate>` is automatically scaled by the number of workers for better convergence. Technical details of such heuristic rule are discussed at [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677).

Expand All @@ -21,7 +22,7 @@ In some cases, it won't work well when scaling the learning rate by worker count
}
```

## Scaling test
### Scaling test

Testing `examples/water/se_e2_a` on an 8-GPU host, linear acceleration can be observed with the increasing number of cards.

Expand All @@ -32,7 +33,7 @@ Testing `examples/water/se_e2_a` on an 8-GPU host, linear acceleration can be ob
| 4 | 1.7635 | 56.71*4 | 3.29 |
| 8 | 1.7267 | 57.91*8 | 6.72 |

## How to use
### How to use

Training workers can be launched with `horovodrun`. The following command launches 4 processes on the same host:

Expand Down Expand Up @@ -68,7 +69,7 @@ Whether distributed workers are initiated can be observed in the "Summary of the
[0] DEEPMD INFO -----------------------------------------------------------------
```

## Logging
### Logging

What's more, 2 command-line arguments are defined to control the logging behavior when performing parallel training with MPI.
```
Expand All @@ -84,3 +85,97 @@ optional arguments:
means each process will output its own log (default:
master)
```

## PyTorch Implementation {{ pytorch_icon }}

Currently, parallel training in pytorch version is implemented in the form of PyTorch Distributed Data Parallelism [DDP](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html).
DeePMD-kit will decide whether to launch the training in parallel (distributed) mode or in serial mode depending on your execution command.

### Dataloader and Dataset
One of the major differences between two backends during training is that the PyTorch version employs a multi-threaded data loading utility [DataLoader](https://pytorch.org/docs/stable/data.html).
We utilize the PyTorch framework and have designed and implemented a multiprocessing data processing and loading system called DpLoaderSet based on torch DataLoader and Dataset.


First, we establish a DeepmdData class for each system, which is consistent with the TensorFlow version in this level. Then, we create a dataloader for each system, resulting in the same number of dataloaders as the number of systems. Next, we create a dataset for the dataloaders obtained in the previous step. This allows us to query the data for each system through this dataset, while the iteration pointers for each system are maintained by their respective dataloaders. Finally, a dataloader is created for the outermost dataset.

We achieve custom sampling methods using a weighted sampler. The length of the sampler is set to total_batch_num * num_workers.The parameter "num_workers" defines the number of threads involved in multi-threaded loading, which can be modified by setting the environment variable NUM_WORKERS (default: min(8, ncpus)).

> **Note** The underlying dataloader will use a distributed sampler to ensure that each GPU receives batches with different content in parallel mode, which will use sequential sampler in serial mode. In the TensorFlow version, Horovod shuffles the dataset using different random seeds for the same purpose..
```mermaid
flowchart LR
subgraph systems
subgraph system1
direction LR
frame1[frame 1]
frame2[frame 2]
end
subgraph system2
direction LR
frame3[frame 3]
frame4[frame 4]
frame5[frame 5]
end
end
subgraph dataset
dataset1[dataset 1]
dataset2[dataset 2]
end
system1 -- frames --> dataset1
system2 --> dataset2
subgraph distribted sampler
ds1[distributed sampler 1]
ds2[distributed sampler 2]
end
dataset1 --> ds1
dataset2 --> ds2
subgraph dataloader
dataloader1[dataloader 1]
dataloader2[dataloader 2]
end
ds1 -- mini batch --> dataloader1
ds2 --> dataloader2
subgraph index[index on Rank 0]
dl11[dataloader 1, entry 1]
dl21[dataloader 2, entry 1]
dl22[dataloader 2, entry 2]
end
dataloader1 --> dl11
dataloader2 --> dl21
dataloader2 --> dl22
index -- for each step, choose 1 system --> WeightedSampler
--> dploaderset --> bufferedq[buffered queue] --> model
```

### How to use

We use [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html#usage) to launch a DDP training session.

To start training with multiple GPUs in one node, set parameter `nproc_per_node` as the number of it:

```bash
torchrun --nproc_per_node=4 --no-python dp --pt train input.json
# Not setting `nproc_per_node` uses only 1 GPU
torchrun --no-python dp --pt train input.json
```

To train a model with a cluster, one can manually launch the task using the commands below (usually this should be done by your job management system). Set `nnodes` as the number of available nodes, `node_rank` as the rank of the current node among all nodes (not the rank of processes!), and `nproc_per_node` as the number of available GPUs in one node. Please make sure that every node can access the rendezvous address and port (`rdzv_endpoint` in the command), and has a same amount of GPUs.

```bash
# Running DDP on 2 nodes with 4 GPUs each
# On node 0:
torchrun --rdzv_endpoint=node0:12321 --nnodes=2 --nproc_per_node=4 --node_rank=0 --no_python dp --pt train tests/water/se_e2_a.json
# On node 1:
torchrun --rdzv_endpoint=node0:12321 --nnodes=2 --nproc_per_node=4 --node_rank=1 --no_python dp --pt train tests/water/se_e2_a.json
```
> **Note** Set environment variables to tune [CPU specific optimizations](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#cpu-specific-optimizations) in advance.
> **Note** for developers: `torchrun` by default passes settings as environment variables [(list here)](https://pytorch.org/docs/stable/elastic/run.html#environment-variables).
> To check forward, backward, and communication time, please set env var `TORCH_CPP_LOG_LEVEL=INFO TORCH_DISTRIBUTED_DEBUG=DETAIL`. More details can be found [here](https://pytorch.org/docs/stable/distributed.html#logging).

0 comments on commit 822be1e

Please sign in to comment.