From 8b6c773305dfd477413b99db4ef2b775b78ad685 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 15 Nov 2023 23:48:17 -0600 Subject: [PATCH 1/3] [mlir][python] set the registry free --- mlir/python/mlir/_mlir_libs/__init__.py | 206 ++++++++++++------------ 1 file changed, 105 insertions(+), 101 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 6ce77b4cb93f6..468925d278c61 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]: # # This facility allows downstreams to customize Context creation to their # needs. -def _site_initialize(): - import importlib - import itertools - import logging - from ._mlir import ir - - logger = logging.getLogger(__name__) - registry = ir.DialectRegistry() - post_init_hooks = [] - disable_multithreading = False - - def process_initializer_module(module_name): - nonlocal disable_multithreading - try: - m = importlib.import_module(f".{module_name}", __name__) - except ModuleNotFoundError: - return False - except ImportError: - message = ( - f"Error importing mlir initializer {module_name}. This may " - "happen in unclean incremental builds but is likely a real bug if " - "encountered otherwise and the MLIR Python API may not function." +import importlib +import itertools +import logging +from ._mlir import ir + +logger = logging.getLogger(__name__) +registry = ir.DialectRegistry() +post_init_hooks = [] +disable_multithreading = False + + +def get_registry(): + return registry + + +def process_initializer_module(module_name): + global disable_multithreading + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + except ImportError: + message = ( + f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function." + ) + logger.warning(message, exc_info=True) + + logger.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logger.debug("Registering dialects from initializer %r", m) + m.register_dialects(registry) + if hasattr(m, "context_init_hook"): + logger.debug("Adding context init hook from %r", m) + post_init_hooks.append(m.context_init_hook) + if hasattr(m, "disable_multithreading"): + if bool(m.disable_multithreading): + logger.debug("Disabling multi-threading for context") + disable_multithreading = True + return True + + +# If _mlirRegisterEverything is built, then include it as an initializer +# module. +init_module = None +if process_initializer_module("_mlirRegisterEverything"): + init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) + +# Load all _site_initialize_{i} modules, where 'i' is a number starting +# at 0. +for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not process_initializer_module(module_name): + break + + +class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(get_registry()) + for hook in post_init_hooks: + hook(self) + if not disable_multithreading: + self.enable_multithreading(True) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + if init_module: + logger.debug("Registering translations from initializer %r", init_module) + init_module.register_llvm_translations(self) + + +ir.Context = Context + + +class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ":" + for diag in self.error_diagnostics: + s += ( + "\nerror: " + + str(diag.location)[4:-1] + + ": " + + diag.message.replace("\n", "\n ") ) - logger.warning(message, exc_info=True) - - logger.debug("Initializing MLIR with module: %s", module_name) - if hasattr(m, "register_dialects"): - logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) - if hasattr(m, "context_init_hook"): - logger.debug("Adding context init hook from %r", m) - post_init_hooks.append(m.context_init_hook) - if hasattr(m, "disable_multithreading"): - if bool(m.disable_multithreading): - logger.debug("Disabling multi-threading for context") - disable_multithreading = True - return True - - # If _mlirRegisterEverything is built, then include it as an initializer - # module. - init_module = None - if process_initializer_module("_mlirRegisterEverything"): - init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) - - # Load all _site_initialize_{i} modules, where 'i' is a number starting - # at 0. - for i in itertools.count(): - module_name = f"_site_initialize_{i}" - if not process_initializer_module(module_name): - break - - class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.append_dialect_registry(registry) - for hook in post_init_hooks: - hook(self) - if not disable_multithreading: - self.enable_multithreading(True) - # TODO: There is some debate about whether we should eagerly load - # all dialects. It is being done here in order to preserve existing - # behavior. See: https://github.com/llvm/llvm-project/issues/56037 - self.load_all_available_dialects() - if init_module: - logger.debug( - "Registering translations from initializer %r", init_module - ) - init_module.register_llvm_translations(self) - - ir.Context = Context - - class MLIRError(Exception): - """ - An exception with diagnostic information. Has the following fields: - message: str - error_diagnostics: List[ir.DiagnosticInfo] - """ - - def __init__(self, message, error_diagnostics): - self.message = message - self.error_diagnostics = error_diagnostics - super().__init__(message, error_diagnostics) - - def __str__(self): - s = self.message - if self.error_diagnostics: - s += ":" - for diag in self.error_diagnostics: + for note in diag.notes: s += ( - "\nerror: " - + str(diag.location)[4:-1] + "\n note: " + + str(note.location)[4:-1] + ": " - + diag.message.replace("\n", "\n ") + + note.message.replace("\n", "\n ") ) - for note in diag.notes: - s += ( - "\n note: " - + str(note.location)[4:-1] - + ": " - + note.message.replace("\n", "\n ") - ) - return s - - ir.MLIRError = MLIRError + return s -_site_initialize() +ir.MLIRError = MLIRError From b23938a217f9fb6e489b154a79c42182c4db4e34 Mon Sep 17 00:00:00 2001 From: max Date: Thu, 16 Nov 2023 00:34:19 -0600 Subject: [PATCH 2/3] [mlir][python] hide everything in a namespace/module --- mlir/python/mlir/_mlir_libs/__init__.py | 319 ++++++++++++------------ 1 file changed, 166 insertions(+), 153 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 468925d278c61..2b2306d121cbf 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -2,168 +2,181 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Sequence - -import os - -_this_dir = os.path.dirname(__file__) - - -def get_lib_dirs() -> Sequence[str]: - """Gets the lib directory for linking to shared libraries. - - On some platforms, the package may need to be built specially to export - development libraries. - """ - return [_this_dir] - - -def get_include_dirs() -> Sequence[str]: - """Gets the include directory for compiling against exported C libraries. - - Depending on how the package was build, development C libraries may or may - not be present. - """ - return [os.path.join(_this_dir, "include")] - - -# Perform Python level site initialization. This involves: -# 1. Attempting to load initializer modules, specific to the distribution. -# 2. Defining the concrete mlir.ir.Context that does site specific -# initialization. -# -# Aside from just being far more convenient to do this at the Python level, -# it is actually quite hard/impossible to have such __init__ hooks, given -# the pybind memory model (i.e. there is not a Python reference to the object -# in the scope of the base class __init__). -# -# For #1, we: -# a. Probe for modules named '_mlirRegisterEverything' and -# '_site_initialize_{i}', where 'i' is a number starting at zero and -# proceeding so long as a module with the name is found. -# b. If the module has a 'register_dialects' attribute, it will be called -# immediately with a DialectRegistry to populate. -# c. If the module has a 'context_init_hook', it will be added to a list -# of callbacks that are invoked as the last step of Context -# initialization (and passed the Context under construction). -# d. If the module has a 'disable_multithreading' attribute, it will be -# taken as a boolean. If it is True for any initializer, then the -# default behavior of enabling multithreading on the context -# will be suppressed. This complies with the original behavior of all -# contexts being created with multithreading enabled while allowing -# this behavior to be changed if needed (i.e. if a context_init_hook -# explicitly sets up multithreading). -# -# This facility allows downstreams to customize Context creation to their -# needs. import importlib import itertools import logging +import os +import sys +from typing import Sequence + from ._mlir import ir + +_this_dir = os.path.dirname(__file__) + logger = logging.getLogger(__name__) -registry = ir.DialectRegistry() -post_init_hooks = [] -disable_multithreading = False - - -def get_registry(): - return registry - - -def process_initializer_module(module_name): - global disable_multithreading - try: - m = importlib.import_module(f".{module_name}", __name__) - except ModuleNotFoundError: - return False - except ImportError: - message = ( - f"Error importing mlir initializer {module_name}. This may " - "happen in unclean incremental builds but is likely a real bug if " - "encountered otherwise and the MLIR Python API may not function." - ) - logger.warning(message, exc_info=True) - - logger.debug("Initializing MLIR with module: %s", module_name) - if hasattr(m, "register_dialects"): - logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) - if hasattr(m, "context_init_hook"): - logger.debug("Adding context init hook from %r", m) - post_init_hooks.append(m.context_init_hook) - if hasattr(m, "disable_multithreading"): - if bool(m.disable_multithreading): - logger.debug("Disabling multi-threading for context") - disable_multithreading = True - return True - - -# If _mlirRegisterEverything is built, then include it as an initializer -# module. -init_module = None -if process_initializer_module("_mlirRegisterEverything"): - init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) - -# Load all _site_initialize_{i} modules, where 'i' is a number starting -# at 0. -for i in itertools.count(): - module_name = f"_site_initialize_{i}" - if not process_initializer_module(module_name): - break - - -class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.append_dialect_registry(get_registry()) - for hook in post_init_hooks: - hook(self) - if not disable_multithreading: - self.enable_multithreading(True) - # TODO: There is some debate about whether we should eagerly load - # all dialects. It is being done here in order to preserve existing - # behavior. See: https://github.com/llvm/llvm-project/issues/56037 - self.load_all_available_dialects() - if init_module: - logger.debug("Registering translations from initializer %r", init_module) - init_module.register_llvm_translations(self) - - -ir.Context = Context - - -class MLIRError(Exception): - """ - An exception with diagnostic information. Has the following fields: - message: str - error_diagnostics: List[ir.DiagnosticInfo] - """ - - def __init__(self, message, error_diagnostics): - self.message = message - self.error_diagnostics = error_diagnostics - super().__init__(message, error_diagnostics) - - def __str__(self): - s = self.message - if self.error_diagnostics: - s += ":" - for diag in self.error_diagnostics: - s += ( - "\nerror: " - + str(diag.location)[4:-1] - + ": " - + diag.message.replace("\n", "\n ") + +_path = __path__ +_spec = __spec__ +_name = __name__ + + +class _M: + __path__ = _path + __spec__ = _spec + __name__ = _name + + @staticmethod + def get_lib_dirs() -> Sequence[str]: + """Gets the lib directory for linking to shared libraries. + + On some platforms, the package may need to be built specially to export + development libraries. + """ + return [_this_dir] + + @staticmethod + def get_include_dirs() -> Sequence[str]: + """Gets the include directory for compiling against exported C libraries. + + Depending on how the package was build, development C libraries may or may + not be present. + """ + return [os.path.join(_this_dir, "include")] + + # Perform Python level site initialization. This involves: + # 1. Attempting to load initializer modules, specific to the distribution. + # 2. Defining the concrete mlir.ir.Context that does site specific + # initialization. + # + # Aside from just being far more convenient to do this at the Python level, + # it is actually quite hard/impossible to have such __init__ hooks, given + # the pybind memory model (i.e. there is not a Python reference to the object + # in the scope of the base class __init__). + # + # For #1, we: + # a. Probe for modules named '_mlirRegisterEverything' and + # '_site_initialize_{i}', where 'i' is a number starting at zero and + # proceeding so long as a module with the name is found. + # b. If the module has a 'register_dialects' attribute, it will be called + # immediately with a DialectRegistry to populate. + # c. If the module has a 'context_init_hook', it will be added to a list + # of callbacks that are invoked as the last step of Context + # initialization (and passed the Context under construction). + # d. If the module has a 'disable_multithreading' attribute, it will be + # taken as a boolean. If it is True for any initializer, then the + # default behavior of enabling multithreading on the context + # will be suppressed. This complies with the original behavior of all + # contexts being created with multithreading enabled while allowing + # this behavior to be changed if needed (i.e. if a context_init_hook + # explicitly sets up multithreading). + # + # This facility allows downstreams to customize Context creation to their + # needs. + + __registry = ir.DialectRegistry() + __post_init_hooks = [] + __disable_multithreading = False + from . import _mlir as _mlir + + def __get_registry(self): + return self.__registry + + def process_initializer_module(self, module_name): + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + except ImportError: + message = ( + f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function." ) - for note in diag.notes: + logger.warning(message, exc_info=True) + + logger.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logger.debug("Registering dialects from initializer %r", m) + m.register_dialects(self.__get_registry()) + if hasattr(m, "context_init_hook"): + logger.debug("Adding context init hook from %r", m) + self.__post_init_hooks.append(m.context_init_hook) + if hasattr(m, "disable_multithreading"): + if bool(m.disable_multithreading): + logger.debug("Disabling multi-threading for context") + self.__disable_multithreading = True + return True + + def __init__(self): + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + init_module = None + if self.process_initializer_module("_mlirRegisterEverything"): + init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not self.process_initializer_module(module_name): + break + + that = self + + class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(that._M__get_registry()) + for hook in that._M__post_init_hooks: + hook(self) + if not that._M__disable_multithreading: + self.enable_multithreading(True) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + if init_module: + logger.debug( + "Registering translations from initializer %r", init_module + ) + init_module.register_llvm_translations(self) + + ir.Context = Context + + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ":" + for diag in self.error_diagnostics: s += ( - "\n note: " - + str(note.location)[4:-1] + "\nerror: " + + str(diag.location)[4:-1] + ": " - + note.message.replace("\n", "\n ") + + diag.message.replace("\n", "\n ") ) - return s + for note in diag.notes: + s += ( + "\n note: " + + str(note.location)[4:-1] + + ": " + + note.message.replace("\n", "\n ") + ) + return s + + ir.MLIRError = MLIRError -ir.MLIRError = MLIRError +sys.modules[__name__] = _M() From b58593b2f79f2531588eaa3278e23798a15b15e1 Mon Sep 17 00:00:00 2001 From: max Date: Thu, 16 Nov 2023 01:20:22 -0600 Subject: [PATCH 3/3] [mlir][python] demo registering --- mlir/python/mlir/_mlir_libs/__init__.py | 10 +++++----- mlir/python/mlir/dialects/python_test.py | 6 ------ mlir/test/python/dialects/python_test.py | 17 ++++------------- mlir/test/python/lib/PythonTestModule.cpp | 9 +++++++++ 4 files changed, 18 insertions(+), 24 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 2b2306d121cbf..f5dfd1edf5a3e 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -82,9 +82,9 @@ def get_include_dirs() -> Sequence[str]: def __get_registry(self): return self.__registry - def process_initializer_module(self, module_name): + def process_c_ext_module(self, module_name): try: - m = importlib.import_module(f".{module_name}", __name__) + m = importlib.import_module(f"{module_name}", __name__) except ModuleNotFoundError: return False except ImportError: @@ -112,14 +112,14 @@ def __init__(self): # If _mlirRegisterEverything is built, then include it as an initializer # module. init_module = None - if self.process_initializer_module("_mlirRegisterEverything"): + if self.process_c_ext_module("._mlirRegisterEverything"): init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) # Load all _site_initialize_{i} modules, where 'i' is a number starting # at 0. for i in itertools.count(): - module_name = f"_site_initialize_{i}" - if not self.process_initializer_module(module_name): + module_name = f"._site_initialize_{i}" + if not self.process_c_ext_module(module_name): break that = self diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 6579e02d8549e..8b4f718d8a53b 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -9,9 +9,3 @@ TestTensorValue, TestIntegerRankedTensorType, ) - - -def register_python_test_dialect(context, load=True): - from .._mlir_libs import _mlirPythonTest - - _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index f313a400b73c0..309de8037049c 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -1,5 +1,9 @@ # RUN: %PYTHON %s | FileCheck %s +from mlir import _mlir_libs + +_mlir_libs.process_c_ext_module("mlir._mlir_libs._mlirPythonTest") + from mlir.ir import * import mlir.dialects.func as func import mlir.dialects.python_test as test @@ -17,7 +21,6 @@ def run(f): @run def testAttributes(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) # # Check op construction with attributes. # @@ -138,7 +141,6 @@ def testAttributes(): @run def attrBuilder(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) # CHECK: python_test.attributes_op op = test.AttributesOp( # CHECK-DAG: x_affinemap = affine_map<() -> (2)> @@ -215,7 +217,6 @@ def attrBuilder(): @run def inferReturnTypes(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): op = test.InferResultsOp() @@ -260,7 +261,6 @@ def inferReturnTypes(): @run def resultTypesDefinedByTraits(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): inferred = test.InferResultsOp() @@ -295,7 +295,6 @@ def resultTypesDefinedByTraits(): @run def testOptionalOperandOp(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): @@ -312,7 +311,6 @@ def testOptionalOperandOp(): @run def testCustomAttribute(): with Context() as ctx: - test.register_python_test_dialect(ctx) a = test.TestAttr.get() # CHECK: #python_test.test_attr print(a) @@ -350,7 +348,6 @@ def testCustomAttribute(): @run def testCustomType(): with Context() as ctx: - test.register_python_test_dialect(ctx) a = test.TestType.get() # CHECK: !python_test.test_type print(a) @@ -397,8 +394,6 @@ def testCustomType(): # CHECK-LABEL: TEST: testTensorValue def testTensorValue(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - i8 = IntegerType.get_signless(8) class Tensor(test.TestTensorValue): @@ -436,7 +431,6 @@ def __str__(self): @run def inferReturnTypeComponents(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): @@ -488,8 +482,6 @@ def inferReturnTypeComponents(): @run def testCustomTypeTypeCaster(): with Context() as ctx, Location.unknown(): - test.register_python_test_dialect(ctx) - a = test.TestType.get() assert a.typeid is not None @@ -542,7 +534,6 @@ def type_caster(pytype): @run def testInferTypeOpInterface(): with Context() as ctx, Location.unknown(ctx): - test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): i64 = IntegerType.get_signless(64) diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index aff414894cb82..9e7decefa7166 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) { }, py::arg("context"), py::arg("load") = true); + m.def( + "register_dialects", + [](MlirDialectRegistry registry) { + MlirDialectHandle pythonTestDialect = + mlirGetDialectHandle__python_test__(); + mlirDialectHandleInsertDialect(pythonTestDialect, registry); + }, + py::arg("registry")); + mlir_attribute_subclass(m, "TestAttr", mlirAttributeIsAPythonTestTestAttribute) .def_classmethod(