You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
141
141
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.
142
142
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.
143
143
144
-
For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads``to local ``num_heads``.
144
+
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension.
145
145
146
146
Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary:
147
147
@@ -150,11 +150,6 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
150
150
for layer_id, transformer_block inenumerate(model.layers):
151
151
layer_tp_plan = {...} # i.e. the plan we just generated
152
152
153
-
# Adjust attention module to use the local number of heads
0 commit comments