Closed
Description
Issue description
I am trying to use StableHLOCompositeBuilder
to capture pytorch module hierarchy to stablehlo export artifact but calling a module that has another StableHLOCompositeBuilder
doesn't work with firing this nonobvious error message: does not have a ordering number in its outer func
. I wonder this is not supported by design or I hit some corner cases.
Code example
import torch
from torch_xla.stablehlo import exported_program_to_stablehlo
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
builder = StableHLOCompositeBuilder("SubModule")
x, y = builder.mark_inputs(x, y)
out = x + y
out = builder.mark_outputs(out)
return out
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = SubModule()
def forward(self, x, y):
builder = StableHLOCompositeBuilder("Model")
x, y = builder.mark_inputs(x, y)
a = x + y
b = x - y
c = self.submodule(a, b)
a, b, c = builder.mark_outputs(a, b, c)
return a + b + c
sample_input = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32))
exported = torch.export.export(Model(), sample_input)
stablehlo_program = exported_program_to_stablehlo(exported)
print(stablehlo_program.get_stablehlo_text('forward'))
Error message
$ python export_nested.py
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714149188.639784 26257 cpu_client.cc:405] TfrtCpuClient created.
loc("custom-call.23"): error: does not have a ordering number in its outer func.
error: failed to build composite.
Traceback (most recent call last):
File "/home/user/export_nested.py", line 36, in <module>
stablehlo_program = exported_program_to_stablehlo(exported)
File "/miniconda3/envs/venv/lib/python3.10/site-packages/torch_xla/stablehlo.py", line 568, in exported_program_to_stablehlo
bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
File "/miniconda3/envs/venv/lib/python3.10/site-packages/torch_xla/stablehlo.py", line 356, in _exported_program_to_stablehlo_bundle
stablehlo_content = xm.get_stablehlo_bytecode(res)
File "/miniconda3/envs/venv/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1112, in get_stablehlo_bytecode
return torch_xla._XLAC._get_stablehlo(
RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:107 : Check failed: status.ok()
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
*** End stack trace ***
MHLO -> StableHLO conversion failed.
StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA.
Original HLO dump:
HloModule IrToHlo.35, entry_computation_layout={(f32[1,1,32,32]{3,2,1,0}, f32[1,1,32,32]{3,2,1,0})->(f32[1,1,32,32]{3,2,1,0})}
ENTRY %IrToHlo.35 (p0.8: f32[1,1,32,32], p1.10: f32[1,1,32,32]) -> (f32[1,1,32,32]) {
%p1.10 = f32[1,1,32,32]{3,2,1,0} parameter(1)
%custom-call.11 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %p1.10), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 0, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": true, "attr": null}
%p0.8 = f32[1,1,32,32]{3,2,1,0} parameter(0)
%custom-call.9 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %p0.8), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 1, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": true, "attr": null}
%constant.15 = f32[] constant(1)
%broadcast.16 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.15), dimensions={}
%multiply.17 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.9, f32[1,1,32,32]{3,2,1,0} %broadcast.16)
%add.18 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %custom-call.11, f32[1,1,32,32]{3,2,1,0} %multiply.17)
%custom-call.27 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %add.18), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 0, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": false, "attr": null}
%constant.3 = f32[] constant(1)
%reshape.4 = f32[1,1,1,1]{3,2,1,0} reshape(f32[] %constant.3)
%broadcast.5 = f32[1,1,1,1]{3,2,1,0} broadcast(f32[1,1,1,1]{3,2,1,0} %reshape.4), dimensions={0,1,2,3}
%reshape.6 = f32[1,1]{1,0} reshape(f32[1,1,1,1]{3,2,1,0} %broadcast.5)
%broadcast.7 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[1,1]{1,0} %reshape.6), dimensions={0,1}
%multiply.12 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.9, f32[1,1,32,32]{3,2,1,0} %broadcast.7)
%subtract.13 = f32[1,1,32,32]{3,2,1,0} subtract(f32[1,1,32,32]{3,2,1,0} %custom-call.11, f32[1,1,32,32]{3,2,1,0} %multiply.12)
%custom-call.26 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %subtract.13), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 1, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": false, "attr": null}
%constant.25 = f32[] constant(1)
%broadcast.28 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.25), dimensions={}
%multiply.29 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.26, f32[1,1,32,32]{3,2,1,0} %broadcast.28)
%add.30 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %custom-call.27, f32[1,1,32,32]{3,2,1,0} %multiply.29)
%custom-call.19 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %add.18), custom_call_target="xla_mark_tensor", backend_config={"name": "SubModule", "pos": 0, "id": "8cee2e7f05fb42179262200522ba886e", "is_input": true, "attr": null}
%custom-call.14 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %subtract.13), custom_call_target="xla_mark_tensor", backend_config={"name": "SubModule", "pos": 1, "id": "8cee2e7f05fb42179262200522ba886e", "is_input": true, "attr": null}
%constant.2 = f32[] constant(1)
%broadcast.20 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.2), dimensions={}
%multiply.21 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.14, f32[1,1,32,32]{3,2,1,0} %broadcast.20)
%add.22 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %custom-call.19, f32[1,1,32,32]{3,2,1,0} %multiply.21)
%custom-call.23 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %add.22), custom_call_target="xla_mark_tensor", backend_config={"name": "SubModule", "pos": 0, "id": "8cee2e7f05fb42179262200522ba886e", "is_input": false, "attr": null}
%custom-call.24 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %custom-call.23), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 2, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": false, "attr": null}
%constant.1 = f32[] constant(1)
%broadcast.31 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.1), dimensions={}
%multiply.32 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.24, f32[1,1,32,32]{3,2,1,0} %broadcast.31)
%add.33 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %add.30, f32[1,1,32,32]{3,2,1,0} %multiply.32)
ROOT %tuple.34 = (f32[1,1,32,32]{3,2,1,0}) tuple(f32[1,1,32,32]{3,2,1,0} %add.33)
}
I0000 00:00:1714149188.968508 26257 cpu_client.cc:408] TfrtCpuClient destroyed.
System Info
- python3.10
- torch==2.3.0
- torch_xla==2.3.0