-
Notifications
You must be signed in to change notification settings - Fork 53
Description
Aten op doesn't fallback to native pytorch runtime as expected.
Versions:
Torch - 1.12.0
OnnxRuntime - 1.12.0
Torch-ort-infer - 1.12.0
Reproduction steps:
import torch
from torch_ort import ORTInferenceModule
def test_numpy_T(input_shape):
class NeuralNet(torch.nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
def forward(self, input):
return input.T
device = "cpu"
ort_model = ORTInferenceModule(NeuralNet().to(device))
def run_step(model, input):
prediction = model(input)
return prediction
ort_input = torch.rand(input_shape, dtype=torch.float, device=device)
ort_prediction = run_step(ort_model, ort_input)
if __name__ == "__main__":
test_numpy_T([3, 2, 5])
Error log
Traceback (most recent call last):
File "unit_test_atenop.py", line 23, in
test_numpy_T([3, 2, 5])
File "unit_test_atenop.py", line 20, in test_numpy_T
ort_prediction = run_step(ort_model, ort_input)
File "unit_test_atenop.py", line 16, in run_step
prediction = model(input)
File "/ort_aten_fb/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/_utils_infer.py", line 98, in _forward
return ortinferencemodule._forward_call(*inputs, **kwargs)
File "/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/ortinferencemodule.py", line 107, in _forward_call
self._inference_session = onnxruntime.InferenceSession(
File "/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 347, in init
self._create_inference_session(providers, provider_options, disabled_optimizers)
File "/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 386, in create_inference_session
sess = C.InferenceSession(session_options, self.model_bytes, False, self.read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (ATen_0) output arg (data) type inference failed.
Tested with symbolic shape inference call from ORTModule(ref: symbolic_shape). Fails with Exception("Incomplete symbolic shape inference").