Skip to content

Commit 567e4fb

Browse files
torch.compile support for ScaledGroupedMMTensor
1 parent c57226b commit 567e4fb

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
["does.not.exist"],
3636
],
3737
)
38-
def test_moe_float8_training(target_fqns: list[str]):
38+
@pytest.mark.parametrize(
39+
"compile", [False, True]
40+
)
41+
def test_moe_float8_training(target_fqns: list[str], compile: bool):
3942
model_args = TransformerModelArgs(
4043
moe_enabled=True,
4144
num_experts=8,
@@ -73,6 +76,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
7376
target_fqns=target_fqns,
7477
)
7578

79+
if compile:
80+
model = torch.compile(model, fullgraph=False)
81+
ref_model = torch.compile(ref_model, fullgraph=False)
82+
7683
# inputs
7784
batch, seq, dim = 8, 2048, 256
7885
ref_x = torch.randn(

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242
for block_size_cols in block_sizes
4343
]
4444

45+
from torch.library import triton_op, wrap_triton
4546

47+
48+
49+
@triton_op("torchao::triton_fp8_row_major_jagged_rowwise_scales", mutates_args={})
4650
def triton_fp8_row_major_jagged_rowwise_scales(
4751
hp_tensor: torch.Tensor,
4852
offsets: torch.Tensor,
@@ -90,7 +94,7 @@ def triton_fp8_row_major_jagged_rowwise_scales(
9094
triton.cdiv(m, meta["BLOCK_SIZE_ROWS"]),
9195
offsets.numel(),
9296
)
93-
_triton_fp8_row_major_jagged_rowwise_scales[grid](
97+
wrap_triton(_triton_fp8_row_major_jagged_rowwise_scales)[grid](
9498
hp_tensor,
9599
offsets,
96100
output_buffer,
@@ -203,7 +207,7 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
203207
)
204208
tl.store(out_ptr + out_offs, fp8_data, mask=block_mask)
205209

206-
210+
@triton_op("torchao::triton_fp8_col_major_jagged_colwise_scales", mutates_args={})
207211
def triton_fp8_col_major_jagged_colwise_scales(
208212
hp_tensor: torch.Tensor,
209213
offsets: torch.Tensor,
@@ -251,7 +255,7 @@ def triton_fp8_col_major_jagged_colwise_scales(
251255
triton.cdiv(n, meta["BLOCK_SIZE_COLS"]),
252256
offsets.numel(),
253257
)
254-
_triton_fp8_col_major_jagged_colwise_scales[grid](
258+
wrap_triton(_triton_fp8_col_major_jagged_colwise_scales)[grid](
255259
hp_tensor,
256260
offsets,
257261
output_buffer,

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def _scaled_grouped_mm(
4040
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242
"""
43-
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
44-
logger.info("Using scaled_grouped_mm")
4543
return _Float8GroupedMM.apply(
4644
A,
4745
B_t,

torchao/prototype/moe_training/tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def __repr__(self):
123123
return f"ScaledGroupedMMTensor(data={self._data})"
124124

125125
def __tensor_flatten__(self):
126-
return ["_data"]
126+
# Metadata is empty but needed to make the subclass traceable for torch.compile.
127+
metadata = {}
128+
return ["_data"], metadata
127129

128130
@staticmethod
129131
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):

0 commit comments

Comments
 (0)