From a434ef950b08cc6de3a91f1a21e7e5f500d7630d Mon Sep 17 00:00:00 2001 From: Michael Adragna Date: Thu, 7 Aug 2025 16:47:21 -0700 Subject: [PATCH] Add warning when using deprecated to_edge and to_backend methods --- .../xnnpack/partition/xnnpack_partitioner.py | 33 +++++++- .../xnnpack/test/test_xnnpack_partitioner.py | 84 +++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 backends/xnnpack/test/test_xnnpack_partitioner.py diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index e5532e17f36..44207e2247a 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import inspect import itertools - import logging from typing import List, Optional, Type, Union @@ -65,6 +65,37 @@ def __init__( self.per_op_mode = per_op_mode super().__init__(delegation_spec, initialized_configs) + def _check_if_called_from_to_backend(self) -> bool: + """ + Check if the partition method is being called from the deprecated to_backend workflow. + Returns True if called from deprecated direct to_backend, False if called from to_edge_transform_and_lower. + """ + stack = inspect.stack() + + for frame_info in stack: + if frame_info.function == "to_edge_transform_and_lower": + return False + + for frame_info in stack: + if frame_info.function == "to_backend": + filename = frame_info.filename + if "program/_program.py" in filename: + return True + return False + + def partition(self, exported_program): + """ + Override partition to add deprecation warning when called from to_backend. + """ + # Check if we're being called from the deprecated to_backend workflow + if self._check_if_called_from_to_backend(): + logger.warning( + "\nDEPRECATION WARNING: You are using the deprecated 'to_edge() + to_backend()' workflow. " + "Please consider migrating to 'to_edge_transform_and_lower()' for better error handling and optimization. " + ) + + return super().partition(exported_program) + def generate_partitions(self, ep: ExportedProgram) -> List[Partition]: """ generate_partitions is different if partitioner is set to per_op_mode diff --git a/backends/xnnpack/test/test_xnnpack_partitioner.py b/backends/xnnpack/test/test_xnnpack_partitioner.py new file mode 100644 index 00000000000..8cd9eb92d56 --- /dev/null +++ b/backends/xnnpack/test/test_xnnpack_partitioner.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import logging +import unittest + +import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge, to_edge_transform_and_lower +from torch.export import export + + +class TestXnnpackPartitioner(unittest.TestCase): + """Test cases for XnnpackPartitioner functionality and deprecation warnings.""" + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + def test_deprecation_warning_for_to_backend_workflow(self): + """ + Test that the deprecated to_edge + to_backend workflow shows a deprecation warning. + """ + model = self.SimpleModel() + x = torch.randn(1, 10) + + exported_model = export(model, (x,)) + + # Capture log output to check for deprecation warning + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + ch.setLevel(logging.WARNING) + + logger = logging.getLogger( + "executorch.backends.xnnpack.partition.xnnpack_partitioner" + ) + logger.addHandler(ch) + logger.setLevel(logging.WARNING) + + edge = to_edge(exported_model) + partitioner = XnnpackPartitioner() + + edge.to_backend(partitioner) + + log_contents = log_capture_string.getvalue() + self.assertIn("DEPRECATION WARNING", log_contents) + self.assertIn("to_edge() + to_backend()", log_contents) + self.assertIn("to_edge_transform_and_lower()", log_contents) + + def test_no_warning_for_to_edge_transform_and_lower_workflow(self): + """ + Test that the recommended to_edge_transform_and_lower workflow does NOT show a deprecation warning. + """ + + model = self.SimpleModel() + x = torch.randn(1, 10) + + exported_model = export(model, (x,)) + + # Capture log output to check for deprecation warning + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + ch.setLevel(logging.WARNING) + + logger = logging.getLogger( + "executorch.backends.xnnpack.partition.xnnpack_partitioner" + ) + logger.addHandler(ch) + logger.setLevel(logging.WARNING) + + partitioner = XnnpackPartitioner() + + to_edge_transform_and_lower(exported_model, partitioner=[partitioner]) + + log_contents = log_capture_string.getvalue() + self.assertNotIn("DEPRECATION WARNING", log_contents)