Skip to content

[NFC][mlir][mesh,shard] Fixing misnomers in mesh dialect, renaming 'mesh' dialect to 'shard' #150177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 0 additions & 74 deletions mlir/docs/Dialects/Mesh.md

This file was deleted.

92 changes: 92 additions & 0 deletions mlir/docs/Dialects/Shard.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# 'shard' Dialect

The 'shard' dialect defines a set of attributes, operations, and interfaces for
working with tensor sharding and device communication.

It’s inspired by [GSPMD](*General and Scalable Parallelization for ML Computation Graphs*).

Originally, the dialect was called `mesh`, but it was renamed to better reflect
what it actually does.

[TOC]

## Collective Communication Operations

The 'shard' dialect includes several collective operations that help coordinate
communication between devices arranged in a grid.

If you’re not already familiar with collective operations, [this Wikipedia
article](https://en.wikipedia.org/wiki/Collective_operation) is a good starting
point.

Unlike traditional collectives that are defined in terms of message-passing
between explicit buffers on each process, the collectives in this dialect work
at a higher level. They’re defined in terms of how data moves across the
dimensions of a tensor, and the participating processes are inferred from how
the tensor is sharded - not specified manually.

### Device Groups

Each collective operation runs within a group of devices. You define groups
using the `grid` and `grid_axes` attributes, which describe how to slice the
full device grid into smaller groups.

Devices that have the same coordinates *outside* the listed `grid_axes` belong
to the same group.

Example: Say your device grid is shaped `2×3×4×5`, and you set
`grid_axes = [0, 1]`. This splits the grid into groups by fixing axes 2 and 3. You’d get groups like:

```
{ { (i, j, k, m) | 0 ≤ i < 2, 0 ≤ j < 3 } | 0 ≤ k < 4, 0 ≤ m < 5 }
```

So the groups are identified by the coordinates `(k, m)`, and devices like
`(1, 0, 2, 3)` and `(1, 1, 2, 3)` are in the same group. But `(1, 0, 2, 4)`
is in a different group.

For some collectives (like `all-to-all`), the order of devices in the group
matters. The device order is based on the order of axes in `grid_axes`, from
outermost to innermost.

Example: If `grid_axes = [3, 1]`, then device `(i, 1, k, 0)` comes before
`(i, 0, k, 1)` and `(i, 2, k, 0)`.

### In-group Devices

Some operations (like `broadcast`, `scatter`, and `send`) refer to a specific
device within each group. These in-group devices are identified using their
coordinates over the axes listed in `grid_axes`.

Example: In a 3D grid with `grid_axes = [0, 2]`, an in-group device is specified
as `(i, j)`. If a group is fixed at coordinate `g` on axis 1, then the full
device index would be `(i, g, j)`.

### Purity and Execution Model

Collective operations involve all devices in a group (e.g. `all-gather`,
`all-to-all`) and are considered pure. Operations like `send` and `recv` are not
collective and are not pure.

The execution model assumes SPMD (Single Program, Multiple Data):

* Every process runs the same program.
* At any collective operation, all processes are in sync.

This means compiler optimizations must treat collective ops carefully. For
example, if a collective is removed during optimization, it must be removed from
*every* path and *every* process that would have participated - otherwise, you’ll
get undefined behavior at runtime.

Marking these ops as pure also helps with standard compiler passes like dead
code elimination and common subexpression elimination. It ensures that when the
program is executed, all devices hit the same line of code at the same time
during collectives and so avoid dead-locks.

## Operations

[include "Dialects/ShardOps.md"]

## Attributes

[include "Dialects/ShardAttrs.md"]
4 changes: 2 additions & 2 deletions mlir/docs/Passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ This document describes the available MLIR passes and their contracts.

[include "MemRefPasses.md"]

## 'mesh' Dialect Passes
## 'shard' Dialect Passes

[include "MeshPasses.md"]
[include "ShardPasses.md"]

## 'ml\_program' Dialect Passes

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
Expand All @@ -66,6 +65,7 @@
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/ShardToMPI/ShardToMPI.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -903,13 +903,13 @@ def ConvertMemRefToSPIRVPass : Pass<"convert-memref-to-spirv"> {
}

//===----------------------------------------------------------------------===//
// MeshToMPI
// ShardToMPI
//===----------------------------------------------------------------------===//

def ConvertMeshToMPIPass : Pass<"convert-mesh-to-mpi"> {
let summary = "Convert Mesh dialect to MPI dialect.";
def ConvertShardToMPIPass : Pass<"convert-shard-to-mpi"> {
let summary = "Convert Shard dialect to MPI dialect.";
let description = [{
This pass converts communication operations from the Mesh dialect to the
This pass converts communication operations from the Shard dialect to the
MPI dialect.
If it finds the DLTI attribute "MPI:comm_world-rank" on the module it will
use that integer value instead of calling MPI_Comm_rank. This allows
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
//===- MeshToMPI.h - Convert Mesh to MPI dialect ----------------*- C++ -*-===//
//===- ShardToMPI.h - Convert Shard to MPI dialect --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
#define MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
#ifndef MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
#define MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H

#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMESHTOMPIPASS
#define GEN_PASS_DECL_CONVERTSHARDTOMPIPASS
#include "mlir/Conversion/Passes.h.inc"

} // namespace mlir

#endif // MLIR_CONVERSION_MESHTOMPI_MESHTOMPI_H
#endif // MLIR_CONVERSION_SHARDTOMPI_SHARDTOMPI_H
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
add_subdirectory(Mesh)
add_subdirectory(Shard)
add_subdirectory(MLProgram)
add_subdirectory(MPI)
add_subdirectory(NVGPU)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- MeshShardingExtensions.h - -----------------------------------------===//
//===- ShardingExtensions.h - -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
//===- ShardingInterfaceImpl.h ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
#ifndef MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace linalg {
void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
void registerShardingInterfaceExternalModels(DialectRegistry &registry);
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
#endif // MLIR_DIALECT_LINALG_SHARDSHARDINGINTERFACEIMPL_H
25 changes: 0 additions & 25 deletions mlir/include/mlir/Dialect/Mesh/IR/CMakeLists.txt

This file was deleted.

6 changes: 0 additions & 6 deletions mlir/include/mlir/Dialect/Mesh/Transforms/CMakeLists.txt

This file was deleted.

25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/Shard/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
add_mlir_doc(ShardOps ShardOps Dialects/ -gen-op-doc -dialect=shard)
add_mlir_doc(ShardOps ShardAttrs Dialects/ -gen-attrdef-doc -dialect=shard)

set(LLVM_TARGET_DEFINITIONS ShardOps.td)
mlir_tablegen(ShardDialect.cpp.inc -gen-dialect-defs -dialect=shard)
mlir_tablegen(ShardDialect.h.inc -gen-dialect-decls -dialect=shard)

set(LLVM_TARGET_DEFINITIONS ShardBase.td)
mlir_tablegen(ShardAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(ShardAttributes.cpp.inc -gen-attrdef-defs)

set(LLVM_TARGET_DEFINITIONS ShardBase.td)
mlir_tablegen(ShardEnums.h.inc -gen-enum-decls)
mlir_tablegen(ShardEnums.cpp.inc -gen-enum-defs)

set(LLVM_TARGET_DEFINITIONS ShardBase.td)
mlir_tablegen(ShardTypes.h.inc -gen-typedef-decls)
mlir_tablegen(ShardTypes.cpp.inc -gen-typedef-defs)

set(LLVM_TARGET_DEFINITIONS ShardOps.td)
mlir_tablegen(ShardOps.h.inc -gen-op-decls)
mlir_tablegen(ShardOps.cpp.inc -gen-op-defs)

add_public_tablegen_target(MLIRShardIncGen)
add_dependencies(mlir-headers MLIRShardIncGen)
Loading