Skip to content

Commit f95606c

Browse files
yiming0416facebook-github-bot
authored andcommitted
Add AOTI Lowering Recipe
Summary: - Added `AOTILoweringStage` for the executorch-full runtime and a simple fp32 AOTI lowering recipe - Added `AOTIDelegateModule` as the artifact generated by AOTI Lowering to be packaged later. Reviewed By: larryliu0820 Differential Revision: D79487582
1 parent 4ce7078 commit f95606c

File tree

14 files changed

+435
-0
lines changed

14 files changed

+435
-0
lines changed

backends/aoti/TARGETS

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
runtime.python_library(
4+
name = "aoti_delegate",
5+
srcs = [
6+
"__init__.py",
7+
],
8+
visibility = [
9+
"//executorch/...",
10+
"@EXECUTORCH_CLIENTS",
11+
],
12+
deps = [
13+
"//executorch/export:lib",
14+
"//executorch/backends/aoti/recipes:aoti_recipe_provider",
15+
"//executorch/backends/aoti/recipes:aoti_recipe_types",
16+
],
17+
)
18+
19+
runtime.python_library(
20+
name = "aoti_delegate_module",
21+
srcs = [
22+
"aoti_delegate_module.py",
23+
],
24+
visibility = [
25+
"//executorch/...",
26+
"@EXECUTORCH_CLIENTS",
27+
],
28+
deps = [
29+
"//caffe2:libtorch",
30+
],
31+
)

backends/aoti/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from executorch.export import recipe_registry
2+
3+
from .recipes.aoti_recipe_provider import AOTIRecipeProvider
4+
from .recipes.aoti_recipe_types import AOTIRecipeType
5+
6+
# Auto-register AOTI recipe provider
7+
recipe_registry.register_backend_recipe_provider(AOTIRecipeProvider())
8+
9+
__all__ = [
10+
"AOTIRecipeType",
11+
]

backends/aoti/aoti_delegate_module.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import torch
10+
11+
from torch.export import ExportedProgram
12+
from torch.utils import _pytree as pytree
13+
14+
# TODO: These should probably be in pytorch
15+
16+
17+
class AOTInductorRunnerWrapper(torch.nn.Module):
18+
# pyre-fixme[2]: Parameter must be annotated.
19+
def __init__(self, aoti_runner) -> None:
20+
super().__init__()
21+
# pyre-fixme[4]: Attribute must be annotated.
22+
self.aoti_runner = aoti_runner
23+
24+
# pyre-fixme[3]: Return type must be annotated.
25+
# pyre-fixme[2]: Parameter must be annotated.
26+
def forward(self, *flat_inputs):
27+
return self.aoti_runner.run(flat_inputs)
28+
29+
30+
class AOTIDelegateModule(torch.nn.Module):
31+
"""
32+
This module is the primary artifact produced by AOTInductor lowering.
33+
It is eagerly runnable in Python and traceable by torch.export.
34+
It also contains all necessary information and metadata to be pacakged and consumed
35+
by the delegate executor in runtime later.
36+
37+
"""
38+
39+
def __init__(self, exported_program: ExportedProgram, so_path: str) -> None:
40+
super().__init__()
41+
self.so_path = so_path
42+
self.exported_program = exported_program
43+
self.exported_program.graph_module.recompile()
44+
45+
# register parameters
46+
for name, parameter in self.exported_program.named_parameters():
47+
normalized_name = name.replace(".", "_")
48+
self.register_parameter(normalized_name, parameter)
49+
50+
# register buffers
51+
non_persistent_buffer_names = (
52+
exported_program.graph_signature.non_persistent_buffers
53+
)
54+
for name, buffer in self.exported_program.named_buffers():
55+
normalized_name = name.replace(".", "_")
56+
if name in non_persistent_buffer_names:
57+
self.register_buffer(normalized_name, buffer, persistent=False)
58+
else:
59+
self.register_buffer(normalized_name, buffer, persistent=True)
60+
61+
# handle tensor constants
62+
self.constant_names: list[str] = []
63+
for name, constant in self.exported_program.tensor_constants.items():
64+
# skip non-persistent buffers
65+
if name in non_persistent_buffer_names:
66+
continue
67+
normalized_name = name.replace(".", "_")
68+
setattr(self, normalized_name, constant)
69+
self.constant_names.append(normalized_name)
70+
71+
# pyre-ignore[4]: Missing attribute annotation
72+
# pyre-ignore[16]: Undefined attribute
73+
# TODO: CPU only for now. Add GPU
74+
self.engine = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1)
75+
self.aoti_runner_wrapper = AOTInductorRunnerWrapper(self.engine)
76+
77+
# pyre-fixme[3]: Return type must be annotated.
78+
# pyre-fixme[2]: Parameter must be annotated.
79+
def forward(self, *inputs):
80+
weights_args = [
81+
*self.parameters(),
82+
*self.buffers(),
83+
] + [getattr(self, const_name) for const_name in self.constant_names]
84+
85+
flat_inputs = pytree.tree_flatten((inputs, {}))[0]
86+
flat_outputs = torch._higher_order_ops.aoti_call_delegate(
87+
self.aoti_runner_wrapper,
88+
self.exported_program.graph_module,
89+
weights_args,
90+
flat_inputs,
91+
)
92+
return pytree.tree_unflatten(
93+
flat_outputs, self.exported_program.call_spec.out_spec
94+
)

backends/aoti/recipes/TARGETS

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_library(
6+
name = "aoti_recipe_provider",
7+
srcs = [
8+
"aoti_recipe_provider.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
"@EXECUTORCH_CLIENTS",
13+
],
14+
deps = [
15+
"//caffe2:torch",
16+
"//executorch/export:lib",
17+
":aoti_recipe_types",
18+
],
19+
)
20+
21+
runtime.python_library(
22+
name = "aoti_recipe_types",
23+
srcs = [
24+
"aoti_recipe_types.py",
25+
],
26+
visibility = [
27+
"//executorch/...",
28+
"@EXECUTORCH_CLIENTS",
29+
],
30+
deps = [
31+
"//executorch/export:lib",
32+
],
33+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Any, Optional, Sequence
10+
11+
from executorch.backends.aoti.recipes.aoti_recipe_types import AOTIRecipeType
12+
13+
from executorch.export import (
14+
BackendRecipeProvider,
15+
ExportRecipe,
16+
LoweringRecipe,
17+
RecipeType,
18+
StageType,
19+
)
20+
21+
22+
class AOTIRecipeProvider(BackendRecipeProvider):
23+
@property
24+
def backend_name(self) -> str:
25+
return "aoti"
26+
27+
def get_supported_recipes(self) -> Sequence[RecipeType]:
28+
return [AOTIRecipeType.FP32]
29+
30+
def create_recipe(
31+
self, recipe_type: RecipeType, **kwargs: Any
32+
) -> Optional[ExportRecipe]:
33+
"""Create AOTI recipe"""
34+
35+
if recipe_type not in self.get_supported_recipes():
36+
return None
37+
38+
if recipe_type == AOTIRecipeType.FP32:
39+
return self._build_fp32_recipe(recipe_type)
40+
41+
def _get_aoti_lowering_recipe(self) -> LoweringRecipe:
42+
return LoweringRecipe(
43+
partitioners=None,
44+
edge_transform_passes=None,
45+
edge_compile_config=None,
46+
)
47+
48+
def _build_fp32_recipe(self, recipe_type: RecipeType) -> ExportRecipe:
49+
return ExportRecipe(
50+
name=recipe_type.value,
51+
lowering_recipe=self._get_aoti_lowering_recipe(),
52+
pipeline_stages=[StageType.TORCH_EXPORT, StageType.AOTI_LOWERING],
53+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from executorch.export import RecipeType
10+
11+
12+
class AOTIRecipeType(RecipeType):
13+
"""AOTInductor-specific recipe types"""
14+
15+
FP32 = "fp32"
16+
# more to be added...
17+
18+
@classmethod
19+
def get_backend_name(cls) -> str:
20+
return "aoti"

backends/aoti/test/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu")
3+
4+
oncall("executorch")
5+
6+
runtime.python_test(
7+
name = "test_aoti_recipes",
8+
srcs = [
9+
"recipes/test_aoti_recipes.py",
10+
],
11+
deps = [
12+
"//executorch/backends/aoti:aoti_delegate",
13+
"//executorch/export:lib",
14+
"//executorch/examples/models:models", # @manual
15+
],
16+
)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import unittest
2+
3+
import torch
4+
5+
from executorch.backends.aoti.recipes.aoti_recipe_types import AOTIRecipeType
6+
from executorch.examples.models import MODEL_NAME_TO_MODEL
7+
from executorch.examples.models.model_factory import EagerModelFactory
8+
from executorch.export import export, ExportRecipe, StageType
9+
from torch.testing._internal.common_quantization import TestHelperModules
10+
11+
12+
class TestAotiRecipes(unittest.TestCase):
13+
def setUp(self) -> None:
14+
super().setUp()
15+
16+
def tearDown(self) -> None:
17+
super().tearDown()
18+
19+
def test_basic_recipe(self) -> None:
20+
m_eager = TestHelperModules.TwoLinearModule().eval()
21+
example_inputs = [(torch.randn(9, 8),)]
22+
session = export(
23+
model=m_eager,
24+
example_inputs=example_inputs,
25+
export_recipe=ExportRecipe.get_recipe(AOTIRecipeType.FP32),
26+
)
27+
artifacts = session.get_stage_artifacts()
28+
aoti_artifacts = artifacts[StageType.AOTI_LOWERING]
29+
aoti_delegate_module = aoti_artifacts.data["forward"]
30+
with torch.inference_mode():
31+
eager_out = m_eager(*example_inputs[0])
32+
aoti_out = aoti_delegate_module(*example_inputs[0])
33+
34+
self.assertTrue(torch.allclose(eager_out, aoti_out, atol=1e-3))
35+
36+
def _test_model_with_factory(self, model_name: str) -> None:
37+
if model_name not in MODEL_NAME_TO_MODEL:
38+
self.skipTest(f"Model {model_name} not found in MODEL_NAME_TO_MODEL")
39+
return
40+
41+
# Create model using factory
42+
model, example_inputs, _example_kwarg_inputs, dynamic_shapes = (
43+
EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model_name])
44+
)
45+
model = model.eval()
46+
47+
# Export with recipe
48+
session = export(
49+
model=model,
50+
example_inputs=[example_inputs],
51+
export_recipe=ExportRecipe.get_recipe(AOTIRecipeType.FP32),
52+
dynamic_shapes=dynamic_shapes,
53+
)
54+
55+
artifacts = session.get_stage_artifacts()
56+
aoti_artifacts = artifacts[StageType.AOTI_LOWERING]
57+
aoti_delegate_module = aoti_artifacts.data["forward"]
58+
59+
with torch.inference_mode():
60+
eager_out = model(*example_inputs)
61+
aoti_out = aoti_delegate_module(*example_inputs)
62+
63+
self.assertTrue(torch.allclose(eager_out, aoti_out, atol=1e-3))
64+
65+
def test_all_models_with_recipes(self) -> None:
66+
models_to_test = [
67+
"linear",
68+
"add",
69+
"add_mul",
70+
"ic3",
71+
"mv2",
72+
"mv3",
73+
"resnet18",
74+
"resnet50",
75+
"vit",
76+
"w2l",
77+
"llama2",
78+
]
79+
for model_name in models_to_test:
80+
with self.subTest(model=model_name):
81+
self._test_model_with_factory(model_name)

export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ runtime.python_library(
5050
deps = [
5151
":recipe",
5252
":types",
53+
"//executorch/backends/aoti:aoti_delegate_module",
5354
"//executorch/devtools/backend_debug:delegation_info",
5455
"//executorch/exir/backend:backend_api",
5556
"//executorch/exir:pass_manager",

export/export.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe
2020
from .stages import (
21+
AOTILoweringStage,
2122
EdgeTransformAndLowerStage,
2223
ExecutorchStage,
2324
PipelineArtifact,
@@ -203,6 +204,8 @@ def _build_stages(self, stages: List[StageType]) -> Dict[StageType, Stage]:
203204
stage = ToBackendStage.from_recipe(self._lowering_recipe)
204205
elif stage_type == StageType.TO_EXECUTORCH:
205206
stage = ExecutorchStage(self._export_recipe.executorch_backend_config)
207+
elif stage_type == StageType.AOTI_LOWERING:
208+
stage = AOTILoweringStage()
206209
else:
207210
logging.info(
208211
f"{stage_type} is unknown, you have to register it before executing export()"

0 commit comments

Comments
 (0)