diff --git a/backends/aoti/TARGETS b/backends/aoti/TARGETS new file mode 100644 index 00000000000..68ca395a0c4 --- /dev/null +++ b/backends/aoti/TARGETS @@ -0,0 +1,31 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +runtime.python_library( + name = "aoti_delegate", + srcs = [ + "__init__.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/export:lib", + "//executorch/backends/aoti/recipes:aoti_recipe_provider", + "//executorch/backends/aoti/recipes:aoti_recipe_types", + ], +) + +runtime.python_library( + name = "aoti_delegate_module", + srcs = [ + "aoti_delegate_module.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:libtorch", + ], +) diff --git a/backends/aoti/__init__.py b/backends/aoti/__init__.py new file mode 100644 index 00000000000..3e5db87983b --- /dev/null +++ b/backends/aoti/__init__.py @@ -0,0 +1,11 @@ +from executorch.export import recipe_registry + +from .recipes.aoti_recipe_provider import AOTIRecipeProvider +from .recipes.aoti_recipe_types import AOTIRecipeType + +# Auto-register AOTI recipe provider +recipe_registry.register_backend_recipe_provider(AOTIRecipeProvider()) + +__all__ = [ + "AOTIRecipeType", +] diff --git a/backends/aoti/aoti_delegate_module.py b/backends/aoti/aoti_delegate_module.py new file mode 100644 index 00000000000..acc76e92e29 --- /dev/null +++ b/backends/aoti/aoti_delegate_module.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import torch + +from torch.export import ExportedProgram +from torch.utils import _pytree as pytree + +# TODO: These should probably be in pytorch + + +class AOTInductorRunnerWrapper(torch.nn.Module): + # pyre-fixme[2]: Parameter must be annotated. + def __init__(self, aoti_runner) -> None: + super().__init__() + # pyre-fixme[4]: Attribute must be annotated. + self.aoti_runner = aoti_runner + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, *flat_inputs): + return self.aoti_runner.run(flat_inputs) + + +class AOTIDelegateModule(torch.nn.Module): + """ + This module is the primary artifact produced by AOTInductor lowering. + It is eagerly runnable in Python and traceable by torch.export. + It also contains all necessary information and metadata to be pacakged and consumed + by the delegate executor in runtime later. + + """ + + def __init__(self, exported_program: ExportedProgram, so_path: str) -> None: + super().__init__() + self.so_path = so_path + self.exported_program = exported_program + self.exported_program.graph_module.recompile() + + # register parameters + for name, parameter in self.exported_program.named_parameters(): + normalized_name = name.replace(".", "_") + self.register_parameter(normalized_name, parameter) + + # register buffers + non_persistent_buffer_names = ( + exported_program.graph_signature.non_persistent_buffers + ) + for name, buffer in self.exported_program.named_buffers(): + normalized_name = name.replace(".", "_") + if name in non_persistent_buffer_names: + self.register_buffer(normalized_name, buffer, persistent=False) + else: + self.register_buffer(normalized_name, buffer, persistent=True) + + # handle tensor constants + self.constant_names: list[str] = [] + for name, constant in self.exported_program.tensor_constants.items(): + # skip non-persistent buffers + if name in non_persistent_buffer_names: + continue + normalized_name = name.replace(".", "_") + setattr(self, normalized_name, constant) + self.constant_names.append(normalized_name) + + # pyre-ignore[4]: Missing attribute annotation + # pyre-ignore[16]: Undefined attribute + # TODO: CPU only for now. Add GPU + self.engine = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) + self.aoti_runner_wrapper = AOTInductorRunnerWrapper(self.engine) + + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def forward(self, *inputs): + weights_args = [ + *self.parameters(), + *self.buffers(), + ] + [getattr(self, const_name) for const_name in self.constant_names] + + flat_inputs = pytree.tree_flatten((inputs, {}))[0] + flat_outputs = torch._higher_order_ops.aoti_call_delegate( + self.aoti_runner_wrapper, + self.exported_program.graph_module, + weights_args, + flat_inputs, + ) + return pytree.tree_unflatten( + flat_outputs, self.exported_program.call_spec.out_spec + ) diff --git a/backends/aoti/recipes/TARGETS b/backends/aoti/recipes/TARGETS new file mode 100644 index 00000000000..0519459cf7d --- /dev/null +++ b/backends/aoti/recipes/TARGETS @@ -0,0 +1,33 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +runtime.python_library( + name = "aoti_recipe_provider", + srcs = [ + "aoti_recipe_provider.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//caffe2:torch", + "//executorch/export:lib", + ":aoti_recipe_types", + ], +) + +runtime.python_library( + name = "aoti_recipe_types", + srcs = [ + "aoti_recipe_types.py", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + deps = [ + "//executorch/export:lib", + ], +) diff --git a/backends/aoti/recipes/aoti_recipe_provider.py b/backends/aoti/recipes/aoti_recipe_provider.py new file mode 100644 index 00000000000..1c049afea8a --- /dev/null +++ b/backends/aoti/recipes/aoti_recipe_provider.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Optional, Sequence + +from executorch.backends.aoti.recipes.aoti_recipe_types import AOTIRecipeType + +from executorch.export import ( + BackendRecipeProvider, + ExportRecipe, + LoweringRecipe, + RecipeType, + StageType, +) + + +class AOTIRecipeProvider(BackendRecipeProvider): + @property + def backend_name(self) -> str: + return "aoti" + + def get_supported_recipes(self) -> Sequence[RecipeType]: + return [AOTIRecipeType.FP32] + + def create_recipe( + self, recipe_type: RecipeType, **kwargs: Any + ) -> Optional[ExportRecipe]: + """Create AOTI recipe""" + + if recipe_type not in self.get_supported_recipes(): + return None + + if recipe_type == AOTIRecipeType.FP32: + return self._build_fp32_recipe(recipe_type) + + def _get_aoti_lowering_recipe(self) -> LoweringRecipe: + return LoweringRecipe( + partitioners=None, + edge_transform_passes=None, + edge_compile_config=None, + ) + + def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe: + return ExportRecipe( + name=recipe_type.value, + lowering_recipe=self._get_aoti_lowering_recipe(), + pipeline_stages=[StageType.TORCH_EXPORT, StageType.AOTI_LOWERING], + ) diff --git a/backends/aoti/recipes/aoti_recipe_types.py b/backends/aoti/recipes/aoti_recipe_types.py new file mode 100644 index 00000000000..2f2fc281dc1 --- /dev/null +++ b/backends/aoti/recipes/aoti_recipe_types.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from executorch.export import RecipeType + + +class AOTIRecipeType(RecipeType): + """AOTInductor-specific recipe types""" + + FP32 = "fp32" + # more to be added... + + @classmethod + def get_backend_name(cls) -> str: + return "aoti" diff --git a/backends/aoti/test/TARGETS b/backends/aoti/test/TARGETS new file mode 100644 index 00000000000..8256615313f --- /dev/null +++ b/backends/aoti/test/TARGETS @@ -0,0 +1,16 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu") + +oncall("executorch") + +runtime.python_test( + name = "test_aoti_recipes", + srcs = [ + "recipes/test_aoti_recipes.py", + ], + deps = [ + "//executorch/backends/aoti:aoti_delegate", + "//executorch/export:lib", + "//executorch/examples/models:models", # @manual + ], +) diff --git a/backends/aoti/test/recipes/test_aoti_recipes.py b/backends/aoti/test/recipes/test_aoti_recipes.py new file mode 100644 index 00000000000..58a9cbcade2 --- /dev/null +++ b/backends/aoti/test/recipes/test_aoti_recipes.py @@ -0,0 +1,81 @@ +import unittest + +import torch + +from executorch.backends.aoti.recipes.aoti_recipe_types import AOTIRecipeType +from executorch.examples.models import MODEL_NAME_TO_MODEL +from executorch.examples.models.model_factory import EagerModelFactory +from executorch.export import export, ExportRecipe, StageType +from torch.testing._internal.common_quantization import TestHelperModules + + +class TestAotiRecipes(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def tearDown(self) -> None: + super().tearDown() + + def test_basic_recipe(self) -> None: + m_eager = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=m_eager, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(AOTIRecipeType.FP32), + ) + artifacts = session.get_stage_artifacts() + aoti_artifacts = artifacts[StageType.AOTI_LOWERING] + aoti_delegate_module = aoti_artifacts.data["forward"] + with torch.inference_mode(): + eager_out = m_eager(*example_inputs[0]) + aoti_out = aoti_delegate_module(*example_inputs[0]) + + self.assertTrue(torch.allclose(eager_out, aoti_out, atol=1e-3)) + + def _test_model_with_factory(self, model_name: str) -> None: + if model_name not in MODEL_NAME_TO_MODEL: + self.skipTest(f"Model {model_name} not found in MODEL_NAME_TO_MODEL") + return + + # Create model using factory + model, example_inputs, _example_kwarg_inputs, dynamic_shapes = ( + EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model_name]) + ) + model = model.eval() + + # Export with recipe + session = export( + model=model, + example_inputs=[example_inputs], + export_recipe=ExportRecipe.get_recipe(AOTIRecipeType.FP32), + dynamic_shapes=dynamic_shapes, + ) + + artifacts = session.get_stage_artifacts() + aoti_artifacts = artifacts[StageType.AOTI_LOWERING] + aoti_delegate_module = aoti_artifacts.data["forward"] + + with torch.inference_mode(): + eager_out = model(*example_inputs) + aoti_out = aoti_delegate_module(*example_inputs) + + self.assertTrue(torch.allclose(eager_out, aoti_out, atol=1e-3)) + + def test_all_models_with_recipes(self) -> None: + models_to_test = [ + "linear", + "add", + "add_mul", + "ic3", + "mv2", + "mv3", + "resnet18", + "resnet50", + "vit", + "w2l", + "llama2", + ] + for model_name in models_to_test: + with self.subTest(model=model_name): + self._test_model_with_factory(model_name) diff --git a/export/TARGETS b/export/TARGETS index 816a3a1a289..4db6c549bc1 100644 --- a/export/TARGETS +++ b/export/TARGETS @@ -50,6 +50,7 @@ runtime.python_library( deps = [ ":recipe", ":types", + "//executorch/backends/aoti:aoti_delegate_module", "//executorch/devtools/backend_debug:delegation_info", "//executorch/exir/backend:backend_api", "//executorch/exir:pass_manager", diff --git a/export/export.py b/export/export.py index 597ec28665b..7cbd195811e 100644 --- a/export/export.py +++ b/export/export.py @@ -18,6 +18,7 @@ from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe from .stages import ( + AOTILoweringStage, EdgeTransformAndLowerStage, ExecutorchStage, PipelineArtifact, @@ -203,6 +204,8 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]: stage = ToBackendStage.from_recipe(self._lowering_recipe) elif stage_type == StageType.TO_EXECUTORCH: stage = ExecutorchStage(self._export_recipe.executorch_backend_config) + elif stage_type == StageType.AOTI_LOWERING: + stage = AOTILoweringStage() else: logging.info( f"{stage_type} is unknown, you have to register it before executing export()" diff --git a/export/stages.py b/export/stages.py index dd22155e929..1f552ce985c 100644 --- a/export/stages.py +++ b/export/stages.py @@ -500,3 +500,62 @@ def delegation_info(self) -> Any: Returns the delegation info. """ return self._artifact.get_context("delegation_info") + + +class AOTILoweringStage(Stage): + """ + For Executorch-full runtime, lowering with AOTInductor. + """ + + def __init__( + self, + partitioners: Optional[List[Any]] = None, + compile_config: Optional[Any] = None, + ) -> None: + self._partitioners = partitioners + self._compile_config = compile_config + + @classmethod + def from_recipe( + cls, lowering_recipe: Optional["LoweringRecipe"] + ) -> "AOTILoweringStage": + if lowering_recipe is None: + return cls() + + return cls( + partitioners=lowering_recipe.partitioners, + compile_config=lowering_recipe.edge_compile_config, + ) + + @property + def stage_type(self) -> str: + return StageType.AOTI_LOWERING + + @property + def valid_predecessor_stages(self) -> List["StageType"]: + return [StageType.TORCH_EXPORT] + + @property + def can_start_pipeline(self) -> bool: + return False + + def run(self, artifact: PipelineArtifact) -> None: + """ + Lowering with AOTInductor. + """ + from executorch.backends.aoti.aoti_delegate_module import AOTIDelegateModule + + aoti_delegate_modules = {} + exported_programs = artifact.data + for name, exported_program in exported_programs.items(): + args, kwargs = exported_program.example_inputs + so_path = torch._inductor.aot_compile( + exported_program.module(), + args, + kwargs, + ) + assert isinstance(so_path, str) + aoti_delegate_module = AOTIDelegateModule(exported_program, so_path) + aoti_delegate_modules[name] = aoti_delegate_module + + self._artifact = artifact.copy_with_new_data(aoti_delegate_modules) diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index 92aeebb7304..2b94898eb46 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -304,6 +304,8 @@ def test_invalid_pipeline_start_stages(self) -> None: # Edge stage cannot start pipeline [StageType.TO_EDGE_TRANSFORM_AND_LOWER], [StageType.TO_EDGE_TRANSFORM_AND_LOWER, StageType.TO_EXECUTORCH], + # AOTI Lowering stage cannot start pipeline + [StageType.AOTI_LOWERING], ] for i, stages in enumerate(invalid_stage_sequence): @@ -336,6 +338,15 @@ def test_pipeline_transitions(self) -> None: False, ), ([StageType.TO_EXECUTORCH, StageType.TORCH_EXPORT], False), + ([StageType.TORCH_EXPORT, StageType.AOTI_LOWERING], True), + ( + [ + StageType.TORCH_EXPORT, + StageType.AOTI_LOWERING, + StageType.TO_EDGE_TRANSFORM_AND_LOWER, + ], + False, + ), ] for i, (stages, should_pass) in enumerate(test_cases): diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 2b3e533723a..a3b2e5bcc7f 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -13,6 +13,7 @@ from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager from executorch.export import QuantizationRecipe from executorch.export.stages import ( + AOTILoweringStage, EdgeTransformAndLowerStage, ExecutorchStage, PipelineArtifact, @@ -100,6 +101,27 @@ def test_get_artifacts_before_run(self) -> None: self.assertIn("Stage: TorchExportStage not executed", str(cm.exception)) +class TestAOTILoweringStage(unittest.TestCase): + def setUp(self) -> None: + self.mock_exported_program = Mock(spec=ExportedProgram) + self.exported_programs = {"forward": self.mock_exported_program} + self.context = {"constant_methods": None} + + def test_run_with_partitioners_and_config(self) -> None: + """Test execution with partitioners and compile config""" + mock_partitioners = [Mock()] + mock_compile_config = Mock() + + stage = AOTILoweringStage( + partitioners=mock_partitioners, compile_config=mock_compile_config + ) + + # Test that the stage has the right configuration + self.assertEqual(stage.stage_type, StageType.AOTI_LOWERING) + self.assertEqual(stage._partitioners, mock_partitioners) + self.assertEqual(stage._compile_config, mock_compile_config) + + class TestEdgeTransformAndLowerStage(unittest.TestCase): def setUp(self) -> None: self.mock_exported_program = Mock(spec=ExportedProgram) diff --git a/export/types.py b/export/types.py index 760f8461d41..85d8fb4dc51 100644 --- a/export/types.py +++ b/export/types.py @@ -15,6 +15,7 @@ class StageType(str, Enum): SOURCE_TRANSFORM = "source_transform" QUANTIZE = "quantize" TORCH_EXPORT = "torch_export" + AOTI_LOWERING = "aoti_lowering" TO_EDGE_TRANSFORM_AND_LOWER = "to_edge_transform_and_lower" TO_EDGE = "to_edge" TO_BACKEND = "to_backend"