Skip to content

Commit 86ed9e4

Browse files
jafraustrosoumith
authored andcommitted
Enhance README and examples for Tensor Parallelism
- Added installation instructions and example running commands to README.md. - Update files to have a better organization Signed-off-by: jafraustro <[email protected]>
1 parent 698a89e commit 86ed9e4

File tree

4 files changed

+94
-83
lines changed

4 files changed

+94
-83
lines changed

distributed/tensor_parallelism/README.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,25 @@ PyTorch native Tensor Parallel APIs, which include:
1010
More details about the PyTorch native Tensor Parallel APIs, please see PyTorch docs:
1111
https://pytorch.org/docs/stable/distributed.tensor.parallel.html
1212

13-
```
13+
## Installation
14+
15+
```bash
1416
pip install -r requirements.txt
15-
torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
1617
```
18+
19+
## Running Examples
20+
21+
You can run the examples using `torchrun` to launch distributed training:
22+
23+
```bash
24+
# Simple Tensor Parallel example
25+
torchrun --nnodes=1 --nproc_per_node=4 tensor_parallel_example.py
26+
27+
# Tensor Parallel with Sequence Parallel
28+
torchrun --nnodes=1 --nproc_per_node=4 sequence_parallel_example.py
29+
30+
# FSDP + Tensor Parallel with Llama2 model
31+
torchrun --nnodes=1 --nproc_per_node=4 fsdp_tp_example.py
32+
```
33+
34+
For more details, check the `run_examples.sh` script.

distributed/tensor_parallelism/fsdp_tp_example.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,3 @@
1-
import sys
2-
import os
3-
import torch
4-
import torch.distributed as dist
5-
import torch.nn as nn
6-
import torch.nn.functional as F
7-
8-
from log_utils import rank_log, get_logger, verify_min_gpu_count
9-
10-
# ---- GPU check ------------
11-
_min_gpu_count = 4
12-
13-
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
14-
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
15-
sys.exit()
16-
# ---------------------------
17-
18-
from llama2_model import Transformer, ModelArgs
19-
20-
from torch.distributed.device_mesh import init_device_mesh
21-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
22-
from torch.distributed._tensor import Shard, Replicate
23-
from torch.distributed.tensor.parallel import (
24-
parallelize_module,
25-
ColwiseParallel,
26-
RowwiseParallel,
27-
PrepareModuleInput,
28-
SequenceParallel
29-
)
30-
31-
321
"""
332
This is the script to test 2D Parallel which combines Tensor/Sequence
343
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example
@@ -60,6 +29,36 @@
6029
https://pytorch.org/tutorials/intermediate/TP_tutorial.html
6130
"""
6231

32+
import sys
33+
import os
34+
import torch
35+
import torch.distributed as dist
36+
import torch.nn as nn
37+
import torch.nn.functional as F
38+
39+
from log_utils import rank_log, get_logger, verify_min_gpu_count
40+
41+
# ---- GPU check ------------
42+
_min_gpu_count = 4
43+
44+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
45+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
46+
sys.exit()
47+
# ---------------------------
48+
49+
from llama2_model import Transformer, ModelArgs
50+
51+
from torch.distributed.device_mesh import init_device_mesh
52+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
53+
from torch.distributed._tensor import Shard, Replicate
54+
from torch.distributed.tensor.parallel import (
55+
parallelize_module,
56+
ColwiseParallel,
57+
RowwiseParallel,
58+
PrepareModuleInput,
59+
SequenceParallel
60+
)
61+
6362
tp_size = 2
6463
logger = get_logger()
6564

distributed/tensor_parallelism/sequence_parallel_example.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
1-
# The following is an example command to run this code
2-
# torchrun --nnodes 1 --nproc-per-node 4 sequence_parallel_example.py
1+
"""
2+
This is the script to test Sequence Parallel(SP) on a toy model in a
3+
Megetron-LM SPMD style. We show an E2E working flow from forward,
4+
backward and optimization.
5+
6+
We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
7+
in between to show an example of sequence parallel, which was proposed in paper:
8+
9+
https://arxiv.org/pdf/2205.05198.pdf.
10+
11+
Like tensor parallel, we parallelize the first linear layer by column
12+
and also parallelize the second linear layer by row. But the input in each rank
13+
now is different so that we need one all-gather for input and one reduce-scatter
14+
in the end of the second linear layer.
15+
16+
The following is an example command to run this code
17+
torchrun --nnodes 1 --nproc-per-node 4 sequence_parallel_example.py
18+
"""
19+
320
import os
421
import sys
522
import torch
@@ -24,28 +41,8 @@
2441
sys.exit()
2542
# ---------------------------
2643

27-
2844
from torch.distributed._tensor.device_mesh import init_device_mesh
2945

30-
31-
32-
"""
33-
This is the script to test Sequence Parallel(SP) on a toy model in a
34-
Megetron-LM SPMD style. We show an E2E working flow from forward,
35-
backward and optimization.
36-
37-
We use the example of two `nn.Linear` layers with an element-wise `nn.RELU`
38-
in between to show an example of sequence parallel, which was proposed in paper:
39-
40-
https://arxiv.org/pdf/2205.05198.pdf.
41-
42-
Like tensor parallel, we parallelize the first linear layer by column
43-
and also parallelize the second linear layer by row. But the input in each rank
44-
now is different so that we need one all-gather for input and one reduce-scatter
45-
in the end of the second linear layer.
46-
"""
47-
48-
4946
class ToyModel(nn.Module):
5047
"""MLP based model"""
5148

distributed/tensor_parallelism/tensor_parallel_example.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,3 @@
1-
# The following is an example command to run this code
2-
# torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
3-
import os
4-
import sys
5-
import torch
6-
import torch.nn as nn
7-
8-
from torch.distributed.tensor.parallel import (
9-
parallelize_module,
10-
ColwiseParallel,
11-
RowwiseParallel,
12-
)
13-
14-
from log_utils import rank_log, get_logger, verify_min_gpu_count
15-
16-
# ---- GPU check ------------
17-
_min_gpu_count = 2
18-
19-
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
20-
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
21-
sys.exit()
22-
# ---------------------------
23-
24-
from torch.distributed._tensor.device_mesh import init_device_mesh
25-
26-
27-
28-
291
"""
302
This is the script to test Tensor Parallel(TP) on a toy model in a
313
Megetron-LM SPMD style. We show an E2E working flow from forward,
@@ -55,8 +27,33 @@
5527
to use and our `parallelize_module` API will parse and parallelize the modules
5628
based on the given `ParallelStyle`. We are using this PyTorch native Tensor
5729
Parallelism APIs in this example to show users how to use them.
30+
31+
The following is an example command to run this code
32+
torchrun --nnodes 1 --nproc-per-node 4 tensor_parallel_example.py
5833
"""
5934

35+
import os
36+
import sys
37+
import torch
38+
import torch.nn as nn
39+
import torch.distributed as dist
40+
from torch.distributed.tensor.parallel import (
41+
parallelize_module,
42+
ColwiseParallel,
43+
RowwiseParallel,
44+
)
45+
from log_utils import rank_log, get_logger, verify_min_gpu_count
46+
47+
# ---- GPU check ------------
48+
_min_gpu_count = 2
49+
50+
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
51+
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
52+
sys.exit()
53+
# ---------------------------
54+
55+
from torch.distributed._tensor.device_mesh import init_device_mesh
56+
6057
class ToyModel(nn.Module):
6158
"""MLP based model"""
6259

0 commit comments

Comments
 (0)