Skip to content

Commit b78fc75

Browse files
wwwjnsvekars
andauthored
Using DTensor to handle local num_heads change while TP is applied (#3465)
* fsdp1 -> fsdp2 * change num_heads in tutorial --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent c0e9be0 commit b78fc75

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
128128
layer_tp_plan = {
129129
# by default ColwiseParallel input layouts is replicated
130130
# and RowwiseParallel output layouts is replicated
131-
"attention.wq": ColwiseParallel(),
132-
"attention.wk": ColwiseParallel(),
133-
"attention.wv": ColwiseParallel(),
131+
"attention.wq": ColwiseParallel(use_local_output=False),
132+
"attention.wk": ColwiseParallel(use_local_output=False),
133+
"attention.wv": ColwiseParallel(use_local_output=False),
134134
"attention.wo": RowwiseParallel(),
135135
"feed_forward.w1": ColwiseParallel(),
136136
"feed_forward.w2": RowwiseParallel(),
@@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
141141
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.
142142
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.
143143

144-
For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/ ``wk``/ ``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads`` to local ``num_heads``.
144+
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension.
145145

146146
Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary:
147147

@@ -150,11 +150,6 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
150150
for layer_id, transformer_block in enumerate(model.layers):
151151
layer_tp_plan = {...} # i.e. the plan we just generated
152152
153-
# Adjust attention module to use the local number of heads
154-
attn_layer = transformer_block.attention
155-
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
156-
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
157-
158153
parallelize_module(
159154
module=transformer_block,
160155
device_mesh=tp_mesh,
@@ -219,12 +214,12 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
219214
# to represent the input/output tensors sharded on the sequence dimension
220215
"attention_norm": SequenceParallel(),
221216
"attention": PrepareModuleInput(
222-
input_layouts=(Shard(1),),
223-
desired_input_layouts=(Replicate(),),
217+
input_layouts=(Shard(1), Replicate()),
218+
desired_input_layouts=(Replicate(), Replicate()),
224219
),
225-
"attention.wq": ColwiseParallel(),
226-
"attention.wk": ColwiseParallel(),
227-
"attention.wv": ColwiseParallel(),
220+
"attention.wq": ColwiseParallel(use_local_output=False),
221+
"attention.wk": ColwiseParallel(use_local_output=False),
222+
"attention.wv": ColwiseParallel(use_local_output=False),
228223
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
229224
"ffn_norm": SequenceParallel(),
230225
"feed_forward": PrepareModuleInput(

0 commit comments

Comments
 (0)