From 55493b2551995cb90c60e694ccaedc65a0b610a9 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 15 Jul 2025 11:10:57 -0700 Subject: [PATCH 1/3] fsdp1 -> fsdp2 --- intermediate_source/TP_tutorial.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/intermediate_source/TP_tutorial.rst b/intermediate_source/TP_tutorial.rst index 91e64a87488..4108e72b02b 100644 --- a/intermediate_source/TP_tutorial.rst +++ b/intermediate_source/TP_tutorial.rst @@ -333,7 +333,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import fully_shard # i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP mesh_2d = init_device_mesh("cuda", (8, 8)) @@ -347,7 +347,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w # apply Tensor Parallel intra-host on tp_mesh model_tp = parallelize_module(model, tp_mesh, tp_plan) # apply FSDP inter-host on dp_mesh - model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...) + model_2d = fully_shard(model_tp, mesh=dp_mesh, ...) This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model. From 90c66f89a133cc1fafef461d9a9143bb7c5fce87 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 15 Jul 2025 13:30:16 -0700 Subject: [PATCH 2/3] change num_heads in tutorial --- intermediate_source/TP_tutorial.rst | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/intermediate_source/TP_tutorial.rst b/intermediate_source/TP_tutorial.rst index 4108e72b02b..25739ed704a 100644 --- a/intermediate_source/TP_tutorial.rst +++ b/intermediate_source/TP_tutorial.rst @@ -128,9 +128,9 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c layer_tp_plan = { # by default ColwiseParallel input layouts is replicated # and RowwiseParallel output layouts is replicated - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), "attention.wo": RowwiseParallel(), "feed_forward.w1": ColwiseParallel(), "feed_forward.w2": RowwiseParallel(), @@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension. If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape. -For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/ ``wk``/ ``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads`` to local ``num_heads``. +For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for wq/ wk/ wv linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would set ``use_local_output=False`` to let the output to be a DTensor. Compared to normal plain tensor, DTensor has knowledge about the parallelism plans, and will handle the ``num_heads`` dimension change under the hood. Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary: @@ -150,11 +150,6 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each `` for layer_id, transformer_block in enumerate(model.layers): layer_tp_plan = {...} # i.e. the plan we just generated - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh, @@ -219,12 +214,12 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS # to represent the input/output tensors sharded on the sequence dimension "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), ), - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( From 630e1d23701502a2664a5a707ba3a86aa37beaae Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 15 Jul 2025 19:33:51 -0700 Subject: [PATCH 3/3] rewrite --- intermediate_source/TP_tutorial.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/intermediate_source/TP_tutorial.rst b/intermediate_source/TP_tutorial.rst index 25739ed704a..846b25502e4 100644 --- a/intermediate_source/TP_tutorial.rst +++ b/intermediate_source/TP_tutorial.rst @@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension. If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape. -For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for wq/ wk/ wv linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would set ``use_local_output=False`` to let the output to be a DTensor. Compared to normal plain tensor, DTensor has knowledge about the parallelism plans, and will handle the ``num_heads`` dimension change under the hood. +For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension. Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary: @@ -328,7 +328,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module - from torch.distributed.fsdp import fully_shard + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP # i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP mesh_2d = init_device_mesh("cuda", (8, 8)) @@ -342,7 +342,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w # apply Tensor Parallel intra-host on tp_mesh model_tp = parallelize_module(model, tp_mesh, tp_plan) # apply FSDP inter-host on dp_mesh - model_2d = fully_shard(model_tp, mesh=dp_mesh, ...) + model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...) This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model.