Skip to content

Using StableHLOCompositeBuilder in nested way.  #6978

Closed
@tlsdmstn56

Description

@tlsdmstn56

Issue description

I am trying to use StableHLOCompositeBuilder to capture pytorch module hierarchy to stablehlo export artifact but calling a module that has another StableHLOCompositeBuilder doesn't work with firing this nonobvious error message: does not have a ordering number in its outer func. I wonder this is not supported by design or I hit some corner cases.

Code example

import torch                        
from torch_xla.stablehlo import exported_program_to_stablehlo 
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder    
    
class SubModule(torch.nn.Module):    
    
    def __init__(self):    
        super().__init__()    
    
    def forward(self, x, y):    
        builder = StableHLOCompositeBuilder("SubModule")    
        x, y = builder.mark_inputs(x, y)    
        out = x + y    
        out = builder.mark_outputs(out)    
        return out    
    
class Model(torch.nn.Module):    
    
    def __init__(self):    
        super().__init__()    
        self.submodule = SubModule()    
    
    def forward(self, x, y):    
        builder = StableHLOCompositeBuilder("Model")    
        x, y = builder.mark_inputs(x, y)    
        a = x + y    
        b = x - y    
        c = self.submodule(a, b)    
        a, b, c = builder.mark_outputs(a, b, c)    
        return a + b + c    
    
sample_input = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32))    
exported = torch.export.export(Model(), sample_input)    
stablehlo_program = exported_program_to_stablehlo(exported)    
print(stablehlo_program.get_stablehlo_text('forward'))

Error message

$ python export_nested.py 
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR 
I0000 00:00:1714149188.639784   26257 cpu_client.cc:405] TfrtCpuClient created.
loc("custom-call.23"): error: does not have a ordering number in its outer func.
error: failed to build composite.
Traceback (most recent call last):
  File "/home/user/export_nested.py", line 36, in <module>
    stablehlo_program = exported_program_to_stablehlo(exported)
  File "/miniconda3/envs/venv/lib/python3.10/site-packages/torch_xla/stablehlo.py", line 568, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
  File "/miniconda3/envs/venv/lib/python3.10/site-packages/torch_xla/stablehlo.py", line 356, in _exported_program_to_stablehlo_bundle
    stablehlo_content = xm.get_stablehlo_bytecode(res)
  File "/miniconda3/envs/venv/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1112, in get_stablehlo_bytecode
    return torch_xla._XLAC._get_stablehlo(
RuntimeError: torch_xla/csrc/runtime/stablehlo_helper.cc:107 : Check failed: status.ok()
*** Begin stack trace ***
  tsl::CurrentStackTrace()
  torch_xla::ConvertHloToStableHlo(xla::HloModuleProto const*, mlir::ModuleOp*)
  torch_xla::hloToStablehlo(xla::HloModuleProto const*, bool)
  torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
  torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)




  _PyObject_MakeTpCall
  _PyEval_EvalFrameDefault
  _PyFunction_Vectorcall
  _PyEval_EvalFrameDefault
  _PyFunction_Vectorcall
  _PyEval_EvalFrameDefault
  _PyFunction_Vectorcall
  _PyEval_EvalFrameDefault

  PyEval_EvalCode



  _PyRun_SimpleFileObject
  _PyRun_AnyFileObject
  Py_RunMain
  Py_BytesMain
  __libc_start_main

*** End stack trace ***
MHLO -> StableHLO conversion failed.
StableHLO Module from MHLO -> StableHLO conversion is not leagal.Please open a github issue to PyTorch/XLA.
Original HLO dump:
HloModule IrToHlo.35, entry_computation_layout={(f32[1,1,32,32]{3,2,1,0}, f32[1,1,32,32]{3,2,1,0})->(f32[1,1,32,32]{3,2,1,0})}

ENTRY %IrToHlo.35 (p0.8: f32[1,1,32,32], p1.10: f32[1,1,32,32]) -> (f32[1,1,32,32]) {
  %p1.10 = f32[1,1,32,32]{3,2,1,0} parameter(1)
  %custom-call.11 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %p1.10), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 0, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": true, "attr": null}
  %p0.8 = f32[1,1,32,32]{3,2,1,0} parameter(0)
  %custom-call.9 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %p0.8), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 1, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": true, "attr": null}
  %constant.15 = f32[] constant(1)
  %broadcast.16 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.15), dimensions={}
  %multiply.17 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.9, f32[1,1,32,32]{3,2,1,0} %broadcast.16)
  %add.18 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %custom-call.11, f32[1,1,32,32]{3,2,1,0} %multiply.17)
  %custom-call.27 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %add.18), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 0, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": false, "attr": null}
  %constant.3 = f32[] constant(1)
  %reshape.4 = f32[1,1,1,1]{3,2,1,0} reshape(f32[] %constant.3)
  %broadcast.5 = f32[1,1,1,1]{3,2,1,0} broadcast(f32[1,1,1,1]{3,2,1,0} %reshape.4), dimensions={0,1,2,3}
  %reshape.6 = f32[1,1]{1,0} reshape(f32[1,1,1,1]{3,2,1,0} %broadcast.5)
  %broadcast.7 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[1,1]{1,0} %reshape.6), dimensions={0,1}
  %multiply.12 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.9, f32[1,1,32,32]{3,2,1,0} %broadcast.7)
  %subtract.13 = f32[1,1,32,32]{3,2,1,0} subtract(f32[1,1,32,32]{3,2,1,0} %custom-call.11, f32[1,1,32,32]{3,2,1,0} %multiply.12)
  %custom-call.26 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %subtract.13), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 1, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": false, "attr": null}
  %constant.25 = f32[] constant(1)
  %broadcast.28 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.25), dimensions={}
  %multiply.29 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.26, f32[1,1,32,32]{3,2,1,0} %broadcast.28)
  %add.30 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %custom-call.27, f32[1,1,32,32]{3,2,1,0} %multiply.29)
  %custom-call.19 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %add.18), custom_call_target="xla_mark_tensor", backend_config={"name": "SubModule", "pos": 0, "id": "8cee2e7f05fb42179262200522ba886e", "is_input": true, "attr": null}
  %custom-call.14 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %subtract.13), custom_call_target="xla_mark_tensor", backend_config={"name": "SubModule", "pos": 1, "id": "8cee2e7f05fb42179262200522ba886e", "is_input": true, "attr": null}
  %constant.2 = f32[] constant(1)
  %broadcast.20 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.2), dimensions={}
  %multiply.21 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.14, f32[1,1,32,32]{3,2,1,0} %broadcast.20)
  %add.22 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %custom-call.19, f32[1,1,32,32]{3,2,1,0} %multiply.21)
  %custom-call.23 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %add.22), custom_call_target="xla_mark_tensor", backend_config={"name": "SubModule", "pos": 0, "id": "8cee2e7f05fb42179262200522ba886e", "is_input": false, "attr": null}
  %custom-call.24 = f32[1,1,32,32]{3,2,1,0} custom-call(f32[1,1,32,32]{3,2,1,0} %custom-call.23), custom_call_target="xla_mark_tensor", backend_config={"name": "Model", "pos": 2, "id": "6066fe3b63d64105a8ccecc7ca0d9eb1", "is_input": false, "attr": null}
  %constant.1 = f32[] constant(1)
  %broadcast.31 = f32[1,1,32,32]{3,2,1,0} broadcast(f32[] %constant.1), dimensions={}
  %multiply.32 = f32[1,1,32,32]{3,2,1,0} multiply(f32[1,1,32,32]{3,2,1,0} %custom-call.24, f32[1,1,32,32]{3,2,1,0} %broadcast.31)
  %add.33 = f32[1,1,32,32]{3,2,1,0} add(f32[1,1,32,32]{3,2,1,0} %add.30, f32[1,1,32,32]{3,2,1,0} %multiply.32)
  ROOT %tuple.34 = (f32[1,1,32,32]{3,2,1,0}) tuple(f32[1,1,32,32]{3,2,1,0} %add.33)
}


I0000 00:00:1714149188.968508   26257 cpu_client.cc:408] TfrtCpuClient destroyed.

System Info

  • python3.10
  • torch==2.3.0
  • torch_xla==2.3.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requeststablehloStableHLO related work

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions