From 89fd536e6d72c68371a0cf0a11320149a86eac09 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 15 Jul 2025 13:17:11 -0700 Subject: [PATCH] remove manual n_heads change --- distributed/tensor_parallelism/fsdp_tp_example.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 4ae6cb1aa2..154cee169e 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -121,12 +121,12 @@ layer_tp_plan = { "attention_norm": SequenceParallel(), "attention": PrepareModuleInput( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), ), - "attention.wq": ColwiseParallel(), - "attention.wk": ColwiseParallel(), - "attention.wv": ColwiseParallel(), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), "attention.wo": RowwiseParallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), "feed_forward": PrepareModuleInput( @@ -138,11 +138,6 @@ "feed_forward.w3": ColwiseParallel(), } - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - # Custom parallelization plan for the model parallelize_module( module=transformer_block,