Skip to content

[torch-ort-infer] Aten fallback doesn't work #139

@saipj

Description

@saipj

Aten op doesn't fallback to native pytorch runtime as expected.

Versions:
Torch - 1.12.0
OnnxRuntime - 1.12.0
Torch-ort-infer - 1.12.0

Reproduction steps:

import torch
from torch_ort import ORTInferenceModule

def test_numpy_T(input_shape):
    class NeuralNet(torch.nn.Module):
        def __init__(self):
            super(NeuralNet, self).__init__()
        def forward(self, input):
            return input.T

    device = "cpu"
    ort_model = ORTInferenceModule(NeuralNet().to(device))

    def run_step(model, input):
        prediction = model(input)
        return prediction

    ort_input = torch.rand(input_shape, dtype=torch.float, device=device)
    ort_prediction = run_step(ort_model, ort_input)

if __name__ == "__main__":
    test_numpy_T([3, 2, 5])

Error log

Traceback (most recent call last):
File "unit_test_atenop.py", line 23, in
test_numpy_T([3, 2, 5])
File "unit_test_atenop.py", line 20, in test_numpy_T
ort_prediction = run_step(ort_model, ort_input)
File "unit_test_atenop.py", line 16, in run_step
prediction = model(input)
File "/ort_aten_fb/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/_utils_infer.py", line 98, in _forward
return ortinferencemodule._forward_call(*inputs, **kwargs)
File "/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/ortinferencemodule.py", line 107, in _forward_call
self._inference_session = onnxruntime.InferenceSession(
File "/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in init
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 386, in create_inference_session
sess = C.InferenceSession(session_options, self.model_bytes, False, self.read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail
: [ONNXRuntimeError] : 1 : FAIL : Node (ATen_0) output arg (data) type inference failed.

Tested with symbolic shape inference call from ORTModule(ref: symbolic_shape). Fails with Exception("Incomplete symbolic shape inference").

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions