diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index 508d8cbb371c..e5f53d04aea1 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -57,6 +57,7 @@ def visualize_sharding(sharding: str, # eg: '{devices=[2,2]0,1,2,3}' # eg: '{replicated}' # eg: '{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}' + print(f"Visualizing {sharding} (showing up to the first two dimensions)") if sharding == '{replicated}' or len(sharding) == 0: heights = 1 widths = 1 @@ -64,7 +65,8 @@ def visualize_sharding(sharding: str, device_ids = list(range(num_devices)) slices.setdefault((0, 0), device_ids) else: - sharding_spac = sharding[sharding.index('['):sharding.index(']') + 1] + sharding_spec = sharding[sharding.index('[') + + 1:sharding.index(']')].split(",") device_list_original = sharding.split(' last_tile_dim_replicate') if len(device_list_original) == 2 and device_list_original[1] == '}': try: @@ -72,9 +74,9 @@ def visualize_sharding(sharding: str, device_list = device_list_original_first[device_list_original_first. index(']') + 1:] device_indices_map = [int(s) for s in device_list.split(',')] - heights = int(sharding_spac[1]) - widths = int(sharding_spac[3]) - last_dim_depth = int(sharding_spac[5]) + heights = int(sharding_spec[0]) + widths = int(sharding_spec[1]) + last_dim_depth = int(sharding_spec[-1]) devices_len = len(device_indices_map) len_after_dim_down = devices_len // last_dim_depth for i in range(len_after_dim_down): @@ -96,8 +98,8 @@ def visualize_sharding(sharding: str, device_list = device_list_original_first[device_list_original_first. index(']') + 1:-1] device_indices_map = [int(i) for i in device_list.split(',')] - heights = int(sharding_spac[1]) - widths = int(sharding_spac[3]) + heights = int(sharding_spec[0]) + widths = int(sharding_spec[1]) devices_len = len(device_indices_map) for i in range(devices_len): slices.setdefault((i // widths, i % widths), device_indices_map[i])